In [None]:
import torch
import torchvision
import json
import os
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from tqdm import tqdm
import random
from collections import defaultdict
from datetime import datetime

from torchvision.datasets import MNIST

from misc.datasets_ae import *

from umap import UMAP
from sklearn.manifold import TSNE
from sklearn.preprocessing import LabelEncoder

import matplotlib.animation
from IPython.display import HTML

from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_samples
from joblib import Parallel, delayed

import scipy
import skdim

# seeding
SEED = 91
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
random.seed(SEED)
np.random.seed(SEED)

In [None]:
def get_mask_from_region__numfound(stat_for_mask, stat_mean, 
                         eps, 
                         eps_max,
                         n,
                         eps_step=2e-5,
                         patience=1e+5,
                        ):
    
    assert len(stat_for_mask) > n, "'n' >= Max number of values in stat"
    
    mask_out = (stat_for_mask < (stat_mean+eps)) * (stat_for_mask > (stat_mean-eps))
    
    counter = 0
    while (sum(mask_out) < n) and (eps <= eps_max) and (counter < patience):
        eps += eps_step
        mask_out = (stat_for_mask < (stat_mean+eps)) * (stat_for_mask > (stat_mean-eps))
        
        counter += 1
        
    n_found = sum(mask_out)
    return mask_out, n_found

def get_num_samples_in_interval(
    stat_for_mask, 
    stat_mean, 
    eps_max,
): 
    mask_out = (stat_for_mask < (stat_mean+eps_max)) * (stat_for_mask > (stat_mean-eps_max))
    n_found = sum(mask_out)
    
    return n_found


