In [None]:
# Defaults

# System settings
systmpfs = '/tmp'
address = 'auto'
inputs_path = 'inputs/Matsch_B2/run'
geotop_path = '../geotop/build/geotop'
variables_path = 'inputs/Matsch_B2/variables.csv'

# Optimizer settings
num_workers = 2
budget = 16
algorithm = 'OnePlusOne'
timeout = 120
monitor_interval = 10
scale = 'D'
startdate = '01/01/2011 00:00'
targets = ['soil_moisture_content_50', 'sensible_heat_flux_in_air']
weights = [1.0, 1.0]

In [None]:
from subprocess import CalledProcessError, TimeoutExpired
from tempfile import TemporaryDirectory, NamedTemporaryFile

import numpy as np
from numpy.random import uniform
import pandas as pd
import matplotlib.pyplot as plt

import nevergrad as ng
import hiplot as hip
from SALib.analyze import delta

In [None]:
from os.path import join as joinpath
from threading import Thread
from time import sleep
from collections.abc import Mapping
from datetime import datetime

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import psutil
import geotopy as gtp

def date_parser(x):
    return datetime.strptime(x, '%d/%m/%Y %H:%M')

class GEOtopRun(gtp.GEOtop):

    def preprocess(self, working_dir, *args, **kwargs):
        
        settings = {**self, **kwargs}
        
        inpts_src_path = joinpath(self.inputs_dir, 'geotop.inpts')
        inpts_dest_path = joinpath(working_dir, 'geotop.inpts')
        
        with open(inpts_src_path, 'r') as inpts_src, open(inpts_dest_path, 'w') as inpts_dest:
            inpts_dest.write(f"! GEOtop input file written by GEOtoPy {datetime.now().strftime('%x %X')}\n")
            while line := inpts_src.readline():
                if gtp._comment_re.match(line):
                    inpts_dest.write(line)
                else:
                    try:
                        key, value = gtp.read_setting(line)
                        
                        if key in settings and value != settings[key]:
                            inpts_dest.write(f"! GEOtoPy: {key} overwritten, was {value}\n")
                            line = gtp.print_setting(key, settings[key])
                        else:
                            line = gtp.print_setting(key, value)
                        
                        inpts_dest.write(line)
                        del settings[key]
                    
                    except ValueError as err:
                        inpts_dest.write(f"! GEOtoPy: {err}\n")
                        inpts_dest.write(line)
            
            if settings:
                inpts_dest.write("\n! Settings added by GEOtoPy\n")
                for key, value in settings.items():
                    try:
                        line = gtp.print_setting(key, value)
                        inpts_dest.write(line)
                    except ValueError as err:
                        inpts_dest.write(f"! GEOtoPy: {err}\n")
                        inpts_dest.write(f"{key} = {value}\n")
                            
        
    def postprocess(self, working_dir):
        
        liq_path = joinpath(working_dir, 'theta_liq.txt')
        liq = pd.read_csv(liq_path, 
                          na_values=['-9999'],
                          usecols=[0, 6, 7], 
                          skiprows=1,
                          header=0, 
                          names=['datetime', 'soil_moisture_content_50', 'soil_moisture_content_200'],
                          parse_dates=[0], 
                          date_parser=date_parser,
                          index_col=0,
                          low_memory=False)
        
        ice_path = joinpath(working_dir, 'theta_ice.txt')
        ice = pd.read_csv(ice_path, 
                          na_values=['-9999'], 
                          usecols=[0, 6, 7], 
                          skiprows=1,
                          header=0, 
                          names=['datetime', 'soil_moisture_content_50', 'soil_moisture_content_200'],
                          parse_dates=[0], 
                          date_parser=date_parser,
                          index_col=0,
                          low_memory=False)
        
        point_path = joinpath(working_dir, 'point.txt')
        point = pd.read_csv(point_path, 
                          na_values=['-9999'],
                          parse_dates=[0], 
                          date_parser=date_parser,
                          index_col=0,
                          low_memory=False)
        point.index.rename('datetime', inplace=True)
        
        sim = pd.DataFrame(index=point.index)
        
        sim['rainfall_amount'] = point['Prain_over_canopy[mm]'] + point['Psnow_over_canopy[mm]']
        
        sim['wind_speed'] = point['Wind_speed[m/s]']
        
        sim['relative_humidity'] = point['Relative_Humidity[-]']
        
        sim['air_temperature'] = point['Tair[C]']
        
        sim['surface_downwelling_shortwave_flux'] = point['SWin[W/m2]']
        
        sim['soil_moisture_content_50'] = ice['soil_moisture_content_50'] + liq['soil_moisture_content_50']
        
        sim['soil_moisture_content_200'] = ice['soil_moisture_content_200'] + liq['soil_moisture_content_200']
        
        sim['latent_heat_flux_in_air'] = \
            point['Canopy_fraction[-]'] * (point['LEg_veg[W/m2]'] + point['LEv[W/m2]']) + \
            (1 - point['Canopy_fraction[-]']) * point['LEg_unveg[W/m2]']
        
        sim['sensible_heat_flux_in_air'] = \
            point['Canopy_fraction[-]'] * (point['Hg_veg[W/m2]'] + point['Hv[W/m2]']) + \
            (1 - point['Canopy_fraction[-]']) * point['Hg_unveg[W/m2]']
        
        return sim

