In [None]:
## Import Libraries
import numpy as np
import pandas as pd
import copy
import glob

from tqdm.auto import trange

import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch import Tensor
from torch.utils.data import DataLoader, SubsetRandomSampler, TensorDataset
import torch.optim as optim

from torchvision import datasets, transforms


import matplotlib.pyplot as plt

# from mpl_toolkits import mplot3d
from matplotlib.colors import ListedColormap

import seaborn as sns

from sklearn.decomposition import PCA, KernelPCA

from sklearn.metrics import adjusted_rand_score, accuracy_score, davies_bouldin_score
from sklearn.model_selection import train_test_split, RandomizedSearchCV

from sklearn.cluster import KMeans, SpectralClustering
from sklearn.mixture import GaussianMixture
from sklearn.svm import SVC

# Used to save data into files
import pickle as pkl
import os

# Used to measure time
import time

In [None]:
custom_colors = [
    "#a6cee3",
    "#1f78b4",
    "#b2df8a",
    "#33a02c",
    "#fb9a99",
    "#e31a1c",
    "#fdbf6f",
    "#ff7f00",
    "#cab2d6",
    "#6a3d9a",
]
sns.set_palette(custom_colors)

ATOL = 0.05

In [None]:
folder_path = "./results/"  # Replace with the actual path to your folder
file_pattern = "alphas_df*.csv"

# Use glob to match files based on the pattern
data_files = glob.glob(f"{folder_path}/{file_pattern}")
data_files.sort()
print(data_files)

In [None]:
def plot_alphas(ax, batch_alpha_0, batch_alpha_1, batch_alpha_2):
    xrange = range(len(batch_alpha_0))
    ax.plot(xrange, batch_alpha_0, label="Alpha 0")
    ax.plot(xrange, batch_alpha_1, label="Alpha 1")
    ax.plot(xrange, batch_alpha_2, label="Alpha 2")
    # ax.set_xlabel("Number of seen batches")
    # ax.set_ylabel("Alpha", rotation=0, labelpad=20)
    ax.legend(loc="upper right")
    ax.set_ylim(0, 1)
    # ax.set_title(
    #     f"Alpha0: {round(batch_alpha_0[0].item(), 2)}, Alpha1: {round(batch_alpha_0[1].item(), 2)}, Alpha2: {round(batch_alpha_0[2].item(), 2)}",
    #     fontweight="bold",
    # )
    return ax

In [None]:
lambda_range = np.arange(start=0, stop=0.051, step=0.005)

for f_counter, file in enumerate(data_files):
    for counter, lam in enumerate(lambda_range):
        df = pd.read_csv(file)

        df = df.dropna()

        fig, ax = plt.subplots(2, 3, figsize=(12, 8), dpi=200)
        ax = ax.flatten()
        for i in range(6):
            condition = (df["lambda"] == lam) & (df["iteration"] == i)
            ax[i] = plot_alphas(
                ax[i],
                df.loc[condition, "alpha0"],
                df.loc[condition, "alpha1"],
                df.loc[condition, "alpha2"],
            )
            ax[i].legend(loc="upper center")
        ax[0].set_xlabel("Number of seen batches", fontweight="bold")
        ax[0].set_ylabel(r"$\alpha$", rotation=0, labelpad=20, fontweight="bold")
        title = r"$\alpha$ values for $\lambda$ = " + f"{round(lam, 3)}"
        fig.suptitle(title, size=20, fontweight="bold")
        fig.subplots_adjust(top=0.90)
        # plt.legend(loc='upper center')
        # plt.show()
        plt.savefig(f"./alphas_plots/alphas_{f_counter}_{counter}.png")

In [None]:
# Initialize an empty list to store DataFrames
dfs_to_concat = []

for i, file in enumerate(data_files):
    df = pd.read_csv(file)
    lambda_range = df["lambda"].unique()

    # Initialize an empty DataFrame for this file
    df_for_file = pd.DataFrame(
        columns=["df", "lambda", "alpha0", "alpha1", "alpha2", "total"]
    )

    # Iterate over unique lambda values
    for Lambda in lambda_range:
        # Gets the last row for each iteration with the given lambda
        df_lambda = df[df["lambda"] == Lambda].groupby("iteration").tail(1)

        percentages_lambda = {}
        for col in ["alpha0", "alpha1", "alpha2"]:
            percentage_zero = np.isclose(df_lambda[col], 0, atol=ATOL).mean() * 100
            percentages_lambda[col] = percentage_zero

        # Add a row for this lambda to the DataFrame
        df_for_lambda = pd.DataFrame(
            {
                "df": [i],
                "lambda": [Lambda],
                "alpha0": [percentages_lambda.get("alpha0", np.nan)],
                "alpha1": [percentages_lambda.get("alpha1", np.nan)],
                "alpha2": [percentages_lambda.get("alpha2", np.nan)],
            }
        )
        df_for_file = pd.concat([df_for_file, df_for_lambda], ignore_index=True)


    # Calculate the total percentage for this file
    alphacols = ["alpha0", "alpha1", "alpha2"]
    df_for_file["total"] = df_for_file[alphacols].mean(axis=1)

    # Add the DataFrame for this file to the list
    dfs_to_concat.append(df_for_file)