In [None]:
def plot_img_zoom_for_stat__imghist(stat_values, thr_l, thr_r,
                           name_stat,
                           N_IMG_ZOOM=3, 
                           N_IMG_PER_ZOOM=3,
                           figsize=(6,6),
                           log_hist_y = False,
                          ):

    sil_scores = stat_values
    stat_for_mask = stat_values
    
    right_thr_l__zoom = thr_l
    right_thr_r__zoom = thr_r
    eps_start__zoom = 1e-5

    sil_scores_2_plot = np.linspace(right_thr_l__zoom, right_thr_r__zoom, N_IMG_ZOOM)

    #####
    fig, ax = plt.subplots(figsize=(16,5))
    for x_i in sil_scores_2_plot:
        ax.axvline(x_i, 0, 1000, c='darkorange', linestyle='dashed')

    ax.hist(sil_scores, bins=N_IMG_ZOOM)
    ax.set_xlim(thr_l-(0.01*(thr_r-thr_l)), 1.01*thr_r)
    if log_hist_y:
        ax.set_yscale('log')
    x_start, x_end = ax.get_xlim()
    ax.xaxis.set_ticks(np.linspace(x_start, x_end, max(N_IMG_ZOOM//2, 5)))
    ax.tick_params(axis='x', rotation=70)
    plt.show()

    #####
    right_tale_mask__zoomed = stat_for_mask > right_thr_l__zoom
    right_tale__zoomed = idx_bed[right_tale_mask__zoomed]
    
    #####
    img_per_col = N_IMG_PER_ZOOM

    sil_region_mask_list = []
    n_found_list = []
    for i, stat_score_ in tqdm(enumerate(sil_scores_2_plot)):
        sil_region_mask, _ = get_mask_from_region__numfound(
            stat_values, stat_score_, eps_start__zoom, 
            eps_max=(sil_scores_2_plot[1]-sil_scores_2_plot[0])/2,
            eps_step=1e-2*(sil_scores_2_plot[1]-sil_scores_2_plot[0])/2,
            n=img_per_col,
        )
        n_found = get_num_samples_in_interval(
            stat_values, stat_score_,
            eps_max=(sil_scores_2_plot[1]-sil_scores_2_plot[0])/2,
        )
        
        sil_region_mask_list.append(sil_region_mask)
        n_found_list.append(n_found)

    # Plot images like histogram
    # Norm number of images scaled by found in the region 
    n_found_list = np.array(n_found_list)
    # take log to smooth 
    n_found_list = np.log1p(n_found_list)
    n_found_max = np.max(n_found_list)
    # max number of imgs*N_IMG_PER_ZOOM to plot for a region
    n_found_list_coef = n_found_list / n_found_max 
    
    f, axes = plt.subplots(img_per_col, len(sil_scores_2_plot), figsize=figsize)
    
    for i, (stat_score_, sil_region_mask, n_found_coef) in enumerate(zip(
        sil_scores_2_plot, sil_region_mask_list, n_found_list_coef
    )):     
        ax_col = axes[:, i]
        
        idx_zoom = idx_bed[sil_region_mask]
        np.random.shuffle(idx_zoom)

        for j in np.arange(img_per_col):
            ax_ij = ax_col[j]

            if (j >= len(idx_zoom)):
                ax_ij.set_title(f"None\n{stat_score_:.3e}", fontsize=8)
            else:
                if j == np.floor(n_found_coef*N_IMG_PER_ZOOM):
                    ax_ij.set_title(f"None\n{stat_score_:.3e}", fontsize=8)
                else:
                    if (j > np.floor(n_found_coef*N_IMG_PER_ZOOM)):
                        c_map_str = 'gray'
                    else: 
                        c_map_str = 'viridis'
                    ax_ij.imshow(train_dataset_sums[idx_zoom[j]][0][0], cmap=c_map_str)
                    ax_ij.set_title(f"{train_dataset_sums[idx_zoom[j]][1]}\n{sil_scores[idx_zoom[j]]:.3e}", fontsize=8)


    f.suptitle(f'Images along {name_stat}\n[{thr_l:.3e} - {thr_r:.3e}]', fontsize=20)
    for ax in axes.ravel():
        ax.set_axis_off()   

    plt.show()

In [None]:
logits_filename = "bn_logits_10_600_wimgsums.npy"

results_path = os.path.join("MNIST", "results")
logits_path = os.path.join(results_path, logits_filename)

data = np.load(logits_path)

In [None]:
image_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = MNIST(root="./.cache", download=True,  transform=image_transform)

aug_ratio = 0.1
train_dataset_sums = MNIST_w_imagesums(train_dataset, aug_ratio=aug_ratio)


id_new_full = train_dataset_sums.idxs_aug_full
id_new = train_dataset_sums.idxs_aug

idx_bed = np.arange(len(train_dataset_sums))

In [None]:
train_img = np.array([train_dataset_sums[i][0][0].numpy() for i in range(len(train_dataset_sums))])
train_labels = np.array([y_ for _, y_ in train_dataset_sums])

In [None]:
data.shape

In [None]:
id_new_full

In [None]:
id_new.shape

In [None]:
idxs_saved = np.load(r"C:\Users\MQTyor\ai_pc\Skoltech-ai-courses\ML term 3\Project\work\ae code\v2\Team12_ML24\ssl-ae\MNIST\results\img_indices_10_600_wimgsums.npy")
idxs_saved

In [None]:
plt.imshow(train_img[1+1+54000])

# 2D Projection

In [None]:
umap_2d = UMAP(
    n_components=2, 
#     init='random', 
#     random_state=SEED
)

projector =  umap_2d


# last epoch embeddings
data_project = data[-1]

proj_2d = projector.fit_transform(data_project)

In [None]:
"""
2D plane of embeddings with augmented

"""

f, axes = plt.subplots(1, 3, figsize=(20,6))

# With augmented images
lbl_enc = LabelEncoder()
labels_ordinal = lbl_enc.fit_transform(train_labels[:])
scat = axes[0].scatter(
    proj_2d[:,0][:], 
    proj_2d[:,1][:], 
    c=labels_ordinal, 
    label=labels_ordinal,
    cmap='Spectral'
)
axes[0].legend(*scat.legend_elements(), fontsize=9, loc='lower left')
axes[0].set_title('Original+Distorted images')

print(f"All classes with augmentations:\n{lbl_enc.classes_}")


# Without augmented images
idx_no_sums = np.arange(0, 60_000*0.9, 1, dtype=int)
idx_plot = idx_no_sums
lbl_enc = LabelEncoder()
labels_ordinal = lbl_enc.fit_transform(train_labels[idx_plot])
scat = axes[1].scatter(
    proj_2d[:,0][idx_plot], 
    proj_2d[:,1][idx_plot], 
    c=labels_ordinal, 
    label=labels_ordinal,
    cmap='Spectral'
)
axes[1].legend(*scat.legend_elements(), fontsize=9, loc='lower left')
axes[1].set_title('Original images')

# Only augmented images
idx_no_sums = np.arange(60_000*0.9, 60_000*0.95, 1, dtype=int)
idx_plot = idx_no_sums
lbl_enc = LabelEncoder()
labels_ordinal = lbl_enc.fit_transform(train_labels[idx_plot])
scat = axes[2].scatter(
    proj_2d[:,0][idx_plot], 
    proj_2d[:,1][idx_plot], 
    c=labels_ordinal, 
    label=labels_ordinal,
    cmap='Spectral'
)
axes[2].legend(*scat.legend_elements(), fontsize=9, loc='lower left')
axes[2].set_title('Distorted images')

for ax_i in axes:
    ax_i.set_xlim(-8, 15)
    ax_i.set_ylim(-8, 18)

plt.suptitle('2D UMAP Projection of the Bottleneck layer activations', fontsize=16)
plt.show()



# Sample-wise metrics

In [None]:
metrics_df = pd.DataFrame(index=id_new)

In [None]:
sil_scores_arr = []
def func_par(x_cluster):
    cluster_model = KMeans(n_clusters=10, random_state=SEED, max_iter=200)
    clusters_pred = cluster_model.fit_predict(x_cluster)
    sil_scores = silhouette_samples(x_cluster, clusters_pred)
    return sil_scores
# Only every 10th epoch
results = Parallel(n_jobs=-1)\
    (delayed(func_par)(x_cluster) for x_cluster in tqdm(data[::10]))
sil_scores_arr = np.array(results)
metrics_df['sil_score__last'] = sil_scores_arr[-1]
metrics_df['sil_score__std_from10'] = sil_scores_arr[0:].std(axis=0)
metrics_df['sil_score__mean_from0'] = sil_scores_arr[0:].mean(axis=0)

dif = data[1:] - data[:-1]
l1_norm = np.linalg.norm(dif, ord=1, axis=2)
metrics_df['dif_L1__std_from0'] = l1_norm[0:].std(axis=0)
metrics_df['dif_L1__var_from0'] = l1_norm[0:].var(axis=0)
metrics_df['dif_L1__mean_from0'] = l1_norm[0:].mean(axis=0)

stat_l2 = np.linalg.norm(dif, ord=2, axis=2)
metrics_df['dif_L2__std_from0'] = stat_l2[0:].std(axis=0)
metrics_df['dif_L2__var_from0'] = stat_l2[0:].var(axis=0)
metrics_df['dif_L2__mean_from0'] = stat_l2[0:].mean(axis=0)

l2_embs = np.linalg.norm(data, ord=2, axis=2)
metrics_df['L2__last'] = l2_embs[-1]
metrics_df['L2__var_from0'] = l2_embs[0:].var(axis=0)
metrics_df['L2__var_from20'] = l2_embs[20:].var(axis=0)
metrics_df['L2__var_from250'] = l2_embs[250:].var(axis=0)
metrics_df['L2__std_from0'] = l2_embs[0:].std(axis=0)
metrics_df['L2__std_from20'] = l2_embs[20:].std(axis=0)
metrics_df['L2__std_from250'] = l2_embs[250:].std(axis=0)
metrics_df['L2__mean_from0'] = l2_embs[0:].mean(axis=0)
metrics_df['L2__mean_from20'] = l2_embs[20:].mean(axis=0)

lid_estimator = skdim.id.TLE()
lid_tle_list = []
for epoch_i in tqdm(range(10,len(data),10)):
    lid_tle = lid_estimator.fit_transform_pw(
        X=data[epoch_i],
        n_jobs=-1,
        n_neighbors=30
    )
    lid_tle_list.append(lid_tle)
lid_tle_list = np.array(lid_tle_list)  
metrics_df["LID__last"] = lid_tle_list[-1]
metrics_df["LID__var_from_10"] = lid_tle_list[:].var(axis=0)
metrics_df["LID__std_from_10"] = lid_tle_list[:].std(axis=0)
metrics_df["LID__mean_from_10"] = lid_tle_list[:].mean(axis=0)


stat_h = scipy.stats.entropy(
    pk=data,
    axis=2,
)
metrics_df["H__last"] = stat_h[-1]
metrics_df["H__diff_last_250"] = stat_h[-1]-stat_h[250]
metrics_df["H__diff_last_450"] = stat_h[-1]-stat_h[450]
metrics_df["H__var_from_0"] = stat_h[0:].var(axis=0)
metrics_df["H__var_from_450"] = stat_h[450:].var(axis=0)
metrics_df["H__std_from_0"] = stat_h[0:].std(axis=0)
metrics_df["H__std_from_450"] = stat_h[450:].std(axis=0)
metrics_df["H__mean_from_0"] = stat_h[0:].mean(axis=0)
metrics_df["H__mean_from_400"] = stat_h[400:].mean(axis=0)

metrics_df.to_csv('metrics_600.csv', )

## 1) Silhouette for 10 clusters 

In [None]:
# sil_scores_arr = []

# def func_par(x_cluster):
#     cluster_model = KMeans(n_clusters=10, random_state=SEED, max_iter=200)
#     clusters_pred = cluster_model.fit_predict(x_cluster)
#     sil_scores = silhouette_samples(x_cluster, clusters_pred)
#     return sil_scores

# # Only every 10th epoch
# results = Parallel(n_jobs=-1)\
#     (delayed(func_par)(x_cluster) for x_cluster in tqdm(data[::10]))
    
# sil_scores_arr = np.array(results)

In [None]:
fig, ax = plt.subplots()
def animate(i):
    ax.clear()
    plt.title(f"Sil scores at {10*i} epoch")
    plt.hist(sil_scores_arr[i,:-3000], bins=100);
    plt.hist(sil_scores_arr[i,-3000:], bins=100);
    plt.legend(["Original", 'Distorted'])
    plt.yscale('log')
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(sil_scores_arr))
HTML(ani.to_jshtml())

In [None]:
plt.title(f"Sil scores at {100} epoch")
plt.hist(sil_scores_arr[10,:-3000], bins=100);
plt.hist(sil_scores_arr[10,-3000:], bins=100);
plt.legend(["Original", 'Distorted'])
plt.yscale('log')

In [None]:
plt.title(f"Sil scores at {600} epoch")
plt.hist(sil_scores_arr[-1,:-3000], bins=100);
plt.hist(sil_scores_arr[-1,-3000:], bins=100);
plt.legend(["Original", 'Distorted'])
plt.yscale('log')

> Seems well because the augmeneted images are not separable from the very beginning BASED on the Sil score on that epoch

In [None]:
fig, ax = plt.subplots()
def animate(i):
    ax.clear()
    plt.title(10*i)
    plt.hist(sil_scores_arr[i:,:-3000].std(axis=0), bins=100);
    plt.hist(sil_scores_arr[i:,-3000:].std(axis=0), bins=100);
    plt.title(f"Std of Sil scores starting from {10*i} epoch")
    plt.legend(["Original", 'Distorted'])
    plt.yscale('log')
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(sil_scores_arr))
HTML(ani.to_jshtml())

> Seems well because the augmeneted images are separable from the very beginning

#### Mean of Sil score

In [None]:
fig, ax = plt.subplots()
def animate(i):
    ax.clear()
    plt.title(10*i)
    plt.hist(sil_scores_arr[i:,:-3000].mean(axis=0), bins=100);
    plt.hist(sil_scores_arr[i:,-3000:].mean(axis=0), bins=100);
    plt.title(f"Mean of Sil scores starting from {10*i} epoch")
    plt.legend(["Original", 'Distorted'])
    plt.yscale('log')
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(sil_scores_arr))
HTML(ani.to_jshtml())

