## Modeling Profiler

Using the ESAT simulator to evaluate potential approaches to optimize modeling very large datasets.

The first approach will look at implementing and validating the following workflow:
1. Create a subset dataset of the input by randomly selecting N values from the input/uncertainty.
2. Train a single model on that data until convergence.
3. Use the factor profile H matrix to calculate a W for the complete dataset.
4. Calculate Q(full)
5. Take a new subset of the data, restart training with the prior H.
6. Repeat until Q(full) is no longer decreasing.

Run full dataset model with the same random seed and evaluate the difference in loss and factor profiles.

#### Code Imports

In [None]:
from esat.data.datahandler import DataHandler
from esat.model.batch_sa import BatchSA
from esat.model.sa import SA
from esat.data.analysis import ModelAnalysis, BatchAnalysis
from esat_eval.simulator import Simulator
from esat.estimator import FactorEstimator
from esat_eval.factor_catalog import FactorCatalog, Factor

from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import min_weight_full_bipartite_matching
from sklearn.manifold import TSNE, MDS
import tensorflow as tf
from tensorflow.keras.models import Model as kModel
from tensorflow.keras.layers import Input, Dense

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.colors as pc
import plotly.io as pio
import logging
import time
import pandas as pd
import numpy as np
import copy
import os

logger = logging.getLogger(__name__)

#### Synthetic Dataset

Generate a synthetic dataset where the factor profiles and contributions are pre-determined for model output analysis.

In [None]:
# Synethic dataset parameters
seed = 3
syn_factors = 5                # Number of factors in the synthetic dataset
syn_features = 40              # Number of features in the synthetic dataset
syn_samples = 10000             # Number of samples in the synthetic dataset
outliers = True                # Add outliers to the dataset
outlier_p = 0.2               # Decimal percent of outliers in the dataset
outlier_mag = 3                # Magnitude of outliers
contribution_max = 2           # Maximum value of the contribution matrix (W) (Randomly sampled from a uniform distribution)
noise_mean_min = 0.4          # Min value for the mean of noise added to the synthetic dataset, used to randomly determine the mean decimal percentage of the noise for each feature.
noise_mean_max = 0.5          # Max value for the mean of noise added to the synthetic dataset, used to randomly determine the mean decimal percentage of the noise for each feature.
noise_scale = 0.25              # Scale of the noise added to the synthetic dataset
uncertainty_mean_min = 0.04    # Min value for the mean uncertainty of a data feature, used to randomly determine the mean decimal percentage for each feature in the uncertainty dataset. 
uncertainty_mean_max = 0.07    # Max value for the mean uncertainty of a data feature, used to randomly determine the mean decimal percentage for each feature in the uncertainty dataset. 
uncertainty_scale = 0.01       # Scale of the uncertainty matrix

In [None]:
# Initialize the simulator with the above parameters
simulator = Simulator(seed=seed,
                      factors_n=syn_factors,
                      features_n=syn_features,
                      samples_n=syn_samples,
                      outliers=outliers,
                      outlier_p=outlier_p,
                      outlier_mag=outlier_mag,
                      contribution_max=contribution_max,
                      noise_mean_min=noise_mean_min,
                      noise_mean_max=noise_mean_max,
                      noise_scale=noise_scale,
                      uncertainty_mean_min=uncertainty_mean_min,
                      uncertainty_mean_max=uncertainty_mean_max,
                      uncertainty_scale=uncertainty_scale
                     )

In [None]:
# Example command for passing in a custom factor profile matrix, instead of the randomly generated profile matrix.
# my_profile = np.ones(shape=(syn_factors, syn_features))
# simulator.generate_profiles(profiles=my_profile)

In [None]:
# Example of how to customize the factor contributions. Curve_type options: 'uniform', 'decreasing', 'increasing', 'logistic', 'periodic'
# simulator.update_contribution(factor_i=0, curve_type="logistic", scale=0.1, frequency=0.5)
# simulator.update_contribution(factor_i=1, curve_type="periodic", minimum=0.0, maximum=1.0, frequency=0.5, scale=0.1)
# simulator.update_contribution(factor_i=2, curve_type="increasing", minimum=0.0, maximum=1.0, scale=0.1)
# simulator.update_contribution(factor_i=3, curve_type="decreasing", minimum=0.0, maximum=1.0, scale=0.1)
# simulator.plot_synthetic_contributions()

#### Load Data
Assign the processed data and uncertainty datasets to the variables V and U. These steps will be simplified/streamlined in a future version of the code.

In [None]:
syn_input_df, syn_uncertainty_df = simulator.get_data()

In [None]:
data_handler = DataHandler.load_dataframe(input_df=syn_input_df, uncertainty_df=syn_uncertainty_df)
V, U = data_handler.get_data()

In [None]:
cwd = os.getcwd()
data_dir = os.path.join(cwd, "..", "data")

# Baton Rouge Dataset
br_input_file = os.path.join(data_dir, "Dataset-BatonRouge-con.csv")
br_uncertainty_file = os.path.join(data_dir, "Dataset-BatonRouge-unc.csv")
br_output_path = os.path.join(data_dir, "output", "BatonRouge")
# Baltimore Dataset
b_input_file = os.path.join(data_dir, "Dataset-Baltimore_con.txt")
b_uncertainty_file = os.path.join(data_dir, "Dataset-Baltimore_unc.txt")
b_output_path = os.path.join(data_dir, "output", "Baltimore")
# Saint Louis Dataset
sl_input_file = os.path.join(data_dir, "Dataset-StLouis-con.csv")
sl_uncertainty_file = os.path.join(data_dir, "Dataset-StLouis-unc.csv")
sl_output_path = os.path.join(data_dir, "output", "StLouis")

# data_handler = DataHandler(
#     input_path=b_input_file,
#     uncertainty_path=b_uncertainty_file,
#     index_col="Date"
# )
# V, U = data_handler.get_data()
# print(f"{V.shape}")

### Train Batch Profile

In [None]:
def calculate_W(V, U, H):
    H[H <= 0.0] = 1e-8
    W = np.matmul(V * np.divide(1, U), H.T)
    return W

def q_loss(V, U, H, W):
    residuals = (V-np.matmul(W, H))/U
    return np.sum(residuals)

def mse(V, U, H, W):
    WH = np.matmul(W, H)
    residuals = ((V-WH)/U)**2
    return np.sum(residuals)/V.size

def compare_H(H1, H2):
    correlation_matrix = np.zeros((H1.shape[0], H2.shape[0]))
    for i in range(H1.shape[0]):
        f1 = H1[i].astype(float)
        for j in range(H2.shape[0]):
            f2 = H2[j].astype(float)
            corr_matrix = np.corrcoef(f2, f1)
            corr = corr_matrix[0, 1]
            r_sq = corr ** 2
            correlation_matrix[i,j] = r_sq
    return correlation_matrix
            
def plot_correlations(matrix):
    header = [f"Factor {i}" for i in range(matrix.shape[0])]
    fig = go.Figure(data=[go.Table(header=dict(values=header), cells=dict(values=matrix))])
    fig.show()

