In [2]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from transformers import GPTNeoXForCausalLM
from sklearn.decomposition import PCA
from matplotlib.cm import get_cmap

# Define constants
experiment = "fixed_5e6"
MODEL_NAME = "EleutherAI/pythia-70m"
#MODEL_NAME = "EleutherAI/pythia-410m"
MODEL_FOLDER = f"{experiment}/models"
OUTPUT_FOLDER = f"{experiment}/mad_analysis_output"
PCA_FOLDER = os.path.join(OUTPUT_FOLDER, 'pca')
os.makedirs(PCA_FOLDER, exist_ok=True)

# Flag for filtering fixed learning rate models
FIXED_LR = False

# Function to calculate MAD
def calculate_mad(weights1, weights2):
    return np.mean(np.abs(weights1 - weights2))

# Function to perform PCA and save the plot
def plot_pca(weights1, weights2, layer_name, epoch):
    try:
        if weights1.ndim < 2:
            weights1 = weights1.reshape(-1, 1)
            weights2 = weights2.reshape(-1, 1)
        pca = PCA(n_components=2)
        weights = np.concatenate([weights1, weights2], axis=0)
        pca_result = pca.fit_transform(weights)
        plt.figure(figsize=(10, 5))
        plt.scatter(pca_result[:len(weights1), 0], pca_result[:len(weights1), 1], alpha=0.5, label='Pre-trained')
        plt.scatter(pca_result[len(weights1):, 0], pca_result[len(weights1):, 1], alpha=0.5, label='Fine-tuned')
        plt.title(f'PCA of Weights: {layer_name} - Epoch {epoch}')
        plt.xlabel('PCA Component 1')
        plt.ylabel('PCA Component 2')
        plot_file = os.path.join(PCA_FOLDER, f'{layer_name.replace(".", "_")}_epoch_{epoch}_pca.png')
        plt.savefig(plot_file)
        plt.close()
    except Exception as e:
        print(f"Skipping PCA for {layer_name} due to error: {e}")

# Load models
def load_models(saved_model_path):
    model_pretrained = GPTNeoXForCausalLM.from_pretrained(MODEL_NAME).to('cuda')
    model_fine_tuned = GPTNeoXForCausalLM.from_pretrained(MODEL_NAME).to('cuda')
    model_fine_tuned.load_state_dict(torch.load(saved_model_path))
    return model_pretrained, model_fine_tuned

# Function to analyze models
def analyze_models(model_folder, fixed_lr=False):
    model_pretrained = GPTNeoXForCausalLM.from_pretrained(MODEL_NAME).to('cuda')
    layer_names = [name for name, _ in model_pretrained.named_parameters() if "weight" in name]
    mad_values = []

    csv_mad_path = os.path.join(OUTPUT_FOLDER, 'mad_values.csv')
    cmap = get_cmap('viridis')
    
    if os.path.exists(csv_mad_path):
        df_mad = pd.read_csv(csv_mad_path)
        print(f'Loaded existing MAD values from {csv_mad_path}')
        epochs = df_mad['Epoch'].unique()
        norm = plt.Normalize(min(epochs), max(epochs))
    else:
        if fixed_lr:
            model_files = [f for f in os.listdir(model_folder) if f.startswith("fixed_lr") and f.endswith(".pt")]
        else:
            model_files = [f for f in os.listdir(model_folder) if f.endswith(".pt")]
        epochs = [int(f.split('_')[-1].split('.')[0]) for f in model_files]
        norm = plt.Normalize(min(epochs), max(epochs))

        for file_name in model_files:
            model_path = os.path.join(model_folder, file_name)
            epoch = int(file_name.split('_')[-1].split('.')[0])
            model_fine_tuned = GPTNeoXForCausalLM.from_pretrained(MODEL_NAME).to('cuda')
            model_fine_tuned.load_state_dict(torch.load(model_path))

            # Calculate MAD for all layers
            for layer_name in layer_names:
                weights_pretrained = model_pretrained.state_dict()[layer_name].cpu().numpy()
                weights_fine_tuned = model_fine_tuned.state_dict()[layer_name].cpu().numpy()
                mad = calculate_mad(weights_pretrained.flatten(), weights_fine_tuned.flatten())
                mad_values.append({'Layer': layer_name, 'MAD': mad, 'Model': file_name, 'Epoch': epoch})
        
        # Save MAD values to CSV
        df_mad = pd.DataFrame(mad_values)
        df_mad.to_csv(csv_mad_path, index=False)
        print(f'MAD values saved to {csv_mad_path}')

    # Plot and save the MAD scatter plot
    plt.figure(figsize=(18, 6))
    unique_epochs = sorted(df_mad['Epoch'].unique())
    selected_epochs = [unique_epochs[0], unique_epochs[len(unique_epochs)//3], unique_epochs[2*len(unique_epochs)//3], unique_epochs[-1]]
    handles = []
    labels = []
    for epoch in unique_epochs:
        epoch_df = df_mad[df_mad['Epoch'] == epoch]
        color = cmap(norm(epoch))
        scatter = plt.scatter(range(len(epoch_df)), epoch_df['MAD'], color=color, alpha=0.6)
        if epoch in selected_epochs:
            handles.append(scatter)
            labels.append(f'Epoch {epoch}')
    plt.xlabel('Layer Index')
    plt.ylabel('MAD')
    plt.title('MAD for Each Layer')
    plt.legend(handles, labels)
    plt.xticks(range(0, len(layer_names), 20))  # Adjusting x-ticks to show every 20th layer index
    plt.tight_layout()
    plot_mad_path = os.path.join(OUTPUT_FOLDER, 'mad_plot.png')
    plt.savefig(plot_mad_path)
    plt.close()
    print(f'MAD plot saved to {plot_mad_path}')

def main():
    # Analyze the models for MAD
    analyze_models(MODEL_FOLDER, fixed_lr=FIXED_LR)

if __name__ == "__main__":
    main()


  cmap = get_cmap('viridis')


MAD values saved to fixed_5e6/mad_analysis_output/mad_values.csv
MAD plot saved to fixed_5e6/mad_analysis_output/mad_plot.png