In [None]:
stat_plot = sil_scores_arr[0:].mean(axis=0)

plot_img_zoom_for_stat__imghist(
    stat_values=stat_plot,
    thr_l=stat_plot.min(),
    thr_r=stat_plot.max(),
    name_stat="Mean starting from 0 epoch (epochs) of Silhouette score (emb_size) of samples",
    figsize=(50,30),
    N_IMG_ZOOM=40,
    N_IMG_PER_ZOOM=15,
    log_hist_y=False,
)

### Adding metrics

In [None]:
# metrics_df['sil_score__last'] = sil_scores_arr[-1]

# metrics_df['sil_score__std_from10'] = sil_scores_arr[0:].std(axis=0)

# metrics_df['sil_score__mean_from0'] = sil_scores_arr[0:].mean(axis=0)

## 2) Difference between embeddings of adjacent epochs

In [None]:
# # emb_t - emb_(t-1)
# dif = data[1:] - data[:-1]

### 2.1) L1 norm - difference bw adj embs

In [None]:
stat = np.linalg.norm(dif, ord=1, axis=2)

In [None]:
fig, ax = plt.subplots()
def animate(i):
    ax.clear()
    plt.hist(stat[i,:-3000], bins=100);
    plt.hist(stat[i,-3000:], bins=100);
    plt.title(f"L1 norm of difference of adj embs at {i} epoch")
    plt.legend(["Original", 'Distorted'])
    
    plt.yscale('log')
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(stat))
HTML(ani.to_jshtml())

