In [None]:
# Defaults
SASampleN = 256 # Nsamples = N x (2D + 2)

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from dask.distributed import SLURMCluster, Client
from joblib import Parallel, delayed, parallel_backend
from SALib.sample import saltelli
from SALib.analyze import sobol

from calibrations.matsch_b2 import loss, variables

In [None]:
cluster = SLURMCluster()
cluster.scale(cores=SASampleN)

client = Client(cluster)

In [None]:
def SA_loss(xs):
         
    return loss_function(**dict(zip(variables.index, xs)))

problem = {'num_vars': variables.shape[0],
           'names': variables.index,
           'bounds': list(zip(variables.lower, variables.upper))}

samples = saltelli.sample(problem, SASampleN, calc_second_order=True)

In [None]:
with parallel_backend('dask'):
    losses = np.fromiter(Parallel(verbose=10)(delayed(SA_loss)(x) for x in samples), dtype=float)

In [None]:
SA = sobol.analyze(problem, losses, calc_second_order=True, parallel=True);

In [None]:
S1 = pd.DataFrame({key: value for key, value in SA.items() if key not in ['S2', 'S2_conf']}, index=problem['names'])
S1.sort_values('S1', key=np.abs, ascending=False)

In [None]:
S2 = SA['S2'].copy()
for i in range(S2.shape[0]):
    for j in range(i):
        S2[i,j] = S2[j,i]
        
f, ax = plt.subplots(figsize=(16, 9), dpi=100)
sns.heatmap(S2, 
            xticklabels=variables.index, 
            yticklabels=variables.index, 
            annot=True, 
            fmt='.3f', 
            linewidths=0.1, 
            linecolor='grey', 
            ax=ax)