def prepare_data(V, U, i_selection):
    _V = pd.DataFrame(V.copy()[i_selection,:])
    _U = pd.DataFrame(U.copy()[i_selection,:])
    
    for f in _V.columns:
        _V[f] = pd.to_numeric(_V[f])
        _U[f] = pd.to_numeric(_U[f])
    return _V.to_numpy(), _U.to_numpy()

### Version 2 of the FactorCatalog

FactorCatalog V2 takes a more robust approach to grouping factor profiles from multiple models with the grouping occuring after all factor profiles have been collected. The updated procedure
is designed to be used to investigate potential solutions for creating models for very large datasets using subsets of the data. The algorithm is described as:
1. Specify your hyper-parameters: samples_n, batches_n, models_n, random_seed, correlation_threshold, factors_k
2. For each batch in batches_n.
3. Create a subset dataset using samples_n randomly selected values from V/U.
4. Created a batchSA instance of models_n, using random_seed and factors_k.
5. Add each output factor to the FactorCatalog (factor model_i, fit_rmse, factor_i, H)
6. Once all batches are completed, cluster the factor collection using a constrained k-means cluster function.
7. Score the models based upon a heuristic, such as the sum of (cluster_cor_avg*cluster_members)
8. Evaluate the clustered profile matrix, using the cluster centroid values, using the complete dataset.

The primary modification of the FactorCatalog is to use the constrained k-means clustering function for grouping 'like' factor profiles. The procedure will work by:
1. Starting with the factors_k=clusters_n, calculate the correlation of all factors/points to the centroids, by model.
2. Initialize the clusters by randomly creating clusters or by selecting clusters_n 'dissimilar' factors.
3. Randomly shuffle the model order for factor assignment:
   1. Assign the factors to the clusters by order of correlation, closer to 1.0 goes first.
   2. If more than one factor in a model would be assigned to the same cluster assign the factor with the highest cor to that cluster and then repeat excluding that cluster(s) until all factors are assigned.
   3. If any given factor does not have a correlation above the specified threshold, create a new cluster centered at that point.
4. At the end of each assignment iteration, remove any cluster which has no members.
5. Once max_assignment_n iterations is reached stop, or when a reassignment doesn't change.

The constrained k-means clustered is a standard clustering approach with the exception that the distance is calculated as 1/r2 of the point to the cluster centroid. Two other differences are a) the number of clusters can increase and decrease depending on correlation threshold and by the constraint that a model can only contribute one factor to any given cluster.

Once all the factors for all models have been clustered, the FactorCatalog models can be scored based upon the heuristic stated in stage A7. The best model, or any selected model, factor profile matrix (H) can then be selected for final evaluation. The factor profile H is not what was produced by the model, but is the mean values of the FactorCatalog's factor (the centroid of the cluster that those factors were assigned to). This approach allows for the factor profile values to be provided as a distributed of possible values for each feature, or demonstrating potential uncertainty in the factor profile. The clustered factor profile is then used to fit the full dataset, but keeping H constant. The loss can then be evaluated against what is calculated for a long-running brute force approach. 

An evaluation of the impact on the model/W matrix and loss given a random selection, MC, simulation of the factor profile would be an interesting next step.

In [None]:
# Define the autoencoder
def build_autoencoder(input_dim, encoding_dim):
    # Encoder
    input_layer = Input(shape=(input_dim,))
    encoded = Dense(128, activation='relu')(input_layer)
    encoded = Dense(64, activation='relu')(encoded)
    encoded = Dense(encoding_dim, activation='linear')(encoded)  # Latent space

    # Decoder
    decoded = Dense(64, activation='relu')(encoded)
    decoded = Dense(128, activation='relu')(decoded)
    decoded = Dense(input_dim, activation='sigmoid')(decoded)

    # Autoencoder model
    autoencoder = kModel(inputs=input_layer, outputs=decoded)

    # Encoder model (for reduced representation)
    encoder = kModel(inputs=input_layer, outputs=encoded)

    return autoencoder, encoder


class Factor:
    def __init__(self,
                 factor_id,
                 profile,
                 model_id
                ):
        self.factor_id = factor_id
        self.profile = profile
        self.model_id = model_id
        self.cluster_id = None
        self.cor = None

    def assign(self, cluster_id, cor):
        self.cluster_id = cluster_id
        self.cor = cor

    def deallocate(self):
        self.cluster_id = None
        self.cor = None

    def distance(self, cluster):
        f1 = np.array(self.profile).astype(float)
        f2 = np.array(cluster).astype(float)
        corr_matrix = np.corrcoef(f2, f1)
        corr = corr_matrix[0, 1]
        r_sq = corr ** 2
        return r_sq


class Model:
    def __init__(self,
                 model_id):
        self.model_id = model_id
        self.factors = []

        self.score = None
        
    def add_factor(self, factor):
        self.factors.append(factor)


class Cluster:
    def __init__(self,
                 cluster_id,
                 centroid: np.ndarray
                ):
        self.cluster_id = cluster_id
        self.centroid = centroid
        self.factors = []
        self.count = 0

        self.mean_r2 = 0
        self.std = 0
        self.min_values = np.full(len(centroid), np.nan)
        self.max_values = np.full(len(centroid), np.nan)
        # print(f"Initilaized cluster {self.cluster_id} with centroid: {self.centroid}")

    def __len__(self):
        return self.count

    def add(self, factor: Factor, cor: float):
        factor.assign(cluster_id=self.cluster_id, cor=cor)
        self.factors.append(factor)
        self.count += 1
        self.min_values = np.minimum(self.min_values, factor.profile)
        self.max_values = np.maximum(self.max_values, factor.profile)
        self.mean_r2 = np.mean([factor.cor for factor in self.factors])
        self.std = np.std([factor.profile for factor in self.factors], axis=0)

    def purge(self):
        # print(f"Purging Cluster: {self.cluster_id}, Members: {self.count}")
        for factor in self.factors: factor.deallocate()
        self.factors = []
        self.count = 0
        self.mean_r2 = 0
        self.std = 0
        self.min_values = np.full(len(self.centroid), np.nan)
        self.max_values = np.full(len(self.centroid), np.nan)

    def recalculate(self):
        if len(self.factors) > 0:
            factor_matrix = np.array([factor.profile for factor in self.factors])
            new_centroid = np.mean(factor_matrix, axis=0)
            # print(f"Recalculating Cluster: {self.cluster_id}, Centroid: {new_centroid}")
            self.centroid = new_centroid

    def plot(self):
        n_features = len(self.centroid)
        factor_matrix = np.array([factor.profile for factor in self.factors])
        
        box_plot = go.Figure()
        for i in range(n_features):
            box_plot.add_trace(go.Box(
                y=factor_matrix[:,i], 
                boxpoints="all", 
                jitter=0.5,
                whiskerwidth=0.2,
                fillcolor=cls,
                marker_size=2,
                line_width=1, 
                name=f"Feature {i+1}")
            )
        box_plot.add_trace(go.Scatter(
            x=np.range(n_features), 
            y=self.centroid, 
            name="Centroid",
            mode='markers',
            marker=dict(color='red', size=10)
        ))
        box_plot.update_layout(title="Clustered Factor Profile", width=1200, height=800)
        box_plot.show()

        
