In [1]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from datasets import load_dataset
from metrics import AlignmentMetrics
import itertools
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np

# -------------------- Configuration --------------------

# Define model checkpoints with these adjusted steps
all_checkpoints = sorted(
    {f"step{step}" for step in [0, 1000] + list(range(8000, 15000, 7000)) + list(range(15000, 143000, 7000)) + [143000]},
    key=lambda x: int(x.replace("step", ""))
)
print(all_checkpoints)

model_checkpoints = {
    "pythia-1.4b": all_checkpoints,
    "pythia-1b": all_checkpoints,
    "pythia-410m": all_checkpoints,
    "pythia-160m": all_checkpoints,
    "pythia-70m": all_checkpoints,
}

# Directory to store extracted features
FEATURES_DIR = "model_features"
os.makedirs(FEATURES_DIR, exist_ok=True)

# Load a sentence classification dataset (e.g., SST-2 from GLUE)
dataset = load_dataset("glue", "sst2", split="validation[:1000]")  # Using 1000 samples for faster processing
sentences = dataset["sentence"]

['step0', 'step1000', 'step8000', 'step15000', 'step22000', 'step29000', 'step36000', 'step43000', 'step50000', 'step57000', 'step64000', 'step71000', 'step78000', 'step85000', 'step92000', 'step99000', 'step106000', 'step113000', 'step120000', 'step127000', 'step134000', 'step141000', 'step143000']


In [None]:
# -------------------- Feature Extraction Functions --------------------

def extract_features(model, tokenizer, sentences, batch_size=64):
    """Extract last hidden state features from the model for given sentences using mixed precision."""
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    features = []
    with torch.cuda.amp.autocast(), torch.no_grad():
        for i in tqdm(range(0, len(sentences), batch_size), desc="Extracting features", leave=False):
            batch_sentences = sentences[i:i + batch_size]
            inputs = tokenizer(batch_sentences, return_tensors="pt", padding=True, truncation=True, max_length=128)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            # Set the model to output hidden states
            outputs = model(**inputs, output_hidden_states=True)
            # Extract the last hidden state (typically the last element in hidden_states)
            last_hidden_state = outputs.hidden_states[-1]
            cls_tokens = last_hidden_state[:, 0, :]  # Use the first token's hidden state as a feature
            features.append(cls_tokens.cpu())  # Move to CPU to save memory
    
    features = torch.cat(features, dim=0)
    return features

# -------------------- Load Models and Extract Features --------------------

def extract_features_for_checkpoint(model_name, checkpoint, sentences, batch_size=64):
    """Extract features for a given model at a specific checkpoint."""
    feature_save_path = os.path.join(FEATURES_DIR, f"{model_name.replace('/', '_')}_{checkpoint}.pt")
    
    # If features are already saved, load them
    if os.path.exists(feature_save_path):
        print(f"Loading saved features for {model_name} at {checkpoint}...")
        return checkpoint, torch.load(feature_save_path)
    
    # Load model and corresponding tokenizer with the specific checkpoint revision
    model = AutoModelForCausalLM.from_pretrained(f"EleutherAI/{model_name}", revision=checkpoint)
    tokenizer = AutoTokenizer.from_pretrained(f"EleutherAI/{model_name}", revision=checkpoint)
    
    # Ensure the tokenizer has a padding token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Extract features
    features = extract_features(model, tokenizer, sentences, batch_size)
    
    # Save the extracted features to disk
    torch.save(features, feature_save_path)
    print(f"Saved features for {model_name} at {checkpoint} to {feature_save_path}")
    
    # Free model and tokenizer from GPU memory
    del model, tokenizer
    torch.cuda.empty_cache()
    
    return checkpoint, features

# -------------------- Process Models Across Checkpoints --------------------

model_features = {}
batch_size = 128  # Adjust based on available GPU memory
max_workers = 4  # Number of models to process concurrently

# Load all checkpoints and extract features
for model_name, checkpoints in tqdm(model_checkpoints.items(), desc="Loading checkpoints and extracting features"):
    model_features[model_name] = {}
    
    # Process each model's checkpoints in parallel
    with ThreadPoolExecutor(max_workers=max_workers) as model_executor:
        futures = {
            model_executor.submit(extract_features_for_checkpoint, model_name, checkpoint, sentences, batch_size): (model_name, checkpoint)
            for checkpoint in checkpoints
        }
        
        # Gather results as they complete
        for future in tqdm(as_completed(futures), total=len(futures), desc=f"Feature extraction for {model_name}"):
            model_name, checkpoint = futures[future]
            try:
                checkpoint, features = future.result()
                model_features[model_name][checkpoint] = features
            except Exception as e:
                print(f"[ERROR] Processing {model_name} at {checkpoint}: {e}")

Loading checkpoints and extracting features:   0%|                                                                    | 0/5 [00:00<?, ?it/s]
Feature extraction for pythia-1.4b:   0%|                                                                            | 0/23 [00:00<?, ?it/s][A

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.93G [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.93G [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.93G [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.93G [00:00<?, ?B/s]

In [None]:
# -------------------- Alignment Metric Computation --------------------

alignment_records = []

# Calculate alignment scores at different checkpoints, ensuring checkpoints exist for both models
common_checkpoints = set(all_checkpoints)
for model_name in model_features.keys():
    common_checkpoints.intersection_update(model_features[model_name].keys())

# Process alignment across all model pairs at each common checkpoint
for checkpoint in tqdm(sorted(common_checkpoints), desc="Calculating alignment scores"):
    for model_a, model_b in itertools.combinations(model_features.keys(), 2):
        try:
            features_a = model_features[model_a][checkpoint]
            features_b = model_features[model_b][checkpoint]
            score = AlignmentMetrics.mutual_knn(features_a, features_b, topk=50)
            
            alignment_records.append({
                'Model A': model_a,
                'Model B': model_b,
                'Checkpoint': checkpoint,
                'Score': score
            })
        except Exception as e:
            print(f"[ERROR] Calculating alignment score for {model_a} and {model_b} at {checkpoint}: {e}")

In [None]:
# Create DataFrame from results
score_df = pd.DataFrame(alignment_records)

# -------------------- Plotting the Results --------------------

# Plot alignment scores as training progresses
plt.figure(figsize=(12, 8))
sns.lineplot(data=score_df, x='Checkpoint', y='Score', hue='Model A', style='Model B', markers=True)
plt.title('Alignment Score (mutual_knn) vs. Training Checkpoint')
plt.xlabel('Training Checkpoint')
plt.ylabel('Alignment Score')
plt.xticks(rotation=45)
plt.legend(loc='best', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()