In [None]:
import os, sys
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import random
import torch as t
from torch import nn, Tensor
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from dataclasses import dataclass
import numpy as np
import einops
from jaxtyping import Float, Int
from typing import Optional, Callable, Union, List, Tuple
from functools import partial
from tqdm.notebook import tqdm
# from tqdm.auto import tqdm
from dataclasses import dataclass
from rich import print as rprint
from rich.table import Table
from IPython.display import display, HTML
from pathlib import Path
import pandas as pd
pd.options.plotting.backend = "plotly"
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import copy
import shap
import plotly.express as px

chapter = "chapter1_transformer_interp"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = (exercises_dir / "part4_superposition_and_saes").resolve()
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow, line, hist
from part31_superposition_and_saes.utils import (
    plot_features_in_2d,
    plot_features_in_Nd,
    plot_features_in_Nd_discrete,
    plot_correlated_features,
    plot_feature_geometry,
    frac_active_line_plot,
    animate_features_in_2d
)

from feature_geometry_utils import *
device = t.device("cuda" if t.cuda.is_available() else "cpu")

if not t.backends.mps.is_available():
    if not t.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

else:
    device = t.device("mps")

MAIN = __name__ == "__main__"

t.manual_seed(20)

importance = (1  ^  t.arange(cfg.n_features))

feature_probability = (25  ^  -t.linspace(0, 1, cfg.n_instances))

In [None]:
def feature_probability(n_instances):
    feature_prob = (25 ** -t.linspace(0, 1, n_instances))
    feature_prob = einops.rearrange(feature_prob, "instances -> instances ()")
    return feature_prob

def feature_importance(n_features):
    importance = (1 ** t.arange(n_features))
    importance = einops.rearrange(importance, "features -> () features")
    return importance

In [None]:
feature_probability(10)

In [None]:
feature_importance(6)

In [None]:
def save_experiment_metadata(n_features,seed, n_instances = 8,optim_fn=t.optim.Adam):
    print(f"n_features:{n_features}")
    print(f"seed:{seed}")

    cfg = Config(
            n_instances = n_instances,
            n_features = n_features,
            n_hidden = 2,
            optim_fn = optim_fn
        )

    
    # importance varies within features for each instance
    importance = feature_importance(cfg.n_features)
    
    # sparsity is the same for all features in a given instance, but varies over instances
    feature_prob = feature_probability(cfg.n_instances)

    model = Model(
        cfg = cfg,
        device = device,
        importance = importance,
        feature_probability = feature_prob,
    )
    summed_losses, losses, batches,all_W, all_b, per_feature_losses, test_losses, test_summed_losses, test_per_feature_losses = model.optimize(steps=10_000)

    results_dict = {}
    results_dict["n_features"] = n_features
    results_dict["seed"] = seed
    results_dict["model"] = model
    results_dict["importance"] = model.importance
    results_dict["feature_prob"] = feature_prob
    results_dict["W"] = model.W.detach()
    results_dict["b"] = model.b_final.detach()
    results_dict["summed_loss"] = summed_losses
    results_dict["per_feature_losses"] = per_feature_losses
    results_dict["losses"] = losses
    results_dict["test_summed_losses"] = test_summed_losses
    results_dict["test_per_feature_losses"] = test_per_feature_losses
    results_dict["test_losses"] = test_losses
    results_dict["batches"] = batches
    results_dict["all_W"] = all_W
    results_dict["all_b"] = all_b
    
    return results_dict

## Notes 11th Jan
- Case 1: where 6th feature is not learned (pentagon representation)
- Case 2: where 6th feature is learned and feature norm is high (hexagon representation)
- Case 3: where 6th feature is learned and feature norm is low (pentagon representation)


## To Do:
- save the random seed and initialization for future runs
- isolate that particular seed and sparsity combination for reproducibility to check how much initialization is affecting hexagon (train just 1 instance)
- is there a propensity of learning 5 vs 6 feature based on initialization
- what the optimizer is doing in all the cases and how its affecting the no of features learned
- what if we slightly change the initialization, does it change it much or changes it a lot (can denote where in the peak that point is)
- how much the optimizer is throwing it away or towards from an initialization
- 

## Bigger question
- why 5 and not any other no of features like 6 or 7
- 

In [None]:
random_seed_list = random.sample(range(1, 100), 8)
random_seed_list.extend([20,26])
print(random_seed_list)

#### Hexagon seeds: 26, 20, 88, 87


- Case 1: where 6th feature is not learned (pentagon representation)(seed=26 & sparsity=0.057)
- Case 2: where 6th feature is learned and feature norm is high (hexagon representation) (seed=20, sparsity=0.057 & seed=26, sparsity=0.040)
- Case 3: where 6th feature is learned and feature norm is low (pentagon representation)

In [None]:
results_list_hexagon_replication = []

for seed in random_seed_list:
    t.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    results_list_hexagon_replication.append(save_experiment_metadata(6,seed,10))