In [None]:
fig, ax = plt.subplots()

def animate(i):
    ax.clear()
    plt.hist(stat[i:,:-3000].std(axis=0), bins=100);
    plt.hist(stat[i:,-3000:].std(axis=0), bins=100);
    plt.title(f"Std of L1 norm of difference of adj embs after {i} epoch")
    plt.yscale('log')
    plt.legend(["Original", 'Distorted'])
    
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(stat))
HTML(ani.to_jshtml())

In [None]:
fig, ax = plt.subplots()

def animate(i):
    ax.clear()
    plt.hist(stat[i:,:-3000].mean(axis=0), bins=100);
    plt.hist(stat[i:,-3000:].mean(axis=0), bins=100);
    plt.title(f"Mean of L1 norm of difference of adj embs after {i} epoch")
    plt.yscale('log')
    plt.legend(["Original", 'Distorted'])
    
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(stat))
HTML(ani.to_jshtml())

In [None]:
# l1_norm = np.linalg.norm(dif, ord=1, axis=2)
# metrics_df['dif_L1__std_from0'] = l1_norm[0:].std(axis=0)
# metrics_df['dif_L1__var_from0'] = l1_norm[0:].var(axis=0)
# metrics_df['dif_L1__mean_from0'] = l1_norm[0:].mean(axis=0)