class BatchFactorCatalog:
    def __init__(self,
                 n_factors: int,
                 n_features: int,
                 threshold: float = 0.8,
                 seed: int = 42
                ):
        self.n_factors = n_factors
        self.n_features = n_features
        self.threshold = threshold

        self.rng = np.random.default_rng(seed)

        self.models = {}
        self.model_count = 0
        self.factors = {}
        self.factor_count = 0

        # Min and max values for all factor vectors, used for random initialization of the centroids in clustering
        self.factor_min = None
        self.factor_max = None

        self.clusters = {}
        self.dropped_clusters = []
        self.max_clusters_found = 0

        self.state = {}

    def results(self):
        results = {}
        for c, cluster in self.clusters.items():
            results[cluster.cluster_id] = {
                "count": len(cluster),
                "mean_r2": cluster.mean_r2,
                "std": cluster.std
            }
        return results

    def plot(self, base_matrix = None, count_threshold: int = 3, method="mds", membership_p: float = 0.0):        
        all_factors = np.array([v.profile for k, v in self.factors.items()])
        model_assignment = np.array([v.model_id for k, v in self.factors.items()])
        factor_assignments = np.array([v.cluster_id for k, v in self.factors.items()])
        cluster_centroids = [(c, cluster.centroid) for c, cluster in self.clusters.items()]
        
        i_cluster, i_centroids = zip(*cluster_centroids)
        i_length = len(i_cluster)
        i_centroids = np.array(i_centroids)
        
        df_pca0 = pd.DataFrame(all_factors)
        factor_columns = df_pca0.columns
        df_pca0["Cluster"] = factor_assignments
        df_pca0["text"] = "Profile " + df_pca0["Cluster"].astype(str) + " Model: " + model_assignment.astype(str)
        
        df_centroids0 = pd.DataFrame(i_centroids, index=list(i_cluster))
        cluster_columns = df_centroids0.columns
        assigned_centroids, cluster_size = np.unique(factor_assignments, return_counts=True)
        df_centroids0["Cluster"] = list(i_cluster)
        df_centroids0= df_centroids0.loc[assigned_centroids]
        df_centroids0["count"] = cluster_size
        
        if membership_p > 0.0:
            point_cluster_n = []
            for i in range(len(all_factors)):
                i_cluster_count = df_centroids0[df_centroids0["Cluster"] == df_pca0["Cluster"].iloc[i]]["count"].values
                point_cluster_n.append(i_cluster_count)
            cluster_n_threshold = int(len(all_factors) * membership_p) 
            df_pca0["cluster_n"] = point_cluster_n
            df_pca0["cluster_n"] = df_pca0["cluster_n"].astype(int)
            df_pca0 = df_pca0[df_pca0["cluster_n"] > cluster_n_threshold]
            
            df_centroids0 = df_centroids0[df_centroids0["count"] > cluster_n_threshold]
        elif count_threshold > 0:
            df_centroids0 = df_centroids0[df_centroids0["count"] > count_threshold]
        
        all_factors = df_pca0[factor_columns].values
        i_centroids = df_centroids0[cluster_columns].values
        
        if method.lower() == "tsne":
            samples = np.vstack((all_factors, i_centroids))
            factor_model = TSNE(n_components=3, random_state=0, perplexity=min(50, len(samples)))
            reduction_results = factor_model.fit_transform(samples)
            factor_reduction = reduction_results[:len(all_factors),:]
            pca_centroids = reduction_results[len(all_factors):,:]
        elif method.lower() == "mds":
            factor_model = MDS(n_components=3, random_state=0, metric=True, max_iter=300)
            reduction_results = factor_model.fit_transform(np.vstack((all_factors, i_centroids)))
            factor_reduction = reduction_results[:len(all_factors),:]
            pca_centroids = reduction_results[len(all_factors):,:]
        else:
            autoencoder, encoder = build_autoencoder(input_dim=i_centroids.shape[1], encoding_dim=3)
            autoencoder.compile(optimizer='adam', loss='mse')
            autoencoder.fit(all_factors, all_factors, epochs=50, batch_size=16, shuffle=True, validation_split=0.2, verbose=0)
            factor_reduction = encoder.predict(all_factors, verbose=0)
            pca_centroids = encoder.predict(i_centroids, verbose=0)
        
        df_pca = pd.DataFrame(factor_reduction, columns=['PCA1', 'PCA2', 'PCA3'])
        df_pca["Cluster"] = df_pca0["Cluster"].values
        df_pca["text"] = df_pca0["text"].values
        
        df_centroids = pd.DataFrame(pca_centroids, columns=['PCA1', 'PCA2', 'PCA3'], index=list(df_centroids0.index))
        df_centroids["text"] = "Centroid " + df_centroids.index.astype(str)
        df_centroids["Cluster"] = df_centroids0["Cluster"].values
        df_centroids["count"] = df_centroids0["count"].values
            
        color_map = factor_catalog.generate_continuous_colormap(len(df_centroids["Cluster"]), colormap_name='rainbow')
        full_colormap = dict(zip(list(df_centroids["Cluster"]), [color[1] for color in color_map]))
        df_pca["color"] = df_pca["Cluster"].map(full_colormap)
        df_centroids["color"] = df_centroids["Cluster"].map(full_colormap)
        
        fig = go.Figure()
        fig.add_trace(go.Scatter3d(
            x=df_pca["PCA1"],
            y=df_pca["PCA2"],
            z=df_pca["PCA3"],
            mode='markers',
            marker=dict(
                size=3,
                color=df_pca["color"],
                opacity=0.5
            ),
            text=df_pca["text"],
            hoverinfo='text'
        ))
        fig.add_trace(go.Scatter3d(
            x=df_centroids['PCA1'],
            y=df_centroids['PCA2'],
            z=df_centroids['PCA3'],
            mode='markers',
            marker=dict(
                size=3,
                color="black",
                symbol='x',
                opacity=0.5
            ),
            text=df_centroids["text"],
            hoverinfo='text'
        ))        
        # Update layout for better visualization
        fig.update_layout(
            title=f"Factor Clustering - Method: {method}",
            scene=dict(
                xaxis_title="PCA1",
                yaxis_title="PCA2",
                zaxis_title="PCA3"
            ),
            showlegend=False,
            height=800,
            width=1000,
            margin=dict(l=5, r=5, b=5, t=50)
        )
        # Show the plot
        fig.show()

    def generate_continuous_colormap(self, n, colormap_name='viridis'):
        # Get the colormap from matplotlib
        cmap = plt.get_cmap(colormap_name)

        # Generate colors from the colormap
        colors = [cmap(i / (n - 1)) for i in range(n)]

        # Convert RGBA to RGB and then to hex
        colors_hex = ['#%02x%02x%02x' % (int(r * 255), int(g * 255), int(b * 255)) for r, g, b, a in colors]

        # Create a Plotly color scale
        plotly_color_scale = [(i / (n - 1), color) for i, color in enumerate(colors_hex)]

        return plotly_color_scale

    def animate(self, to_file: bool=False, base_matrix = None, profiled_matrix = None, count_threshold: int = 3, use_kernel: bool = False):        
        all_factors = np.array([v.profile for k, v in self.factors.items()])
        if use_kernel:
            factor_pca = KernelPCA(n_components=3, kernel='rbf', gamma=0.1)
            pca_results = factor_pca.fit_transform(all_factors)
        else:
            factor_pca = PCA(n_components=3)
            pca_results = factor_pca.fit_transform(all_factors)
        
        plot_base = False
        if base_matrix is not None:
            pcs_base = factor_pca.transform(base_matrix)
            df_base = pd.DataFrame(pcs_base, columns=['PCA1', 'PCA2', 'PCA3'])
            plot_base = True

        plot_profiled = False
        if profiled_matrix is not None:
            pcs_profiled = factor_pca.transform(profiled_matrix)
            df_profiled = pd.DataFrame(pcs_profiled, columns=['PCA1', 'PCA2', 'PCA3'])
            plot_profiled = True

        frames = []
        state0 = None
        color_map = self.generate_continuous_colormap(self.max_clusters_found, colormap_name='rainbow')
        full_colormap = dict(zip(list(np.arange(self.max_clusters_found)), [color[1] for color in color_map]))

        for i in range(len(self.state)):
            cluster_centroids = self.state[i]["cluster_centroids"]
            i_cluster, i_centroids = zip(*cluster_centroids)
            factor_assignments = self.state[i]["assignment"]
            
            pca_centroids = factor_pca.transform(i_centroids)
            df_pca = pd.DataFrame(pca_results, columns=['PCA1', 'PCA2', 'PCA3'])
            df_pca["Cluster"] = factor_assignments
            df_pca["text"] = "Profile " + df_pca["Cluster"].astype(str)
            df_pca["color"] = df_pca["Cluster"].map(full_colormap)          
        
            df_centroids = pd.DataFrame(pca_centroids, columns=['PCA1', 'PCA2', 'PCA3'], index=list(i_cluster))
            assigned_centroids, cluster_size = np.unique(factor_assignments, return_counts=True)
            df_centroids["text"] = "Centroid " + df_centroids.index.astype(str)
            df_centroids["Cluster"] = list(i_cluster)
            df_centroids = df_centroids.loc[assigned_centroids]
            df_centroids["count"] = cluster_size
            
            if count_threshold > 0:
                df_centroids = df_centroids[df_centroids["count"] > count_threshold]
            
            df_centroids["color"] = df_centroids["Cluster"].map(full_colormap)
                    
            data = []
            data.append(go.Scatter3d(
                x=df_pca["PCA1"],
                y=df_pca["PCA2"],
                z=df_pca["PCA3"],
                mode='markers',
                marker=dict(
                    size=3,
                    color=df_pca["color"],
                    opacity=0.8
                ),
                text=df_pca["text"],
                hoverinfo='text'
            ))
            # assigned_cluster_size = df_centroids["count"] * 3
            # data.append(go.Scatter3d(
            #     x=df_centroids['PCA1'],
            #     y=df_centroids['PCA2'],
            #     z=df_centroids['PCA3'],
            #     mode='markers',
            #     marker=dict(
            #         size=assigned_cluster_size,
            #         color=df_centroids["color"],
            #         opacity=0.2
            #     ),
            #     text=df_centroids["text"],
            #     hoverinfo='text'
            # ))
            data.append(go.Scatter3d(
                x=df_centroids['PCA1'],
                y=df_centroids['PCA2'],
                z=df_centroids['PCA3'],
                mode='markers',
                marker=dict(
                    size=3,
                    color="black",
                    symbol='x',
                ),
                text=df_centroids["text"],
                hoverinfo='text'
            ))
            if plot_base:
                data.append(go.Scatter3d(
                    x=df_base['PCA1'],
                    y=df_base['PCA2'],
                    z=df_base['PCA3'],
                    mode='markers',
                    marker=dict(
                        size=8,
                        color="black",
                        symbol='cross'
                    ),
                    name="Base Factor"
                ))
            if plot_profiled:
                data.append(go.Scatter3d(
                    x=df_profiled['PCA1'],
                    y=df_profiled['PCA2'],
                    z=df_profiled['PCA3'],
                    mode='markers',
                    marker=dict(
                        size=8,
                        color="green",
                        symbol='cross'
                    ),
                    name="P Factor"
                ))
            
            if i == 0:
                state0 = data
            frames.append(
                go.Frame(data=data, 
                         layout=go.Layout(
                             annotations=[
                                dict(
                                    x=1,
                                    y=1,
                                    showarrow=False,
                                    text=f"Iteration: {i + 1}/{len(self.state)}",
                                    xref="paper",
                                    yref="paper",
                                    font=dict(size=14)
                                )
                            ]
                        ),
                         name=str(i)
                        )
            )
        df_pca0 = pd.DataFrame(pca_results, columns=['PC1', 'PC2', 'PC3'])
        min_values = [[df_pca0['PC1'].min()],[df_pca0['PC2'].min()],[df_pca0['PC3'].min()]]
        max_values = [[df_pca0['PC1'].max()],[df_pca0['PC2'].max()],[df_pca0['PC3'].max()]]
        
        fig = go.Figure(
            data=state0,
            layout=go.Layout(
                title="Factor Profile Clustering",
                height=1000,
                width=1000,
                scene=dict(
                    xaxis=dict(range=[min(min_values[0])-1.25, max(max_values[0])+1.25], autorange=False),
                    yaxis=dict(range=[min(min_values[1])-1.25, max(max_values[1])+1.25], autorange=False),
                    zaxis=dict(range=[min(min_values[2])-1.25, max(max_values[2])+1.25], autorange=False),
                    aspectmode="manual",
                    aspectratio=dict(x=1, y=1, z=1)
                ),
                updatemenus=[dict(
                    type="buttons",
                    showactive=False,
                    buttons=[
                        dict(label="Play",
                             method="animate",
                             args=[None, dict(frame=dict(duration=500, redraw=True), fromcurrent=True, mode="immediate")]),
                        dict(label="Pause",
                             method="animate",
                             args=[[None], dict(frame=dict(duration=0, redraw=False), mode="immediate")]),
                        dict(
                            args=[[0], dict(frame=dict(duration=0, redraw=True), mode="immediate")],
                            label="Reset",
                            method="animate"),
                    ]
                )],
                annotations=[
                    dict(
                        x=1,
                        y=1,
                        showarrow=False,
                        text="Iteration: NA",
                        xref="paper",
                        yref="paper",
                        font=dict(size=14)
                    )
                ],
                showlegend=False
            ),
            frames=frames
        )
        # Show the figure
        if to_file:
            pio.write_html(fig, file="factor_clustering.html", auto_open=False)
        else:
            fig.show()
        
    def add_model(self, model: SA, norm: bool = True):
        model_id = self.model_count
        model_factor_ids = []
        norm_H = model.H / np.sum(model.H, axis=0)
        i_model = Model(model_id=model_id)
        for i in range(model.H.shape[0]):
            factor_id = self.factor_count
            self.factor_count += 1
            model_factor_ids.append(factor_id)
            i_H = norm_H if norm else model.H 
            factor = Factor(factor_id=factor_id, profile=i_H[i], model_id=model_id)
            
            i_model.add_factor(factor)
            self.factors[factor_id] = factor
            self.update_ranges(i_H[i])
            
        self.models[str(model_id)] = i_model
        self.model_count += 1

    def compare(self, matrix):
        compare_results = {}
        for i in range(matrix.shape[0]):
            i_H = matrix[i]
            i_cor = 0.0
            best_cluster = None
            for c, cluster in self.clusters.items():
                cluster_cor = self.distance(i_H, cluster.centroid)
                if cluster_cor > i_cor:
                    i_cor = cluster_cor
                    best_cluster = cluster.cluster_id
            compare_results[i] = {"cluster_id": best_cluster, "r2": i_cor}
        return compare_results

    def score(self):
        # iterate over all models, get the membership count the cluster that each factor is mapped to.
        for model_id, model in self.models.items():
            model_score = 0.0
            for factor in model.factors:
                if factor.cluster_id not in self.clusters.keys():
                    logger.info(f"Factor {factor.factor_id} assigned to non-existent cluster {factor.cluster_id}")
                    factor_score = 0
                    # factor.cluster_id = -1
                else:
                    factor_score = len(self.clusters[factor.cluster_id])
                model_score += factor_score
            model.score = model_score

    def update_ranges(self, factor):
        if self.factor_min is None and self.factor_max is None:
            self.factor_min = copy.copy(factor)
            self.factor_max = copy.copy(factor)
        else:
            self.factor_min = np.minimum(self.factor_min, factor)
            self.factor_max = np.maximum(self.factor_max, factor)

    def initialize_clusters(self):
        for k in range(self.n_factors):
            new_centroid = np.zeros(self.n_features)
            for i in range(self.n_features):
                i_v = self.rng.uniform(low=self.factor_min[i], high=self.factor_max[i])
                new_centroid[i] = i_v
            cluster = Cluster(cluster_id=k, centroid=new_centroid)
            self.clusters[k] = cluster

    def purge_clusters(self):
        for c, cluster in self.clusters.items():
            cluster.purge()

    def distance(self, factor1, factor2):
        f1 = np.array(factor1).astype(float)
        f2 = np.array(factor2).astype(float)
        corr_matrix = np.corrcoef(f2, f1)
        corr = corr_matrix[0, 1]
        r_sq = corr ** 2
        return r_sq

    def calculate_centroids(self):
        new_centroid_matrix = []
        for c, cluster in self.clusters.items(): 
            cluster.recalculate()
            new_centroid_matrix.append(cluster.centroid)
        return np.array(new_centroid_matrix)

    def cluster_cleanup(self):
        drop_clusters = set()
        cluster_keys = list(self.clusters.keys())
        for i, i_key in enumerate(cluster_keys[:len(cluster_keys)-1]):
            cluster_i = self.clusters[i_key]
            for j, j_key in enumerate(cluster_keys[i+1:]):
                if j_key == i_key:
                    continue
                cluster_j = self.clusters[j_key]
                ij_cor = self.distance(cluster_i.centroid, cluster_j.centroid)
                if ij_cor > self.threshold:
                    smaller_cluster = i_key if len(cluster_i) < len(cluster_j) else j_key
                    if smaller_cluster not in drop_clusters:
                        drop_clusters.add(smaller_cluster)
        for i_key, cluster in self.clusters.items():
            if len(cluster) == 0:
                drop_clusters.add(i_key)
        # new_clusters = {}
        new_centroid_matrix = []
        for i, cluster in self.clusters.items():
            new_centroid_matrix.append(cluster.centroid)
        #     cluster = self.clusters[i]
        #     if i not in drop_clusters:
        #         new_clusters[i] = cluster
        #         new_centroid_matrix.append(cluster.centroid)
        #     else:
        #         self.dropped_clusters.append(i)
        for cluster in drop_clusters:
            self.clusters[cluster].purge()
        # self.clusters = new_clusters
        return np.array(new_centroid_matrix)

    def save_state(self, iteration):
        factor_assignment = np.array([v.cluster_id for k, v in factor_catalog.factors.items()])
        cluster_centroids = [(c, cluster.centroid) for c, cluster in factor_catalog.clusters.items() if cluster is not None]
        self.state[iteration] = {"assignment":factor_assignment, "cluster_centroids": cluster_centroids}
        self.max_clusters_found = max(self.max_clusters_found, len(cluster_centroids))

    def matrix_difference(self, i_centroids, j_centroids):
        if i_centroids.shape == j_centroids.shape:
            distance = np.linalg.norm(i_centroids - j_centroids, axis=1)
            centroid_shifts = np.mean(distance)
        else:
            min_shape = (min(i_centroids.shape[0], j_centroids.shape[0]), min(i_centroids.shape[1], j_centroids.shape[1]))
            centroid_shifts = np.mean(np.linalg.norm(i_centroids[:min_shape[0], :min_shape[1]] - j_centroids[:min_shape[0], :min_shape[1]], axis=1))
            if i_centroids.shape[0] > j_centroids.shape[0]:
                centroid_shifts += np.mean((len(i_centroids[min_shape[0]:])/i_centroids.shape[0]) * i_centroids[min_shape[0]:])
            else:
                centroid_shifts += np.mean((len(j_centroids[min_shape[0]:])/j_centroids.shape[0]) * j_centroids[min_shape[0]:])
        # logger.info(f"Diff Shift: {centroid_shifts}")
        return centroid_shifts     

    def cluster(self, max_iterations: int = 20, threshold: float = None, early_stopping: bool = True):
        self.initialize_clusters()        
        centroids = self.calculate_centroids()
        converged = False
        current_iter = 0
        threshold = self.threshold if threshold is None else threshold
        with tqdm(total=max_iterations, desc="Running clustering. N Clusters: NA, Added: NA") as pbar:
            while not converged:
                if current_iter >= max_iterations:
                    logger.info(f"Factor clustering did not converge after {max_iterations} iterations.")
                    break
                self.purge_clusters()
    
                model_list = self.rng.permutation(list(self.models.keys()))
                for model_i in model_list:
                    model_factors = [factor.factor_id for factor in self.models[model_i].factors]
                    factor_dist = {}
                    factor_hi = {}
                    # Calculate distances for all factors in the model to all centroids and then order the distances.
                    for factor_i in model_factors:
                        distances = [(j, self.distance(self.factors[factor_i].profile, cluster.centroid)) for j, cluster in self.clusters.items()]
                        distances.sort(key=lambda x: x[1], reverse=True)
                        factor_dist[str(factor_i)] = distances
                        factor_hi[str(factor_i)] = distances[0]
                    already_assigned = []
                    factor_hi = dict(sorted(factor_hi.items(), key=lambda x: x[1], reverse=True))
                    # Assign factors to clusters, if model hasn't contributed to the cluster already and if the correlation is above the threshold
                    for factor_id in factor_hi.keys():
                        # iterate through list of clusters in order of highest correlation.
                        cluster_idx = -1
                        for cluster_i, correlation_i in factor_dist[factor_id]:
                            if cluster_i not in already_assigned and correlation_i >= threshold:
                                cluster_idx = cluster_i
                                break
                        if cluster_idx != -1:
                            self.clusters[cluster_idx].add(factor=self.factors[int(factor_id)], cor=factor_hi[factor_id][1])
                            already_assigned.append(cluster_idx)
                        else:
                            new_cluster_id = self.dropped_clusters.pop(0) if len(self.dropped_clusters) > 0 else len(self.clusters)
                                
                            new_cluster = Cluster(cluster_id=new_cluster_id, centroid=self.factors[int(factor_id)].profile)
                            new_cluster.add(factor=self.factors[int(factor_id)], cor=1.0)
                            self.clusters[new_cluster_id] = new_cluster
                            already_assigned.append(new_cluster_id)
    
                # Recalculate centroids of clusters
                self.save_state(iteration=current_iter)
                new_centroids = self.calculate_centroids()

                if (self.matrix_difference(i_centroids=new_centroids, j_centroids=centroids) < 0.0001 and current_iter > 5 and early_stopping) or (current_iter >= max_iterations):
                    converged = True
                    
                # if not converged:
                #     # pass
                #     new_centroids = self.cluster_cleanup()
   
                pbar.update(1) 
                pbar.set_description(f"Running clustering. N Clusters: {len(new_centroids)}, Added: {len(new_centroids) - len(centroids)}")
                centroids = new_centroids
                current_iter += 1
        self.score()