In [None]:
for results in results_list_hexagon_replication:
    print("n_features:",results['n_features'])

    plot_features_in_2d(
        results['W'],
        colors = results['importance'],
        title = f"Superposition: {results['n_features']} features represented in 2D space",
        subplot_titles = [f"1 - S = {i:.3f}" for i in results['feature_prob'].squeeze()],
    )

results_list_heptagon_replication = []

for n in range(10):
    results_list_heptagon_replication.append(save_experiment_metadata(7,10))

for results in results_list_heptagon_replication:
    print("n_features:",results['n_features'])

    plot_features_in_2d(
        results['W'],
        colors = results['importance'],
        title = f"Superposition: {results['n_features']} features represented in 2D space",
        subplot_titles = [f"1 - S = {i:.3f}" for i in results['feature_prob'].squeeze()],
    )

### TO DO:
- analyze both the 6 feature instances in 6 features and see what the norm, test loss, per feature loss trends are in it 
- compare it to the instance where 5 and 1 overlap is present
- analyze the 6 feature instance in 7 features, and compare it to that of 6 features

# Analyzing Hexagon in n_features = 6

## 1st Instance, 2nd iteration, second lowest sparsity

In [None]:
first_hexagon_nfeatures_6 = results_list_hexagon_replication[1]

In [None]:
first_hexagon_nfeatures_6.keys()

In [None]:
pd.DataFrame(first_hexagon_nfeatures_6['losses'])[index].values

In [None]:
def norm_vs_feature_learned(result, index):
    print("n_feature:", result['n_features'])
    feature_wise_norm = [t.norm(t.tensor(W[index]),dim=0).cpu().tolist() for W in result['all_W']]
    overall_norm = [t.norm(t.tensor(W[index])).item() for W in result['all_W']]
    n_features_learned = [sum(np.round(np.abs(W[index]).sum(axis = 0),0)>0).item() for W in result['all_W']]
    df = pd.DataFrame(feature_wise_norm)
    df['overall_norm'] = overall_norm
    
    fig = df.plot()
        
    # Add the second trace for y2
    fig.add_trace(
        go.Scatter(x=df.index, y=n_features_learned, name='n_features_learned', yaxis='y2')
    )
    
    # Update layout for dual axes
    fig.update_layout(
        yaxis=dict(title='norm/n_features_learned', side='left'),
        yaxis2=dict(title='loss', overlaying='y', side='right'),
        xaxis=dict(title='x'),
        title='Dual Axis Plot with Plotly'
    )
    
    # Show the plot
    fig.show()

def per_feature_loss_viz(result, index, test_loss=False, log=False,moving_avg_window = 20):
    print("n_feature:", result['n_features'])
    if not test_loss:
        per_feature_loss_key = 'per_feature_losses'
        loss_key = 'losses'
        title='Per Feature loss vs overall loss'
    else:
        per_feature_loss_key = 'test_per_feature_losses'
        loss_key = 'test_losses'
        title = "Test Per Feature loss vs test overall loss"
        
    per_feature_loss_df = pd.DataFrame([per_feature_loss[index] for per_feature_loss in result[per_feature_loss_key]])
    per_feature_loss_df['overall_loss'] = pd.DataFrame(result[loss_key])[index].values
    n_features_learned = [sum(np.round(np.abs(W[index]).sum(axis = 0),0)>0).item() for W in result['all_W']]

    moving_avg_loss = per_feature_loss_df.rolling(window=moving_avg_window).mean()
    
    if not log:
        fig = moving_avg_loss.iloc[moving_avg_window-1:].plot()    
    else:
        fig = np.log(moving_avg_loss.iloc[moving_avg_window-1:]).plot()
        
    # Add the second trace for y2
    fig.add_trace(
        go.Scatter(x=moving_avg_loss.index[moving_avg_window-1:], y=n_features_learned[moving_avg_window-1:], name='n_features_learned', yaxis='y2')
    )
    
    # Update layout for dual axes
    fig.update_layout(
        yaxis=dict(title='per_feature_loss/overall loss', side='left'),
        yaxis2=dict(title='n_features_learned', overlaying='y', side='right'),
        xaxis=dict(title='x'),
        title=title
    )
    
    # Show the plot
    fig.show()

In [None]:
per_feature_loss_df.rolling(window=moving_avg_window).mean().shape

In [None]:
norm_vs_feature_learned(results_list_hexagon_replication[5],9)

In [None]:
per_feature_loss_viz(results_list_hexagon_replication[0], 9)

In [None]:
norm_vs_feature_learned(first_hexagon_nfeatures_6,8)

In [None]:
norm_vs_feature_learned(first_hexagon_nfeatures_6,9)

In [None]:
per_feature_loss_viz(first_hexagon_nfeatures_6, 8)

In [None]:
per_feature_loss_viz(first_hexagon_nfeatures_6, 8,test_loss=True)

In [None]:
norm_vs_feature_learned(first_hexagon_nfeatures_6,9)

In [None]:
per_feature_loss_viz(first_hexagon_nfeatures_6, 9)

In [None]:
per_feature_loss_viz(first_hexagon_nfeatures_6, 9,test_loss = True)