### 2.2) L2 norm of difference bw adj embs

In [None]:
l2norm = np.linalg.norm(dif, ord=2, axis=2)
stat = l2norm

In [None]:
fig, ax = plt.subplots()
def animate(i):
    ax.clear()
    plt.hist(l2norm[i,:-3000], bins=100);
    plt.hist(l2norm[i,-3000:], bins=100);
    plt.title(f"L2 norm of difference of adj embs at {i} epoch")
#     plt.xlim(stat.min(), stat.max())
    plt.yscale('log')
    plt.legend(["Original", 'Distorted'])
    
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(stat))
HTML(ani.to_jshtml())

In [None]:
fig, ax = plt.subplots()

def animate(i):
    ax.clear()
    plt.hist(l2norm[i:,:-3000].std(axis=0), bins=100);
    plt.hist(l2norm[i:,-3000:].std(axis=0), bins=100);
    plt.title(f"Std of L2 norm of difference of adj embs after {i} epoch")
#     plt.xlim(stat.min(), stat.mean() + 4*stat.std())
    plt.yscale('log')
    plt.legend(["Original", 'Distorted'])
    
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(stat))
HTML(ani.to_jshtml())

In [None]:
fig, ax = plt.subplots()

def animate(i):
    ax.clear()
    plt.hist(l2norm[i:,:-3000].mean(axis=0), bins=100);
    plt.hist(l2norm[i:,-3000:].mean(axis=0), bins=100);
    plt.title(f"Mean of L2 norm of difference of adj embs after {i} epoch")
#     plt.xlim(stat.min(), stat.mean() + 4*stat.std())
    plt.yscale('log')
    plt.legend(["Original", 'Distorted'])
    
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(stat))
HTML(ani.to_jshtml())

In [None]:
# stat_l2 = np.linalg.norm(dif, ord=2, axis=2)
# metrics_df['dif_L2__std_from0'] = stat_l2[0:].std(axis=0)
# metrics_df['dif_L2__var_from0'] = stat_l2[0:].var(axis=0)
# metrics_df['dif_L2__mean_from0'] = stat_l2[0:].mean(axis=0)

### 2.3) Mean of diff of adj embs 

In [None]:
stat = np.mean(dif, axis=2)

In [None]:
fig, ax = plt.subplots()
def animate(i):
    ax.clear()
    plt.hist(stat[i,:-3000], bins=100);
    plt.hist(stat[i,-3000:], bins=100);
    plt.title(f"Mean of difference of adj embs at {i} epoch")
#     plt.xlim(stat.mean()-4*stat.std(), stat.mean()+4*stat.std())
    plt.yscale('log')
    plt.legend(["Original", 'Distorted'])
    
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(stat))
HTML(ani.to_jshtml())

In [None]:
fig, ax = plt.subplots()
def animate(i):
    ax.clear()
    plt.hist(stat[i:,:-3000].mean(axis=0), bins=100);
    plt.hist(stat[i:,-3000:].mean(axis=0), bins=100);
    plt.title(f"Mean of mean of difference of adj embs from {i} epoch")
#     plt.xlim(stat.mean()-4*stat.std(), stat.mean()+4*stat.std())
    plt.yscale('log')
    plt.legend(["Original", 'Distorted'])
    
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(stat))
HTML(ani.to_jshtml())

In [None]:
fig, ax = plt.subplots()
def animate(i):
    ax.clear()
    plt.hist(stat[i:,:-3000].std(axis=0), bins=100);
    plt.hist(stat[i:,-3000:].std(axis=0), bins=100);
    plt.title(f"Std of mean of difference of adj embs from {i} epoch")
#     plt.xlim(
#         stat.std(axis=0).mean()-4*stat.std(axis=0).std(), 
#         stat.std(axis=0).mean()+4*stat.std(axis=0).std()
#     )
    plt.yscale('log')
    plt.legend(["Original", 'Distorted'])
    
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(stat))
HTML(ani.to_jshtml())

## 3) L2 of embeddings

In [None]:
l2_embs = np.linalg.norm(data, ord=2, axis=2)

In [None]:
stat = l2_embs

In [None]:
idx_plot = np.random.choice(l2_embs.shape[1], size=150, replace=False, )

plt.figure(figsize=(20,3))
plt.plot(l2_embs[:, idx_plot])
plt.title("L2 norm of BN embeddings of randomly sampled objects")

plt.show()

### 3.1) at i epoch

