In [None]:
import os
import sys
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from IPython.display import display
import matplotlib.pyplot as plt

# %matplotlib inline
%matplotlib widget

# Add the project root directory to sys.path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(project_root)

In [None]:
checkpoint_path='../cloud_runs/69.55.141.119/barc/runs/barc/20241106_004918_nogit_nobranch_lr5e-05_bl1e-06_ssu0_bs16_h4_es888_nl18_we10_as1_ph1_ac1_ad1_scosine_oadam_ge1_mh0_ssnone_ss1e-02_c5/Transformer_best_eis3056_ep5856.pt'

checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

print(checkpoint['epoch'], checkpoint['train_loss'], checkpoint['epoch_in_session'])
print(checkpoint['hyperparameters'])
print(checkpoint['epoch'], checkpoint['train_loss'], checkpoint['epoch_in_session'])
print(checkpoint['model_state_dict'].keys())
print(checkpoint['model_state_dict']['embedding.weight'].shape, checkpoint['model_state_dict']['positional_encoding'].shape, checkpoint['model_state_dict']['grid_encoding'].shape)

In [None]:
from analyze_tensor import analyze_tensor, visualize_tsne

def generate_tensor_names(layer_num):
    # List of tensor names with formatted layer number
    tensor_names = [
        f'layers.{layer_num}.attention.in_proj_weight',
        f'layers.{layer_num}.attention.in_proj_bias',
        f'layers.{layer_num}.attention.out_proj.weight',
        f'layers.{layer_num}.attention.out_proj.bias',
        f'layers.{layer_num}.norm1.weight',
        f'layers.{layer_num}.norm1.bias',
        f'layers.{layer_num}.norm2.weight',
        f'layers.{layer_num}.norm2.bias',
        f'layers.{layer_num}.feed_forward.0.weight',
        f'layers.{layer_num}.feed_forward.0.bias',
        f'layers.{layer_num}.feed_forward.3.weight',
        f'layers.{layer_num}.feed_forward.3.bias'
    ]
    return tensor_names

def analyze_layer_tensors(checkpoint, layer_num):
    # Generate tensor names for the specified layer
    tensor_names = generate_tensor_names(layer_num)

    # Loop through each tensor name
    for name in tensor_names:
        # Get the tensor from the checkpoint
        tensor = checkpoint['model_state_dict'][name]
        
        # Call the analyze_tensor function
        analyze_tensor(tensor, f'{name}, {tensor.shape}')

def analyze_layer_tensors_vertical(checkpoint, names, layers_count, stack=True):
    # Generate tensor names for the specified layer
    tensor_names = [f'layers.{layer_num}.{name}' for layer_num in range(layers_count) for name in names]

    print('tensor_names', tensor_names)

    # Get the tensor from the checkpoint
    if stack:
        tensor = torch.stack([checkpoint['model_state_dict'][name] for name in tensor_names])
    else:
        tensor = torch.cat([checkpoint['model_state_dict'][name] for name in tensor_names], dim = 1)
    
    # Call the analyze_tensor function
    analyze_tensor(tensor, f'{names}, {tensor.shape}')



In [None]:
analyze_tensor(checkpoint['model_state_dict']['embedding.weight'], 'embedding.weight')
visualize_tsne(checkpoint['model_state_dict']['embedding.weight'], 'embedding.weight', perplexity=5)

In [None]:
analyze_tensor(checkpoint['model_state_dict']['fc_out.weight'], 'fc_out.weight')

analyze_tensor(checkpoint['model_state_dict']['fc_out.bias'], 'fc_out.bias')

visualize_tsne(checkpoint['model_state_dict']['fc_out.weight'], 'fc_out.weight', perplexity=5)

# ts.show(checkpoint['model_state_dict']['fc_out.bias'].unsqueeze(0), interpolation='nearest', figsize=(20, 20))

In [None]:
# Visualize positional encoding
analyze_tensor(checkpoint['model_state_dict']['positional_encoding'], 'positional_encoding')

In [None]:
# Visualize grid encoding
analyze_tensor(checkpoint['model_state_dict']['grid_encoding'], 'grid_encoding')

In [None]:
print(checkpoint['model_state_dict']['grid_scale'])

In [None]:
analyze_layer_tensors_vertical(checkpoint, ('norm1.weight', 'norm2.weight'), 18)
analyze_layer_tensors_vertical(checkpoint, ('norm1.bias', 'norm1.bias'), 18)

# analyze_layer_tensors_vertical(checkpoint, ('attention.in_proj_weight', ), 18, stack=False)
# analyze_layer_tensors_vertical(checkpoint, ('attention.out_proj.weight', ), 18, stack=False)
# analyze_layer_tensors_vertical(checkpoint, ('attention.in_proj_bias', ), 18)
# analyze_layer_tensors_vertical(checkpoint, ('attention.out_proj.bias', ), 18)

# analyze_layer_tensors_vertical(checkpoint, ('feed_forward.0.weight', ), 18, stack=False)
# analyze_layer_tensors_vertical(checkpoint, ('feed_forward.3.weight', ), 18, stack=False)
# analyze_layer_tensors_vertical(checkpoint, ('feed_forward.0.bias', ), 18)
# analyze_layer_tensors_vertical(checkpoint, ('feed_forward.3.bias', ), 18)

In [None]:
layer_number = 0
analyze_layer_tensors(checkpoint, layer_number)