In [None]:
rng = np.random.default_rng(seed)

In [None]:
%%capture

# factors = syn_factors
factors = 5
method = "ls-nmf"                   # "ls-nmf", "ws-nmf"
models = 20                         # the number of models to train
init_method = "col_means"           # default is column means "col_means", "kmeans", "cmeans"
seed = 42                           # random seed for initialization
max_iterations = 20000              # the maximum number of iterations for fitting a model
converge_delta = 0.1                # convergence criteria for the change in loss, Q
converge_n = 25                     # convergence criteria for the number of steps where the loss changes by less than converge_delta

samples_n = V.shape[0]
batch_size = 250
max_batches = 20
i_batches = 0
n_models = 10
max_iter = 20000

i_selection = rng.choice(V.shape[0], size=batch_size, replace=False, shuffle=True)
i_V, i_U = prepare_data(V=V, U=U, i_selection=i_selection)

logger.info(f"Total V: {V.shape[0]}, Batch Size: {batch_size}, factors: {factors}")

factor_catalog = BatchFactorCatalog(n_factors=factors, n_features=i_V.shape[1], threshold=0.8, seed=42)

t0 = time.time()
change_p = 1.0
with tqdm(range(max_batches), desc="Generating subset profiles") as pbar:
    for i in range(max_batches):
        j_selection = rng.choice(i_V.shape[0], size=int(batch_size*change_p), replace=False, shuffle=True)
        idx_change = rng.choice(batch_size, size=int(batch_size*change_p), replace=False, shuffle=True)
        i_selection[idx_change] = j_selection
        i_V, i_U = prepare_data(V=V, U=U, i_selection=i_selection)

        batch_sa = BatchSA(V=i_V, U=i_U, factors=factors, models=n_models, method=method, seed=rng.integers(low=0, high=1e8), max_iter=max_iter,
                            converge_delta=converge_delta, converge_n=converge_n, verbose=False)
        _ = batch_sa.train()
        pbar.update(1)
        for sa in batch_sa.results:
            factor_catalog.add_model(model=sa, norm=True)
        
        pbar.set_description(f"Generating subset profiles.")
