In [None]:
import pickle
import sys

import pandas as pd
import numpy as np
import plotly.express as px

sys.path.append('../cluster_library')
sys.path.append('../')

import cluster_validation.davies_bouldin as db
import federated_clustering.local_learners as fcll
import federated_clustering.global_learner as fcgl

In [None]:
path_to_experiment_file = './results/global_drift_2023-12-14_22-42-10_results.pkl' #'./results/local_but_no_global_drift_2023-12-14_21-43-00_results.pkl'

In [None]:

with open(path_to_experiment_file, 'rb') as f:
    experiment_dict = pickle.load(f)

In [None]:
experiment_dict.keys()

### Initial clustering

In [None]:
init_global_centers = experiment_dict.get('initial_cluster_results').get('global_cluster_centers')
init_db_index = experiment_dict.get('initial_cluster_results').get('global_db')

In [None]:
init_global_centers.shape

In [None]:
fig_init_centers = px.scatter(init_global_centers[:,0], init_global_centers[:,1], 
                              title=f'Initial global cluster centers with DB index of {init_db_index:.4f}',
                              symbol_sequence=['x']
                              )
fig_init_centers.show()

#### Look into experiments

In [None]:
experiment_dict.get('experiments').keys()

In [None]:
def create_experiment_data_vis(experiment_no, experiment_dict, init_global_centers):
    
    dict_one_experiment = experiment_dict.get('experiments').get(experiment_no)
    title_str = f'Drift was detected: {dict_one_experiment.get("drift_detected")} with a DB index of {dict_one_experiment.get("recalculated_db"):.4f}'
    
    experiment_new_data_per_client = dict_one_experiment.get('new_data_per_client')

    # prepare data for visualization
    list_client_dfs = []
    for client_no, client_data in experiment_new_data_per_client.items():

        df_client = pd.DataFrame(client_data)
        df_client.rename(columns={0: 'x',
                                  1: 'y'
                                 }, inplace=True)

        df_client['color'] = f'client_{client_no}'
        df_client['marker'] = 'client'

        list_client_dfs.append(df_client)

    df_all_client_data = pd.concat(list_client_dfs).reset_index(drop=True)

    df_centers = pd.DataFrame(init_global_centers, columns=['x', 'y'])
    df_centers['color'] = 'global_init_centers'
    df_centers['marker'] = 'center'

    df_vis = pd.concat([df_all_client_data, df_centers]).reset_index(drop=True)

    fig_experiment_data = px.scatter(data_frame=df_vis,
                                     x='x',
                                     y='y',
                                     color='color',
                                     symbol='marker',
                                     title=title_str
                                      )
    return fig_experiment_data

In [None]:
list_fig_experiment_data = [create_experiment_data_vis(experiment_no, 
                                                       experiment_dict=experiment_dict,
                                                       init_global_centers=init_global_centers) 
                            for experiment_no in range(100)]

In [None]:
list_fig_experiment_data[3].show()

In [None]:
list_fig_experiment_data[2].show()