In [None]:
fig, ax = plt.subplots()
def animate(i):
    ax.clear()
    plt.hist(stat[i,:-3000], bins=100);
    plt.hist(stat[i,-3000:], bins=100);
    plt.title(f"L2 of embs at {i} epoch")
#     plt.xlim(stat.mean()-4*stat.std(), stat.mean()+4*stat.std())
    plt.yscale('log')
    plt.legend(["Original", 'Distorted'])
    
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(stat))
HTML(ani.to_jshtml())

In [None]:
fig, ax = plt.subplots()
def animate(i):
    ax.clear()
    plt.hist(stat[i:,:-3000].mean(axis=0), bins=100);
    plt.hist(stat[i:,-3000:].mean(axis=0), bins=100);
    plt.title(f"Mean of L2 norm of embs from {i} epoch")
#     plt.xlim(stat.mean()-4*stat.std(), stat.mean()+4*stat.std())
    plt.yscale('log')
    plt.legend(["Original", 'Distorted'])
    
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(stat))
HTML(ani.to_jshtml())

In [None]:
fig, ax = plt.subplots()
def animate(i):
    ax.clear()
    plt.hist(stat[i:,:-3000].std(axis=0), bins=100);
    plt.hist(stat[i:,-3000:].std(axis=0), bins=100);
    plt.title(f"Std of L2 norm of embs from {i} epoch")
#     plt.xlim(
#         stat.std(axis=0).mean()-4*stat.std(axis=0).std(), 
#         stat.std(axis=0).mean()+4*stat.std(axis=0).std()
#     )
    plt.yscale('log')
    plt.legend(["Original", 'Distorted'])
    
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(stat))
HTML(ani.to_jshtml())

In [None]:
# metrics_df['L2__last'] = l2_embs[-1]

# metrics_df['L2__var_from0'] = l2_embs[0:].var(axis=0)
# metrics_df['L2__var_from20'] = l2_embs[20:].var(axis=0)
# metrics_df['L2__var_from250'] = l2_embs[250:].var(axis=0)

# metrics_df['L2__std_from0'] = l2_embs[0:].std(axis=0)
# metrics_df['L2__std_from20'] = l2_embs[20:].std(axis=0)
# metrics_df['L2__std_from250'] = l2_embs[250:].std(axis=0)

# metrics_df['L2__mean_from0'] = l2_embs[0:].mean(axis=0)
# metrics_df['L2__mean_from20'] = l2_embs[20:].mean(axis=0)

## 4) Intrinsic Dim of layer ? 

### 4.1) Int dim 

#### 4.1.1) TLE - skdim 

In [None]:
# %%time

# lid_estimator = skdim.id.TLE()

# lid_tle_list = []

# for epoch_i in tqdm(range(0,len(data),10)):
#     lid_tle = lid_estimator.fit_transform_pw(
#         X=data[epoch_i],
#         n_jobs=-1,
#         n_neighbors=30
#     )
    
#     lid_tle_list.append(lid_tle)
    
# lid_tle_list = np.array(lid_tle_list)   

In [None]:
# lid_estimator = skdim.id.TLE()
# lid_tle_list = []
# for epoch_i in tqdm(range(10,len(data),10)):
#     lid_tle = lid_estimator.fit_transform_pw(
#         X=data[epoch_i],
#         n_jobs=-1,
#         n_neighbors=30
#     )
#     lid_tle_list.append(lid_tle)
# lid_tle_list = np.array(lid_tle_list)  


In [None]:
plt.hist(lid_tle_list[-1,:-3000], bins=100);
plt.hist(lid_tle_list[-1,-3000:], bins=100);
plt.title(f"LID of embs at {list(range(0,len(data),10))[-1]} epoch")
plt.yscale('log')
plt.legend(["Original", 'Distorted'])

In [None]:
plt.hist(lid_tle_list[1:,:-3000].std(axis=0), bins=100);
plt.hist(lid_tle_list[1:,-3000:].std(axis=0), bins=100);
plt.title(f"STD of LID of embs from {10*1} epoch")
plt.yscale('log')
plt.legend(["Original", 'Distorted'])

In [None]:
plt.hist(lid_tle_list[1:,:-3000].mean(axis=0), bins=100);
plt.hist(lid_tle_list[1:,-3000:].mean(axis=0), bins=100);
plt.title(f"Mean of LID of embs from {10*1} epoch")
plt.yscale('log')
plt.legend(["Original", 'Distorted'])

In [None]:
stat_plot = lid_tle_list[1:].mean(axis=0)

plot_img_zoom_for_stat__imghist(
    stat_values=stat_plot,
    thr_l=stat_plot.min(),
    thr_r=stat_plot.max(),
    name_stat="Std after 7th epoch (epochs) of loss per sample (emb_size) of samples",
    figsize=(70,25),
    N_IMG_ZOOM=60,
    N_IMG_PER_ZOOM=15,
    log_hist_y=True,
)