t1 = time.time()
logger.info(f"Runtime: {((t1-t0)/60):.2f} min(s)")

In [None]:
t2 = time.time()
factor_catalog.cluster(max_iterations=50)
t3 = time.time()

In [None]:
# cluster_distances = []
# cluster_keys = list(factor_catalog.clusters.keys())
# for i in range(len(cluster_keys)):
#     i_key = cluster_keys[i]
#     i_centroid = factor_catalog.clusters[i_key].centroid
#     for j in range(i, len(cluster_keys)):
#         if i != j:
#             j_key = cluster_keys[j]
#             j_centroid = factor_catalog.clusters[j_key].centroid
#             dist = float(np.linalg.norm(i_centroid - j_centroid))
#             cluster_distances.append(dist)

# print(f"Cluster Mean: {np.mean(cluster_distances):.2f}, Var: {np.var(cluster_distances):.2f}, STD: {np.std(cluster_distances):.2f}")   

In [None]:
factor_catalog.plot(method="tsne", membership_p=0.03)
factor_catalog.plot(method="ae", membership_p=0.03)
factor_catalog.plot(method="mds", membership_p=0.03)

In [None]:
model_scores = {}
for model_id, model in factor_catalog.models.items():
    model_factors = []
    for factor in model.factors:
        model_factors.append({"id": factor.factor_id, "cluster_id": factor.cluster_id, "cluster_count": len(factor_catalog.clusters[factor.cluster_id])})
    model_scores[model_id] = {"score": model.score, "factors": model_factors}