class observations(Mapping):
    
    def __init__(self, source, scale='D', start=None, end=None):
        
        self.scale = scale
        
        if isinstance(source, pd.DataFrame):
            obs = source
        else:
            obs = pd.read_csv(source, 
                              na_values=['-9999', '-99.99'],
                              parse_dates=[0], 
                              date_parser=date_parser,
                              index_col=0)
        
        obs.index.rename('datetime', inplace=True)
        
        if start and end:
            obs = obs[date_parser(start):date_parser(end)]
        elif start:
            obs = obs[date_parser(start):]
        elif end:
            obs = obs[:date_parser(end)]
        
        self.data = obs.resample(scale).mean()
        
        self.mean_square = (self.data * self.data).mean()
        
        
    def __getitem__(self, key):
        
        return self.data[key]

    def __len__(self):
        
        return len(self.data)

    def __iter__(self):
        
        return iter(self.data)

    def compare(self, target, simulation, scales=None, desc=None, unit=None, rel=False, figsize=(16,9), dpi=100):

        if not scales:
            scales = {'Daily': 'D', 'Weekly': 'W', 'Monthly': 'M'}

        fig, axes = plt.subplots(ncols=3, 
                                 nrows=len(scales),
                                 figsize=figsize,
                                 dpi=dpi,
                                 constrained_layout=True)

        if desc:
            fig.suptitle(desc)

        for i, (Tstr, T) in enumerate(scales.items()):
            comp_plot, diff_plot, hist_plot = axes[i, :]
            
            obs_resampled = self[target].resample(T).mean()
            sim_resampled = simulation[target].resample(T).mean()

            err = obs_resampled - sim_resampled        
            if rel:
                err = err / obs_resampled.abs()

            data = pd.DataFrame({'Observations': obs_resampled, 'Simulation': sim_resampled})
            sns.lineplot(data=data, ax=comp_plot)
            comp_plot.set_title(Tstr)
            comp_plot.set_xlabel('')
            if unit:
                comp_plot.set_ylabel(f"[{unit}]")

            sns.lineplot(data=err, ax=diff_plot)
            plt.setp(diff_plot.get_xticklabels(), rotation=20)
            diff_plot.set_xlabel('')
            if rel:
                diff_plot.set_ylabel("Relative error")
                diff_plot.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1.0))
            elif unit:
                diff_plot.set_ylabel(f"Error [{unit}]")
            else:
                diff_plot.set_ylabel("Error")

            sns.histplot(y=err, kde=True, stat='probability', ax=hist_plot)
            y1, y2 = diff_plot.get_ylim()
            hist_plot.set_ylim(y1,y2)
            hist_plot.set_yticklabels([])
            hist_plot.set_ylabel('')
        
        return fig
    
    def metric(self, target, simulation):
        
        diff = self[target] - simulation[target].resample(self.scale).mean()
        
        return np.sqrt((diff * diff).mean() / self.mean_square[target])

    
class monitor:
    def __init__(self, interval):
        self.datetime = []
        self.cpu_usage = []
        self.memory_usage = []
        self.interval = interval
        self.running = False
        
        self.thread = Thread(target=self.run, args=())
        self.thread.daemon = True
        self.start()

    def sample(self):
        self.datetime.append(datetime.now())
        self.cpu_usage.append(psutil.cpu_percent())
        self.memory_usage.append(psutil.virtual_memory().percent)
        
    def run(self):
        while self.running:
            self.sample()
            sleep(self.interval)
    
    def start(self):
        self.running = True
        self.thread.start()
        
    def stop(self):
        self.running = False
        self.thread.join()
        
    def plot(self, figsize=(16,9), dpi=100):
        self.stop()
        stats = pd.DataFrame({'CPU': self.cpu_usage, 'Memory': self.memory_usage}, index=self.datetime)
        fig = plt.figure(figsize=figsize, dpi=dpi)
        axes = fig.add_subplot()
        axes.set_title('Resource Monitor')
        axes.set_xlabel('Time')
        axes.set_ylabel('Usage [%]')
        sns.lineplot(data=stats, ax=axes)