In [None]:
# plt.hist(lid_tle[:-3000], bins=100);
# plt.hist(lid_tle[-3000:], bins=100);
    
# plt.yscale('log')
# plt.show()

In [None]:
%%time
fig, ax = plt.subplots()
def animate(i):
    ax.clear()
    plt.hist(lid_tle_list[1:][i,:-3000], bins=100);
    plt.hist(lid_tle_list[1:][i,-3000:], bins=100);
    plt.title(f"LID of embs at {100*i} epoch")
    plt.xlim(
        lid_tle_list[1:].mean()-4*lid_tle_list[1:].std(), 
        lid_tle_list[1:].mean()+4*lid_tle_list[1:].std()
    )
    plt.yscale('log')
    plt.legend(["Original", 'Distorted'])
    
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(lid_tle_list[1:]))
HTML(ani.to_jshtml())

In [None]:
idx_plot = np.random.choice(lid_tle_list.shape[1], size=50, replace=False, )

plt.figure(figsize=(20,3))
plt.plot(lid_tle_list[1:, idx_plot])
plt.title("LID of BN layer of randomly sampled objects")

plt.show()

## 5) Entropy of a layer

In [None]:
# stat_h = scipy.stats.entropy(
#     pk=data,
#     axis=2,
# )

In [None]:
# metrics_df["H__last"] = stat_h[-1]

# metrics_df["H__diff_last_250"] = stat_h[-1]-stat_h[250]
# metrics_df["H__diff_last_450"] = stat_h[-1]-stat_h[450]

# metrics_df["H__var_from_0"] = stat_h[0:].var(axis=0)
# metrics_df["H__var_from_450"] = stat_h[450:].var(axis=0)

# metrics_df["H__std_from_0"] = stat_h[0:].std(axis=0)
# metrics_df["H__std_from_450"] = stat_h[450:].std(axis=0)

# metrics_df["H__mean_from_0"] = stat_h[0:].mean(axis=0)
# metrics_df["H__mean_from_400"] = stat_h[400:].mean(axis=0)

In [None]:
# metrics_df.to_csv('metrics_600.csv', )

In [None]:
idx_plot = np.random.choice(stat_h.shape[1], size=50, replace=False, )

plt.figure(figsize=(20,3))
plt.plot(stat_h[:, idx_plot])
plt.title("Entropy of BN layer of randomly sampled objects")

plt.show()

> The separation with this measure can happen later in epochs

In [None]:
%%time
fig, ax = plt.subplots()

mean_plot = stat_h.mean()
std_plot = stat_h.std()
def animate(i):
    ax.clear()
    plt.hist(stat_h[i:,:-3000].mean(axis=0), bins=100);
    plt.hist(stat_h[i:,-3000:].mean(axis=0), bins=100);
    plt.title(f"Mean of Entropy of bottleneck layer from {i} epoch")
    plt.xlim(
        mean_plot-4*std_plot, 
        mean_plot+4*std_plot
    )
    plt.yscale('log')
    plt.legend(["Original", 'Distorted'])
    
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(stat_h))
HTML(ani.to_jshtml())

In [None]:
stat_plot = stat_h.mean(axis=0)

plot_img_zoom_for_stat__imghist(
    stat_values=stat_plot,
    thr_l=stat_plot.min(),
    thr_r=stat_plot.max(),
    name_stat="Std after 7th epoch (epochs) of loss per sample (emb_size) of samples",
    figsize=(70,25),
    N_IMG_ZOOM=60,
    N_IMG_PER_ZOOM=15,
    log_hist_y=True,
)

In [None]:
%%time
fig, ax = plt.subplots()

mean_plot = stat_h.std(axis=0).mean()
std_plot = stat_h.std(axis=0).std()
def animate(i):
    ax.clear()
    plt.hist(stat_h[i:,:-3000].std(axis=0), bins=100);
    plt.hist(stat_h[i:,-3000:].std(axis=0), bins=100);
    plt.title(f"Std of Entropy of bottleneck layer from {i} epoch")
    plt.xlim(
        mean_plot-4*std_plot, 
        mean_plot+4*std_plot
    )
    plt.yscale('log')
    plt.legend(["Original", 'Distorted'])
    
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(stat_h))
HTML(ani.to_jshtml())

In [None]:
%%time
fig, ax = plt.subplots()

mean_plot = stat_h.var(axis=0).mean()
std_plot = stat_h.var(axis=0).std()
def animate(i):
    ax.clear()
    plt.hist(stat_h[i:,:-3000].var(axis=0), bins=100);
    plt.hist(stat_h[i:,-3000:].var(axis=0), bins=100);
    plt.title(f"Var of Entropy of bottleneck layer from {i} epoch")