model_scores = dict(sorted(model_scores.items(), key=lambda item: item[1]["score"], reverse=True))
model_scores

In [None]:
def select_unique(fg, count_threshold: int = 1):
    model_scores = {}
    cluster_count = {}
    for model_id, model in fg.models.items():
        model_factors = []
        for factor in model.factors:
            model_factors.append({"id": factor.factor_id, "cluster_id": factor.cluster_id, "cluster_count": len(fg.clusters[factor.cluster_id])})
            if factor.cluster_id not in cluster_count.keys():
                cluster_count[factor.cluster_id] = len(fg.clusters[factor.cluster_id])
        model_scores[model_id] = {"score": model.score, "factors": model_factors}
    model_scores = dict(sorted(model_scores.items(), key=lambda item: item[1]["score"], reverse=True))

    added_map = []
    unique_models = {}
    for model_id, model in model_scores.items():
        model_mapping = {"score": model["score"]}
        cluster_mapping = {}
        f_map = []
        for factor in model["factors"]:
            cluster_mapping[factor["id"]] = factor["cluster_id"]
            f_map.append(factor["cluster_id"])
        f_map = sorted(f_map)
        if len(added_map) == 0:
            added_map.append(f_map)
            unique_models[model_id] = cluster_mapping
        else:
            add = False
            for i_map in added_map:
                if f_map != i_map:
                    dif = list(set(f_map) - set(i_map))
                    for d in dif:
                        if cluster_count[d] > count_threshold:
                            add = True
            if add:
                added_map.append(f_map)
                unique_models[model_id] = cluster_mapping
    return unique_models
            