In [None]:
from dask.distributed import Client

In [None]:
client = Client(address)

In [None]:
class GEOtopRunLogVars(GEOtopRun):
 
    def preprocess(self, working_dir, *args, **kwargs):
        
        for key, value in kwargs.items():
            if variables.type[key] == 'log':
                kwargs[key] = 10 ** value
                
        super().preprocess(working_dir, *args, **kwargs)

In [None]:
model = GEOtopRunLogVars(inputs_path,
                         run_args={'check': True, 
                                   'capture_output': True, 
                                   'timeout': timeout})

variables = pd.read_csv(variables_path, index_col='name')
variables['synth'] = [uniform(low=var.lower, high=var.upper) for name, var in variables.iterrows()]

In [None]:
def loss_function(*args, sim=None, **kwargs):
    if sim is None:
        with TemporaryDirectory(dir=systmpfs) as tmpdir:
            try:
                sim = model.eval(tmpdir, *args, **kwargs)
            except CalledProcessError:
                    return np.nan
            except TimeoutExpired:
                    return np.nan
    return sum(w * synth.metric(t, sim) for w, t in zip(weights, targets)) / sum(weights)

In [None]:
with TemporaryDirectory(dir=systmpfs) as tmpdir:
    synth = model.eval(tmpdir, **variables.synth.to_dict())

synth = observations(synth, scale=scale, start=startdate)

In [None]:
with TemporaryDirectory(dir=systmpfs) as tmpdir:
    sim = model.eval(tmpdir)
    print(f"Before optimization loss is {loss_function(sim=sim)}")
    for t in targets:
        synth.compare(t, sim, desc=t)
        plt.show()

In [None]:
kwargs = {name: ng.p.Scalar(init=value.suggested, lower=value.lower, upper=value.upper) 
          for name, value in variables.iterrows()}

optimizer = ng.optimizers.registry[algorithm](parametrization=ng.p.Instrumentation(**kwargs),
                                              budget=budget,
                                              num_workers=num_workers)

logfile = NamedTemporaryFile(dir=systmpfs)
logger = ng.callbacks.ParametersLogger(logfile.name)
optimizer.register_callback("tell",  logger)

In [None]:
recommendation = optimizer.minimize(loss_function, executor=client, batch_mode=False)

In [None]:
samples = pd.DataFrame(logger.load())

samples.dropna(subset=['#loss'], inplace=True)
points = samples[variables.index].to_numpy()
losses = samples['#loss'].to_numpy()

problem = {'num_vars': variables.shape[0],
           'names': variables.index,
           'bounds': list(zip(variables.lower, variables.upper))}

SA = delta.analyze(problem, points, losses)

In [None]:
variables['best'] = pd.Series(recommendation.kwargs)
variables['err'] = 3 * (variables.synth - variables.best).abs() / (variables.upper - variables.lower)

pd.concat([variables, SA.to_df()], axis=1).sort_values('err')

In [None]:
experiment = logger.to_hiplot_experiment()

In [None]:
hidden_columns = ['uid', 
                  'from_uid', 
                  '#parametrization', 
                  '#optimizer', 
                  '#optimizer#noise_handling', 
                  '#optimizer#mutation',
                  '#optimizer#crossover',
                  '#optimizer#initialization',
                  '#optimizer#scale',
                  '#optimizer#recommendation',
                  '#optimizer#F1',
                  '#optimizer#F2',
                  '#optimizer#popsize',
                  '#optimizer#propagate_heritage',
                  '#session', 
                  '#lineage',
                  '#meta-sigma']

for name in variables.index:
    hidden_columns.append(name + '#sigma')
    hidden_columns.append(name + '#sigma#sigma')
    
table = experiment.display_data(hip.Displays.TABLE)
table.update({'hide': hidden_columns,
              'order_by': [['#num-tell', 'asc']]})

plot = experiment.display_data(hip.Displays.PARALLEL_PLOT)
plot.update({'hide': [*hidden_columns, '#num-tell'],
             'order': ['#generation', *variables.index, '#loss']})

In [None]:
experiment.display()

In [None]:
with TemporaryDirectory() as tmpdir:
    print(f"After optimization loss is {recommendation.loss}")
    sim = model.eval(tmpdir, **recommendation.kwargs)
    for t in targets:
        synth.compare(t, sim, desc=t)
        plt.show()

In [None]:
variables.err.mean()