#     plt.xlim(
#         mean_plot-1*std_plot, 
#         mean_plot+1*std_plot
#     )
    plt.yscale('log')
    plt.legend(["Original", 'Distorted'])
    
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(stat_h))
HTML(ani.to_jshtml())

# Loss during training

In [None]:
data_loss_filename = "samplewise_metrics_10_600_wimgsums.npy"

results_path = os.path.join("MNIST", "results")
data_loss_path = os.path.join(results_path, data_loss_filename)

data_loss = np.load(data_loss_path, allow_pickle=True)
data_loss = np.array(data_loss.item()['samplewise_loss'])

In [None]:
loss_df = pd.DataFrame(index=id_new)

loss_df["loss__last"] = data_loss[-1]

loss_df["loss__var_from_0"] = data_loss[0:].var(axis=0)
loss_df["loss__var_from_20"] = data_loss[20:].var(axis=0)
loss_df["loss__var_from_50"] = data_loss[50:].var(axis=0)

loss_df["loss__mean_from_0"] = data_loss[0:].mean(axis=0)
loss_df["loss__mean_from_20"] = data_loss[20:].mean(axis=0)
loss_df["loss__mean_from_50"] = data_loss[50:].mean(axis=0)

loss_df["loss__std_from_0"] = data_loss[0:].std(axis=0)
loss_df["loss__std_from_20"] = data_loss[20:].std(axis=0)
loss_df["loss__std_from_50"] = data_loss[50:].std(axis=0)

loss_df["loss__diff_last_0"] = data_loss[-1]-data_loss[0]
loss_df["loss__diff_last_20"] = data_loss[-1]-data_loss[20]
loss_df["loss__diff_last_50"] = data_loss[-1]-data_loss[50]

loss_df.to_csv('metrics_losses_600.csv',)

## Std of loss from i

In [None]:
plt.hist(data_loss[0:,:-3000].std(axis=0), bins=100);
plt.hist(data_loss[0:,-3000:].std(axis=0), bins=100);
plt.legend(["Original", 'Distorted'])
plt.title(f'Std of loss from {0} epoch')
plt.show()

## Loss value at i

In [None]:
idx_plot = np.random.choice(data_loss.shape[1], size=50, replace=False, )

plt.figure(figsize=(20,3))
plt.plot(data_loss[:, idx_plot])
plt.title("Loss of AE of randomly sampled objects")

plt.show()

In [None]:
import matplotlib.animation

fig, ax = plt.subplots()
def animate(i):
    ax.clear()
    plt.hist(data_loss[i,:-3000], bins=100);
    plt.hist(data_loss[i,-3000:], bins=100);
    plt.title(f"Loss of AE at {i} epoch")
    plt.legend(["Original", 'Distorted'])
    
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(data_loss), interval=50)
from IPython.display import HTML
HTML(ani.to_jshtml())

In [None]:
stat_plot = data_loss[-1]

plot_img_zoom_for_stat__imghist(
    stat_values=stat_plot,
    thr_l=stat_plot.min(),
    thr_r=stat_plot.max(),
    name_stat="Std after 7th epoch (epochs) of loss per sample (emb_size) of samples",
    figsize=(70,25),
    N_IMG_ZOOM=60,
    N_IMG_PER_ZOOM=15,
    log_hist_y=False,
)

## Mean of loss from i

In [None]:
plt.hist(data_loss[0:,:-3000].mean(axis=0), bins=100);
plt.hist(data_loss[0:,-3000:].mean(axis=0), bins=100);
plt.title(f"Mean of loss of AE from {0} epoch")
plt.legend(["Original", 'Distorted'])

## Var of loss from i

In [None]:
import matplotlib.animation

fig, ax = plt.subplots()
def animate(i):
    ax.clear()
    plt.title(i)
    plt.hist(data_loss[i:,:-3000].var(axis=0), bins=100);
    plt.hist(data_loss[i:,-3000:].var(axis=0), bins=100);
    plt.yscale('log')
    plt.legend(["Original", 'Distorted'])
    
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(data_loss), interval=50)
from IPython.display import HTML
HTML(ani.to_jshtml())

## Difference bw last and i losses

In [None]:
import matplotlib.animation

fig, ax = plt.subplots()
def animate(i):
    ax.clear()
    plt.title(i)
    plt.hist(data_loss[-1,:-3000]-data_loss[i,:-3000], bins=100);
    plt.hist(data_loss[-1,-3000:]-data_loss[i,-3000:], bins=100);
    plt.yscale('log')
    plt.legend(["Original", 'Distorted'])
    
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(data_loss))
from IPython.display import HTML
HTML(ani.to_jshtml())