In [None]:
unique_models = select_unique(fg=factor_catalog)

In [None]:
profile_batches = min(20, len(unique_models.keys()))
profile_H = []
for model_i, mapping in unique_models.items():
    i_factors = list(unique_models[model_i].values())
    i_H = np.array([cluster.centroid for c, cluster in factor_catalog.clusters.items() if cluster.cluster_id in i_factors])
    profile_H.append(i_H)
    if len(profile_H) == profile_batches:
        break
profile_H = np.array(profile_H)

profile_H.shape

In [None]:
i = list(unique_models.keys())[0]
i_factors = list(unique_models[i].values())
i_H = np.array([cluster.centroid for c, cluster in factor_catalog.clusters.items() if cluster.cluster_id in i_factors])
i_H = i_H / i_H.sum(axis=0)

In [None]:
%%capture

ta0 = time.time()
run_batch = True
batch_seed = rng.integers(low=0, high=1e8)
if run_batch:
    batch_sa = BatchSA(V=V, U=U, factors=factors, models=20, method=method, seed=batch_seed, max_iter=50000,
                       converge_delta=converge_delta, converge_n=converge_n, verbose=False)
    _ = batch_sa.train()
    batch_sa.details()
else:
    base_sa = SA(factors=factors, method=method, V=V, U=U, seed=seed, verbose=True)
    base_sa.initialize(H=None, W=None)
    _ = base_sa.train(max_iter=50000, converge_delta=converge_delta, converge_n=converge_n)
    base_sa.summary()
ta1 = time.time()
logger.info(f"Total BatchSA Runtime: {(ta1-ta0)/60:.2f} min(s)")

In [None]:
fc_base = BatchFactorCatalog(n_factors=factors, n_features=V.shape[1], threshold=0.8, seed=42)
for model in batch_sa.results:
    fc_base.add_model(model=model, norm=True)
fc_base.cluster(max_iterations=50)
    

In [None]:
fc_base.plot(method="tsne", membership_p=0.03)
fc_base.plot(method="ae", membership_p=0.03)
fc_base.plot(method="mds", membership_p=0.03)

In [None]:
%%capture
t4 = time.time()
if run_batch:
    #TODO: Stacking W with in-order sequence of smaller subsets until its the same number of samples as V
    # profile_batch = BatchSA(V=V, U=U, H=profile_H, factors=factors, models=profile_batches, method=method, seed=batch_seed, max_iter=2000,
    #                         converge_delta=converge_delta, converge_n=converge_n, verbose=False, hold_h=True)
    # _ = profile_batch.train()
    # batch_H = [model.H for model in profile_batch.results]
    # batch_W = [model.W for model in profile_batch.results]
    # profile_batch2 = BatchSA(V=V, U=U, H=batch_H, W=batch_W, factors=factors, models=profile_batches, method=method, seed=batch_seed, max_iter=500,
    #                          converge_delta=converge_delta, converge_n=converge_n, verbose=False)
    # _ = profile_batch2.train()
    # profile_batch2.details()
    profile_batch2 = BatchSA(V=V, U=U, H=profile_H, W=None, factors=factors, models=profile_batches, method=method, seed=batch_seed, max_iter=20000,
                             converge_delta=converge_delta, converge_n=converge_n, delay_h=750, verbose=False)
    _ = profile_batch2.train()
    profile_batch2.details()
else:
    final_sa1 = SA(factors=factors, method=method, V=V, U=U, seed=seed, verbose=True)
    final_sa1.initialize(H=i_H, W=None)
    _ = final_sa1.train(max_iter=1000, converge_delta=converge_delta, converge_n=converge_n, hold_h=True)
    final_sa2 = SA(factors=factors, method=method, V=V, U=U, seed=seed, verbose=True)
    final_sa2.initialize(H=final_sa1.H, W=final_sa1.W)
    _ = final_sa2.train(max_iter=10000, converge_delta=converge_delta, converge_n=converge_n, hold_h=False)
    final_sa2.summary()
t5 = time.time()
total_runtime = ((t1-t0) + (t3-t2) + (t5-t4))/60
logger.info(f"Total Profiler Runtime: {(t5-t4)/60:.2f} min(s)")

In [None]:
fc_profiled = BatchFactorCatalog(n_factors=factors, n_features=V.shape[1], threshold=0.8, seed=42)
for model in profile_batch2.results:
    fc_profiled.add_model(model=model, norm=True)
fc_profiled.cluster(max_iterations=50)

In [None]:
fc_profiled.plot(method="tsne", membership_p=0.03)
fc_profiled.plot(method="ae", membership_p=0.03)
fc_profiled.plot(method="mds", membership_p=0.03)

In [None]:
# if run_batch:
#     base_matrix = np.vstack([model.H for model in batch_sa.results])
#     profile_matrix = np.vstack([model.H for model in profile_batch.results])
# else:
#     base_matrix=base_sa.H
#     profiled_matrix=final_sa2.H
# factor_catalog.animate(base_matrix=base_matrix, profiled_matrix=profile_matrix)

In [None]:
# factor_catalog.animate(base_matrix=final_sa2.H)

In [None]:
# base_test = factor_catalog.compare(matrix=base_sa.H)
# base_test

In [None]:
# final_test = factor_catalog.compare(matrix=final_sa2.H)
# final_test

In [None]:
def distance(factor1, factor2):
    f1 = np.array(factor1).astype(float)
    f2 = np.array(factor2).astype(float)
    corr_matrix = np.corrcoef(f2, f1)
    corr = corr_matrix[0, 1]
    r_sq = corr ** 2
    return r_sq

complete_mapping = {}
for i_key, i_model in enumerate(profile_batch2.results):
    n_factors = i_model.H.shape[0]
    i_h = i_model.H
    best_j = -1
    best_cor = 0.0
    i_results = {}
    for j_key, j_base in enumerate(batch_sa.results):
        j_h = j_base.H
        final_base_test = np.zeros(shape=(n_factors, n_factors))
        for i in range(n_factors):
            i_base = i_h[i]
            for j in range(n_factors):
                j_final = j_h[j]
                ij_cor = distance(factor1=i_base, factor2=j_final)
                final_base_test[i, j] = ij_cor
        final_matrix = csr_matrix(final_base_test)
        final_mapping = min_weight_full_bipartite_matching(final_matrix, maximize=True)
        final_mapping = list(zip(final_mapping[0], final_mapping[1]))
        final_cor = float(np.round(np.mean([final_matrix[i] for i in final_mapping]),4))
        # i_results[j_key] = {"map": final_mapping, "cor": final_cor}
        if final_cor > best_cor:
            best_cor = final_cor
            best_j = j_key
    i_results["best"] = best_j
    i_results["best_cor"] = best_cor
    complete_mapping[i_key] = i_results    

In [None]:
complete_mapping

