# Load Dependecies and Hubbard Brook Info File

In [None]:
import os
import random
import shelve, pickle
import warnings
from datetime import datetime
import numpy as np
import pandas as pd
from scipy.stats import describe
from sklearn.decomposition import PCA, NMF
import matplotlib.pyplot as plt
import seaborn as sns; sns.set_theme()
%matplotlib widget

import hb_dic, nmf_utils

In [None]:
import json

with open("hbinfo.json", "r") as f:
    hb_config = json.load(f)
del f

hb_config

# NMF for a Single Watershed (Example)

## Load Data and Configurations

Load data as ``pd.DataFrame``, and then select interested time range. 

In [None]:
watershed = hb_config["biogeochem_ref_watershed"]
species = ["DIC", "Ca", "Mg", "K", "Na", "SO4", "Cl", "NO3", "SiO2"]
normalizer = "Na"
bootstrap = 2000
random_seed = 42

nmf_config = dict(
    init="random",
    random_state=random_seed,
    max_iter=10000,
    tol=1e-6,
)

multinmf_config = dict(
    nmf_type=nmf_utils.TrivialRescaledNMF,
    n_runs=20000,
)
n_selected = 50

heatmap_fig_config = dict(
    figsize=(8, 4),
    constrained_layout=True,
)

df = hb_dic.load_watershed_data(watershed, species)
df = df["2000":]  # Shaughnessy et al use 2000-2017

## Data Preprocessing
1. Drop rows with any missing values
2. Convert weight to molar mass
3. Normalize the dataframe with an assigned column (default "Na"), and drop that column
4. Divide each row 

In [None]:
preprocessor = nmf_utils.NMFPreprocessor(
    normalizer="Na", bootstrap=bootstrap, bootstrap_random_state=random_seed
)
V = preprocessor.transform(df)

## PCA

In [None]:
pca = PCA()
pca.fit(V)
print("Explained varience ratio:", pca.explained_variance_ratio_, sep="\n")
n_endmember = nmf_utils.count_endmember(pca)
print("Endmember number (explain >90% ratio):", n_endmember)

In [None]:
nmf = nmf_utils.TrivialRescaledNMF(
    n_components=n_endmember,
    **nmf_config
)

# mixing proportion
W = nmf.fit_transform(V)
# chemical signature
H = nmf.components_
H = pd.DataFrame(H, columns=V.columns)

In [None]:
plt.figure(**heatmap_fig_config)
heatmap = nmf_utils.ChemistryHeatmap()
heatmap.plot(H, V.columns)

## MultiNMF

In [None]:
warnings.filterwarnings("ignore")
multi_nmf = nmf_utils.MinStdPickedNMF(
    n_selected=n_selected,
    n_components=n_endmember,
    **multinmf_config,
    **nmf_config,
)
Hs = multi_nmf.fit_transform(V)
warnings.resetwarnings()

In [None]:
permuter = nmf_utils.NMFKmeansPermuter(n_endmember=n_endmember)

permuter.fit_transform(Hs, inplace=True)
labels = permuter.labels_
H_mean = pd.DataFrame(Hs.mean(axis=0), columns=V.columns)
H_mean *= preprocessor.scaler_

## Save Data and Figure

In [None]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
new_dir = "../models/output/{}/".format(timestamp)
os.mkdir(new_dir)

config_dict = dict(
    watershed=watershed,
    species=species,
    normalizer=normalizer,
    bootstrap=bootstrap,
    random_seed=random_seed,
    n_endmember=n_endmember,
    nmf_config=nmf_config,
    n_multinmf=multinmf_config["n_runs"],
    n_selected=n_selected,
    heatmap_fig_config=heatmap_fig_config,
    timestamp=timestamp,
)

with open(new_dir + "meta.txt", "w") as f:
    json.dump(config_dict, f)

with shelve.open(new_dir + "data") as db:
    db["raw"] = Hs
    db["mean"] = H_mean

plt.figure(**heatmap_fig_config)
heatmap = nmf_utils.ChemistryHeatmap()
for i, H in enumerate(Hs):
    heatmap.plot(H, V.columns)
    plt.savefig(new_dir + "heatmap_{}.png".format(i))
    plt.clf()
heatmap.plot(H_mean, H.columns)
plt.savefig(new_dir + "heatmap_mean.svg")


In [None]:
def filter_sample(v, w, h, idx, lower_bound=0.9, upper_bound=1.1):
    sample_data = v.iloc[idx]
    sample_proportion = w[idx]

    accepted = lower_bound <= sum(sample_proportion) <= upper_bound
    err = None
    if accepted:
        err = np.sum(sample_data - np.dot(sample_proportion, h))
    return accepted, err