In [None]:
import pandas as pd
import numpy as np
from scipy.stats import pearsonr
import seaborn as sns
from matplotlib import pyplot as plt

from ipywidgets import widgets, interact, interactive_output
from IPython.display import display

In [None]:
df = pd.read_csv('../data/WWTP_influence_IOW.csv')  # read in data
df.set_index('Sample', inplace=True)   # set sample as index
df.columns = df.columns.str.removeprefix('WWTP_influence_as_')  # remove prefix
df.columns = pd.Index([x.replace('__', '\n') for x in df.columns.to_list()])  # replace double underscore with newline
df = (df - df.min()) / (df.max() - df.min())  # normalise to 0-1

In [None]:
## function to plot column labels in diagonal (dooesn't work when filtering df coluns interactively)
# it = iter(list(df.columns))
# def diagfunc(*args, **kws):
#     plt.gca().annotate(next(it), xy=(0.05, 0.6), xytext=(5,-5), ha="left", va="top", 
#                        xycoords=plt.gca().transAxes, textcoords="offset points")


def reg_coef(x,y,label=None,color=None,**kwargs):
    ax = plt.gca()
    r,p = pearsonr(x,y)
    ax.annotate(
        f'r =\n{r:.2f}',
        fontweight='extra bold',
        fontfamily='monospace',
        color=f'{round(abs(r))}',
        xy=(0.5,0.5),
        xycoords='axes fraction',
        ha='center',
        bbox={'boxstyle': f'Circle, pad={5*abs(r)**4}', 'color': f'{1-abs(r)}'},
        annotation_clip=True)
    ax.set_axis_off()

def scattermatrix(df, param_type, param_sed, param_size, param_season, param_buffer):
    """
    Filter a data frame by column names.
    :param df: pandas data frame
    :param words: list of strings
    :return: pandas data frame
    """
        
    param_buffer = [str(arg) for arg in param_buffer]  # ensure all values are strings
    df = df[[col for col in df.columns if any(arg in col for arg in param_type)]]
    df = df[[col for col in df.columns if any(arg in col for arg in param_sed)]]
    df = df[[col for col in df.columns if any(arg in col for arg in param_size)]]
    df = df[[col for col in df.columns if any(arg in col for arg in param_season)]]
    df = df[[col for col in df.columns if any(arg in col for arg in param_buffer)]]
    # print(words)
    # print(df)

    # sns.pairplot(df, kind="reg", diag_kind="kde", plot_kws={'line_kws':{'color':'red'}, 'scatter_kws': {'alpha': 0.3}})
    g = sns.PairGrid(df)
    g.map_diag(sns.histplot)
    g.map_upper(reg_coef)
    g.map_lower(sns.regplot)
    # g.set(xlabel=None, ylabel=None)
    g.savefig('../plots/corrmatrix_wwtpinfluence.svg')
    return df

In [None]:
param_type = widgets.SelectMultiple(
    options=['tracer_mean_dist', 'endpoints_mean_dist', 'cumulated_residence', 'mean_time_travelled'],
    value=['mean_time_travelled'],
    rows=4,
    description='Type',
    disabled=False
)

param_sed = widgets.SelectMultiple(
    options=['nosed', 'sed'],
    value=['nosed'],
    rows=2,
    description='Sedimentation',
    disabled=False
)

param_size = widgets.SelectMultiple(
    options=['0µm', '18µm', 'allsizes'],
    value=['allsizes'],
    rows=3,
    description='Size',
    disabled=False
)

param_season = widgets.SelectMultiple(
    options=['spring', 'summer', 'autumn', 'allseasons'],
    value=['allseasons'],
    rows=4,
    description='Season',
    disabled=False
)

param_buffer = widgets.SelectMultiple(
    options=[222, 444],
    value=[222, 444],
    rows=2,
    description='Buffer',
    disabled=False
)

ui = widgets.HBox([param_type, param_sed, param_size, param_season, param_buffer])
out1 = interactive_output(
    scattermatrix, {
        'df': widgets.fixed(df),
        'param_type': param_type,
        'param_sed': param_sed,
        'param_size': param_size,
        'param_season': param_season,
        'param_buffer': param_buffer
        })
# out2 = interactive_output(datafilter, {'df': widgets.fixed(df), 'words': param_sed})
# out3 = interactive_output(datafilter, {'df': widgets.fixed(data), 'words': param_type})
# out4 = interactive_output(datafilter, {'df': widgets.fixed(data), 'words': param_type})
# out5 = interactive_output(datafilter, {'df': widgets.fixed(data), 'words': param_type})

display(ui, out1)