In [None]:
def animate(factor_catalog, method="mds"):
    all_factors = np.array([v.profile for k, v in factor_catalog.factors.items()])
    model_assignment = np.array([v.model_id for k, v in factor_catalog.factors.items()])

    color_map = factor_catalog.generate_continuous_colormap(len(factor_catalog.clusters), colormap_name='rainbow')
    full_colormap = dict(zip(list(factor_catalog.clusters.keys()), [color[1] for color in color_map]))

    sample_lengths = []
    samples = []
    for i in range(len(factor_catalog.state)):
        cluster_centroids = factor_catalog.state[i]["cluster_centroids"]
        i_cluster, i_centroids = zip(*cluster_centroids)
        i_centroids = np.array(i_centroids)
        samples.append(i_centroids)
        sample_lengths.append(len(i_centroids))
    i_centroids = np.vstack(samples)
    
    if method.lower() == "tsne":
        samples = np.vstack((all_factors, i_centroids))
        factor_model = TSNE(n_components=3, random_state=0, perplexity=min(50, len(samples)))
        reduction_results = factor_model.fit_transform(samples)
        factor_reduction = reduction_results[:len(all_factors),:]
        pca_centroids = reduction_results[len(all_factors):,:]
    elif method.lower() == "mds":
        factor_model = MDS(n_components=3, random_state=0, metric=True, max_iter=300)
        reduction_results = factor_model.fit_transform(np.vstack((all_factors, i_centroids)))
        factor_reduction = reduction_results[:len(all_factors),:]
        pca_centroids = reduction_results[len(all_factors):,:]
    else:
        autoencoder, encoder = build_autoencoder(input_dim=i_centroids.shape[1], encoding_dim=3)
        autoencoder.compile(optimizer='adam', loss='mse')
        autoencoder.fit(all_factors, all_factors, epochs=50, batch_size=16, shuffle=True, validation_split=0.2, verbose=0)
        factor_reduction = encoder.predict(all_factors, verbose=0)
        pca_centroids = encoder.predict(i_centroids, verbose=0)
    cumulative_lengths = np.cumsum([0] + sample_lengths)
    state_centroids = [pca_centroids[cumulative_lengths[i]: cumulative_lengths[i+1]] for i in range(len(sample_lengths))]

    df_pca0 = pd.DataFrame(factor_reduction, columns=['x', 'y', 'z'])
    df_pca0["text"] = "P: " + df_pca0.index.astype(str) + ", M: " + model_assignment.astype(str)
    
    state0 = [
        go.Scatter3d(
            x=df_pca0["x"],
            y=df_pca0["y"],
            z=df_pca0["z"],
            mode='markers',
            marker=dict(
                size=3,
                color="gray",
                opacity=0.5
            ),
            text=df_pca0["text"],
            hoverinfo='text')
    ]
    frames = []
    for i in range(len(factor_catalog.state)):
        cluster_centroids = factor_catalog.state[i]["cluster_centroids"]
        i_cluster, i_centroids = zip(*cluster_centroids)
        factor_assignments = factor_catalog.state[i]["assignment"]
        i_centroids = state_centroids[i]
            
        i_length = len(i_cluster)
        i_centroids = np.array(i_centroids)
        
        df_pca = pd.DataFrame(factor_reduction, columns=['x', 'y', 'z'])
        factor_columns = df_pca.columns
        df_pca["Cluster"] = factor_assignments
        df_pca["text"] = "P: " + df_pca0.index.astype(str) + ", M: " + model_assignment.astype(str) + ", C:" + df_pca["Cluster"].astype(str)
        
        df_centroids = pd.DataFrame(i_centroids, columns=['x', 'y', 'z'], index=list(i_cluster))
        cluster_columns = df_centroids.columns
        assigned_centroids, cluster_size = np.unique(factor_assignments, return_counts=True)
        df_centroids["text"] = "Centroid " + df_centroids.index.astype(str)
        df_centroids["Cluster"] = list(i_cluster)
        df_centroids = df_centroids.loc[assigned_centroids]
        df_centroids["count"] = cluster_size
        df_centroids = df_centroids[df_centroids["count"] > 3]
            
        df_pca["color"] = df_pca["Cluster"].map(full_colormap)
        df_centroids["color"] = df_centroids["Cluster"].map(full_colormap)

        data = [
            go.Scatter3d(
                x=df_pca["x"],
                y=df_pca["y"],
                z=df_pca["z"],
                mode='markers',
                marker=dict(
                    size=3,
                    color=df_pca["color"],
                    opacity=0.75
                ),
                text=df_pca["text"],
                hoverinfo='text'
            ),
            go.Scatter3d(
                x=df_centroids['x'],
                y=df_centroids['y'],
                z=df_centroids['z'],
                mode='markers',
                marker=dict(
                    size=3,
                    color="black",
                    symbol='x',
                ),
                text=df_centroids["text"],
                hoverinfo='text'
            )
        ]
        frames.append(
            go.Frame(data=data, 
                     layout=go.Layout(
                         annotations=[
                            dict(
                                x=1,
                                y=1,
                                showarrow=False,
                                text=f"Iteration: {i + 1}/{len(factor_catalog.state)}",
                                xref="paper",
                                yref="paper",
                                font=dict(size=14)
                            )
                        ]
                    ),
                     name=str(i)
                    )
        )
    df_pca0 = pd.DataFrame(factor_reduction, columns=['x', 'y', 'z'])
    min_values = [[df_pca0['x'].min()],[df_pca0['y'].min()],[df_pca0['z'].min()]]
    max_values = [[df_pca0['x'].max()],[df_pca0['y'].max()],[df_pca0['z'].max()]]
    
    fig = go.Figure(
        data=state0,
        layout=go.Layout(
            title="Factor Profile Clustering",
            height=1000,
            width=1000,
            # scene=dict(
            #     xaxis=dict(range=[min(min_values[0])-1.25, max(max_values[0])+1.25], autorange=False),
            #     yaxis=dict(range=[min(min_values[1])-1.25, max(max_values[1])+1.25], autorange=False),
            #     zaxis=dict(range=[min(min_values[2])-1.25, max(max_values[2])+1.25], autorange=False),
            #     aspectmode="manual",
            #     aspectratio=dict(x=1, y=1, z=1)
            # ),
            updatemenus=[dict(
                type="buttons",
                showactive=False,
                buttons=[
                    dict(label="Play",
                         method="animate",
                         args=[None, dict(frame=dict(duration=500, redraw=True), fromcurrent=True, mode="immediate")]),
                    dict(label="Pause",
                         method="animate",
                         args=[[None], dict(frame=dict(duration=0, redraw=False), mode="immediate")]),
                    dict(
                        args=[[0], dict(frame=dict(duration=0, redraw=True), mode="immediate")],
                        label="Reset",
                        method="animate"),
                ]
            )],
            # annotations=[
            #     dict(
            #         x=1,
            #         y=1,
            #         showarrow=False,
            #         text="Iteration: NA",
            #         xref="paper",
            #         yref="paper",
            #         font=dict(size=14)
            #     )
            # ],
            showlegend=False
        ),
        frames=frames
    )
    # Show the figure
    fig.show()

In [None]:
animate(factor_catalog=factor_catalog)