In [None]:
# Defaults

# System settings
systmpfs = '/tmp'

# SA settings
num_workers = 2
timeout = 30
SASampleN = 2 # Nsamples = N x (2D + 2)
target = 'soil_moisture_content_50'

In [None]:
from subprocess import CalledProcessError, TimeoutExpired
from tempfile import TemporaryDirectory
from concurrent.futures import ProcessPoolExecutor

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from SALib.sample import saltelli
from SALib.analyze import sobol

from common import observations, GEOtopRun

In [None]:
class GEOtopRunLogVars(GEOtopRun):
    
    def preprocess(self, working_dir, *args, **kwargs):
        
        for key, value in kwargs.items():
            if variables.type[key] == 'log':
                kwargs[key] = 10 ** value
                
        super().preprocess(working_dir, *args, **kwargs)

In [None]:
model = GEOtopRunLogVars('inputs/Matsch_B2/run',
                         exe='../geotop/build/geotop',
                         run_args={'check': True, 
                                   'capture_output': True, 
                                   'timeout': timeout})

obs = observations('inputs/Matsch_B2/obs.csv')

variables = pd.read_csv('inputs/Matsch_B2/variables.csv', index_col='name')

In [None]:
def loss_function(*args, **kwargs):

    with TemporaryDirectory(dir=systmpfs) as tmpdir:
        try:
            sim = model.eval(tmpdir, *args, **kwargs)
        except CalledProcessError:
            return np.nan
        except TimeoutExpired:
            return np.nan

    return obs.metric(target, sim) # < (y_obs - y_sim)^2 > / < y_obs^2  >

In [None]:
def SA_loss(xs):
         
    return loss_function(**dict(zip(variables.index, xs)))

problem = {'num_vars': variables.shape[0],
           'names': variables.index,
           'bounds': list(zip(variables.lower, variables.upper))}

samples = saltelli.sample(problem, SASampleN, calc_second_order=True)

with ProcessPoolExecutor(max_workers=num_workers) as executor:
    losses = np.fromiter(executor.map(SA_loss, samples), dtype=float)

In [None]:
SA = sobol.analyze(problem, losses, calc_second_order=True, parallel=True, n_processors=num_workers);

In [None]:
S1 = pd.DataFrame({key: value for key, value in SA.items() if key not in ['S2', 'S2_conf']}, index=problem['names'])
S1.sort_values('S1', key=np.abs, ascending=False)

In [None]:
S2 = SA['S2'].copy()
for i in range(S2.shape[0]):
    for j in range(i):
        S2[i,j] = S2[j,i]
        
f, ax = plt.subplots(figsize=(16, 9), dpi=100)
sns.heatmap(S2, 
            xticklabels=variables.index, 
            yticklabels=variables.index, 
            annot=True, 
            fmt='.3f', 
            linewidths=0.1, 
            linecolor='grey', 
            ax=ax)