In [None]:
from pathlib import Path
import pandas as pd
import sys

# Add the root project directory to the Python path
project_root = Path.cwd().parent  # This will get the project root since the notebook is in 'notebooks/'
sys.path.append(str(project_root))
from configs.path_config import EXTRACTED_DATA_DIR, OUTPUT_DIR

from src.clustering import clustering_preprocess #load_data, drop_columns_by_header_rules, remove_outliers, explain_variance, do_pca
from src.clustering import clustering_models # kmeans_clustering, gmm_clustering, kl_divergence, jeffreys_divergence, merge_clusters_by_divergence, streaming_dpgmm_clustering
from src.clustering import clustering_visualization #plot_clusters_over_time, plot_cluster_mean_and_std
from src.clustering import sankey_diagram #plot_sankey_diagram

### Data loading and preprocessing


In [None]:
path = EXTRACTED_DATA_DIR / 'strain_distributions' / 'alvbrodel_04' / 'S-C_Close_Comp_20091129120000_20210611160000_strain_distribution_04.csv'
# path = OUTPUT_DIR / 'strain_distributions' / 'N-F_Mid_Comp_20091129120000_20210611160000_strain_distribution.csv'
df = clustering_preprocess.load_data(path)
# df

In [None]:
df = clustering_preprocess.drop_columns_by_header_rules(df, threshold=0)
# df

In [None]:
df_strain, df = clustering_preprocess.remove_outliers(df, threshold=7, individual_threshold=7)
# df

### Explained Variance by Number of Prinicipal Components 

In [None]:
clustering_preprocess.explain_variance(df_strain)

In [None]:
n_components = 8
normalized_pca_components, df_pca = clustering_preprocess.do_pca(n_components, df_strain, df)

### GMM Clustering

In [None]:
# n_clusters = 10
# data_with_gmm, cluster_color_map  = clustering_models.gmm_clustering(normalized_pca_components, df, n_clusters)
# data_with_gmm

In [None]:
# clustering_visualization.plot_clusters_over_time(data_with_gmm, cluster_color_map, 'GMM')

In [None]:
# clusters_to_keep = ['all'] # 'all' or a list of cluster indices
# clustering_visualization.plot_cluster_mean_and_std(data_with_gmm, clusters_to_keep, cluster_color_map, 'GMM')

### DPGMM Clustering

In [None]:
data_with_dpgmm, cluster_color_map, cluster_dict = clustering_models.streaming_dpgmm_clustering(
    normalized_pca_components=normalized_pca_components,
    df=df,
    prior=0.1,                # Decides how restrictive the model is when creating new clusters the lower the more restrictive
    n_points=1095,            # Number of points to use for the initial clustering, 3 years
    window_size=180,          # Size of the sliding window, 0,5 years
    step_size=90,             # Step size for the sliding window, 3 months
    max_components=100,       # Maximum number of components to use in the model
    merge_threshold=7,        # Threshold for merging clusters
)

### Plot the Cluster Assignment over Time

In [None]:
clustering_visualization.plot_clusters_over_time(data_with_dpgmm, cluster_color_map, 'DPGMM')

In [None]:
new_data_with_dpgmm = data_with_dpgmm[data_with_dpgmm['Assigned_Cluster_Prob'] > 1e-2]
new_data_with_dpgmm.shape 

### Cluster Visualization
Visualizes the representative strain distribution of each cluster in the form of mean strain (dark) and standard deviation of strain (light shade).

In [None]:
clusters_to_keep = [0, 4] # 'all' or a list of cluster indices
clustering_visualization.plot_cluster_mean_and_std(data_with_dpgmm, clusters_to_keep, cluster_color_map, 'DPGMM')

### Plot Sankey Diagram
Visualizes trainsitions between and merges of clusters after each step of the sliding window in the clustering algorithm. 

In [None]:
cluster_dict_converted = {
    pd.to_datetime(key, format="%Y-%m-%d").date(): value
    for key, value in cluster_dict.items()
}
# cluster_dict_converted = {
#     pd.to_datetime(k, format="%Y-%m-%d").date(): cluster_dict[k]
#     for i, k in enumerate(cluster_dict)
#     if i % 12 == 0
# }

links = sankey_diagram.build_sankey_links_from_cluster_dict(cluster_dict_converted)
nodes, source, target, value = sankey_diagram.prepare_sankey_data(links)

sankey_diagram.plot_sankey(
    nodes, 
    source,
    target, 
    value, 
    title="Cluster transitions over time (90 day intervals)", 
    save_path=OUTPUT_DIR / 'strain_distributions' / 'DPGMM' / "sankey_diagram_90_day_intervals.pdf", 
    save = True
)