In [None]:
import pandas as pd
import matplotlib.pyplot as plt

from dask.distributed import Client
from dask_jobqueue import SLURMCluster

import nevergrad as ng
from mhpc_project.matsch_b2 import CalibrationModel, Variables, Loss, Calibration
from geotopy.utils import date_parser, comparison_plot, DataFrameLogger
from geotopy.measures import KGE

In [None]:
settings = {
    'parametrization': {
        'mutable_sigma': True,
        'init_sigma': 0.1,
        'lower': 0.0,
        'upper': 1.0,
    },
    'optimizer': {
        'budget': 4096,
        'num_workers': 256
    }
}

In [None]:
cluster = SLURMCluster()
cluster.scale(cores=settings['optimizer']['num_workers'])
client = Client(cluster)

In [None]:
observations = pd.read_csv('../data/Matsch B2/obs.csv',
                           na_values=['-9999', '-99.99'],
                           usecols=[0, 7],
                           parse_dates=[0],
                           date_parser=date_parser,
                           index_col=0,
                           squeeze=True)
observations.index.rename('datetime', inplace=True)
model = CalibrationModel('../data/Matsch B2/geotop', run_args={'timeout': 120})
variables = Variables('../data/Matsch B2/variables.csv')
measure = KGE(observations)
loss = Loss(model, variables, measure)
calibration = Calibration(loss, settings)

In [None]:
simulation = model()
print(f"Before optimization loss is {measure(simulation)}")
comparison_plot(observations,
                simulation,
                desc='Soil moisture content @ 5cm')
plt.show()

In [None]:
calibration.optimizer.register_callback('tell', ng.callbacks.ProgressBar())
logger = DataFrameLogger(variables)
calibration.optimizer.register_callback('tell', logger)

In [None]:
loss, settings = calibration()


In [None]:
experiment = logger.experiment
experiment.display()

In [None]:
simulation = model(**settings)
print(f"After optimization loss is {loss}")
comparison_plot(observations,
                simulation,
                desc='Soil moisture content @ 5cm')
plt.show()