# Activations visualized - part 2

gets saved acts and grads, reduces dimensions, draws charts

In [None]:
%config InlineBackend.figure_format = 'retina'
import matplotlib
matplotlib.rcParams.update({
        # "font.family": "Times New Roman",
        "axes.labelsize": 18,
        "font.size": 18,
        "legend.fontsize": 18,
        "xtick.labelsize": 18,
        "ytick.labelsize": 18,
})


In [None]:
%env CUDA_VISIBLE_DEVICES=0
%env OMP_NUM_THREADS=16 
%env MKL_NUM_THREADS=16 
# %load_ext autoreload
# %autoreload 2

In [None]:
import sys, pathlib, os
sys.path.append(str(pathlib.Path('./src').resolve()))

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
device = torch.device('cuda:0')

from tqdm.auto import tqdm, trange
print(f"{torch.__version__=}, {transformers.__version__=}, {device=}")


from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

In [None]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

def reduce_data_dim(data, method='pca', n_components=2):
    if method.lower() == 'pca':
        pca = PCA(n_components=n_components)
        data = pca.fit_transform(data)
    # TSNE IS SLOW
    elif method.lower() == 'tsne':
        tsne = TSNE(n_components=n_components)
        if data.shape[-1] > 50:
            data = reduce_data_dim(data, method='pca', n_components=32)
        data = tsne.fit_transform(data)
    return data
    

def plot_act_grad(repacked_data, labels, step=None, method='pca'):
    
    fig, axs = plt.subplots(2, 2, figsize=(10,6))

    # labels = p.label_ids
    for k, ax in zip(('fwd_0', 'fwd_1'), axs[0]):
        data = reduce_data_dim(repacked_data[k].detach().cpu(), method=method)
        ax.scatter(data[:,0], data[:,1], c=labels, cmap='viridis', alpha=0.6)
        ax.set_title(k)

    for k, ax in zip(('back_0', 'back_1'), axs[1]):
        data = reduce_data_dim(repacked_data[k].detach().cpu(), method=method)
        ax.scatter(data[:,0], data[:,1], c=labels, cmap='viridis', alpha=0.6)
        ax.set_title(k)

    plt.suptitle(f'Visualization using {method.upper()} for two LoRAs, step {step}')
    plt.tight_layout()
    plt.show()
    
def reduce50(dd):
    return torch.pca_lowrank(dd.to(0).to(torch.float32), q = 50)[0]

In [None]:
chart_data_path = "./outs"
sorted(os.listdir(chart_data_path))

In [None]:
chart_data = {}

## Loading raw data and reducing dimensions

In [None]:
for step  in [0, 1000, 4000, 16000]:
    for k_reg in [0, 1000]:
        # step = 0
        # k_reg = 999
        filename = os.path.join(chart_data_path, f"outs_{step}_reg_{k_reg}.pt")
        print(filename)

        data = torch.load(filename)

        chart_data[f'acts_step_{step}_reg_{k_reg}'] = reduce_data_dim(reduce50(data['acts']).detach().cpu(), method='tsne')
        chart_data[f'grads_step_{step}_reg_{k_reg}'] = reduce_data_dim(reduce50(data['grads']).cpu(), method='tsne')
        chart_data[f'labels_step_{step}_reg_{k_reg}'] = data['labels']

        rr = []
        for i in trange(192):
            d = data['lora_grads'].view([872, 192,-1])[:, i, :]
            try:
                r = torch.pca_lowrank(d.to(0).to(torch.float32), q = 4)
                rr.append(r[0])
            except :
                pass
        r1 = torch.hstack(rr)
        rank2 = min(50, r1.shape[-1])
        r2 = torch.pca_lowrank(r1, q=rank2)[0]
        r3 = reduce_data_dim(r2.cpu(), method='tsne')
        chart_data[f'loragrads_step_{step}_reg_{k_reg}'] = r3

In [None]:
for step  in [0, 1000, 4000, 16000]:
    for k_reg in [0, 1000]:
        filename = os.path.join(chart_data_path, f"outs_{step}_reg_{k_reg}.pt")
        print(filename)

        data = torch.load(filename)

        chart_data[f'acts_step_{step}_reg_{k_reg}'] = reduce_data_dim(reduce50(data['acts']).cpu(), method='tsne')
        chart_data[f'grads_step_{step}_reg_{k_reg}'] = reduce_data_dim(reduce50(data['grads']).cpu(), method='tsne')
        chart_data[f'labels_step_{step}_reg_{k_reg}'] = data['labels']

        rr = []
        for i in trange(192):
            d = data['lora_grads'].view([872, 192,-1])[:, i, :]
            try:
                r = torch.pca_lowrank(d.to(0).to(torch.float32), q = 4)
                rr.append(r[0])
            except :
                pass
        r1 = torch.hstack(rr)
        rank2 = min(50, r1.shape[-1])
        r2 = torch.pca_lowrank(r1, q=rank2)[0]
        r3 = reduce_data_dim(r2.cpu(), method='tsne')
        chart_data[f'loragrads_step_{step}_reg_{k_reg}'] = r3

In [None]:
data['lora_grads'].view([872, 192,-1]).shape

In [None]:
torch.save(chart_data, 'chart_data.pt')  # 400 KB

In [None]:
chart_data.keys()

## Drawing Charts

In [None]:
chart_data = torch.load('chart_data.pt')

In [None]:
for k_reg in [0, 1000]:
    for val in ['acts', 'grads', 'loragrads']:

        print(val, k_reg)
        fig, axs = plt.subplots(1, 4, figsize=(15,3))
        for ax, step in zip(axs, [0, 1000, 4000, 16000]):

            series = chart_data[f"{val}_step_{step}_reg_{k_reg}"]
            labels = chart_data[f"labels_step_{step}_reg_{k_reg}"]
            ax.scatter(series[:,0], series[:,1], c=labels, cmap='viridis', alpha=0.6)
            ax.set_title(f'Step: {step}')
            ax.set_xticks([])
            ax.set_yticks([])
        plt.grid
        if "pdf" not in os.listdir():
            os.mkdir(os.path.join(os.getcwd(), "pdf"))
        plt.savefig(f'pdf/{val}_reg_{k_reg}.pdf', bbox_inches='tight')
        plt.show()