# Concatenate all DataFrames into the final result
percentages_df = pd.concat(dfs_to_concat, ignore_index=True)
percentages_df = percentages_df.dropna()


In [None]:
fig, ax = plt.subplots(figsize=(10, 5), dpi=200)
for df in percentages_df["df"].unique():
    ax.plot(
        percentages_df[percentages_df["df"] == df]["lambda"],
        percentages_df[percentages_df["df"] == df]["total"],
        marker="o",
        label=df,
    )
ax.set_xlabel(r"$\lambda$", fontweight="bold", fontsize=14)
ax.set_ylabel(
    r'% of $\alpha$' + '\nthat go to zero',
    rotation=0,
    labelpad=70,
    fontweight="bold",
    fontsize=14,
)
ax.set_xticks(np.arange(0, 0.051, 0.005))

ax.set_yticks(np.arange(0, 101, 10))
ax.legend(title="Data file")
ax.grid(True, axis="y")
plt.savefig("./plots/alphas_total.png")

## Distribution of converged values

In [None]:
# Getting all the converged values
for file in data_files:
    final_values = pd.DataFrame(columns=["lambda", "alpha0", "alpha1", "alpha2"])
    df = pd.read_csv(file)
    lambda_range = df["lambda"].unique()

    for Lambda in lambda_range:
        df_lambda = df[df["lambda"] == Lambda].groupby("iteration").tail(1)
        final_values = pd.concat(
            [final_values, df_lambda[["lambda", "alpha0", "alpha1", "alpha2"]]],
            ignore_index=True,
        )

In [None]:
# Melt the DataFrame to have a single column for values and another for variable names
melted_values = final_values.melt(var_name="variable", value_name="value", id_vars="lambda")
print(melted_values)

In [None]:

g = sns.FacetGrid(
    melted_values,
    row="variable",
    hue="variable",
    palette=custom_colors,
    margin_titles=True
)
g.map(sns.stripplot, "value")
g.set(xlim=(0, 1))
g.set_titles("{value}", fontweight="bold")

g.set_ylabels("Count", rotation=0, labelpad=20, fontweight="bold")
g.set_xlabels(r"$\alpha$", fontweight="bold", fontsize=14)
g.fig.set_size_inches(10, 5.5)
g.fig.set_dpi(200)
g.fig.tight_layout()
plt.show()  


# fig, ax = plt.subplots(figsize=(10, 5), dpi=100)
# sns.stripplot(data=final_values, x="alpha0", ax=ax, label="alpha0")
# sns.stripplot(data=final_values, x="alpha1", ax=ax, label="alpha1")
# sns.stripplot(data=final_values, x="alpha2", ax=ax, label="alpha2")
# ax.legend()

In [None]:
fig, ax = plt.subplots(figsize=(10, 5), dpi=200)
sns.scatterplot(
    data=melted_values,
    x="lambda",
    y="value",
    hue="variable",
    palette=custom_colors
)
ax.set_xlabel(r"$\lambda$", fontweight="bold", fontsize=14)
ax.set_ylabel(r"$\alpha$", fontweight="bold", fontsize=14, rotation=0, labelpad=20)

fig.savefig("./plots/scatter.png")

In [None]:
bin_edges = [i / 10 for i in range(11)]

g = sns.FacetGrid(
    melted_values,
    row="variable",
    col="lambda",
    hue="variable",
    palette=custom_colors,
    margin_titles=True
)
g.map(sns.histplot, "value", bins=bin_edges)
g.set(xlim=(0, 1))
g.set_titles("{lambda}", fontweight="bold")

g.set_ylabels("Count", rotation=0, labelpad=20, fontweight="bold")
g.set_xlabels(r"$\alpha$", fontweight="bold")
g.fig.set_size_inches(15, 5.5)
g.fig.set_dpi(200)
g.fig.tight_layout()
fig.savefig("./plots/alpha_hist.png")