In [1]:
import json
import torch
import torch.nn as nn
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os
import re
import gc
from collections import defaultdict
from torch.utils.data import DataLoader, TensorDataset

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"PyTorch device: {device}")

# Create output directories
os.makedirs("projection_plots", exist_ok=True)
os.makedirs("projection_models", exist_ok=True)

# Initialize model and tokenizer
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
    

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto"
)
model.eval()

# Get model config
num_layers = model.config.num_hidden_layers
hidden_size = model.config.hidden_size
print(f"Model has {num_layers} layers, hidden size {hidden_size}")

# Generation config
GEN_CFG = GenerationConfig(max_new_tokens=3, do_sample=False, max_length=None)

def load_dataset(filepath):
    """Load the counting dataset"""
    with open(filepath, 'r') as f:
        data = json.load(f)
    print(f"Loaded {len(data)} examples from dataset")
    return data

def extract_int(text: str):
    """Return first integer, or -1 if none found."""
    INT_RE = re.compile(r"\d+")
    m = INT_RE.search(text)
    return int(m.group()) if m else -1

class ProjectionLayer(nn.Module):
    """Projects from early layer representations to final layer representations"""
    def __init__(self, hidden_size, use_residual=True):
        super().__init__()
        self.use_residual = use_residual
        
        # Multi-layer projection with residual connection
        self.projection = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 2),
            nn.GELU(),
            nn.LayerNorm(hidden_size * 2),
            nn.Dropout(0.1),
            nn.Linear(hidden_size * 2, hidden_size),
        )
        
        # Learnable residual weight
        if use_residual:
            self.residual_weight = nn.Parameter(torch.tensor(0.5))
        
    def forward(self, hidden_states):
        projected = self.projection(hidden_states)
        if self.use_residual:
            # Weighted sum of original and projected
            return self.residual_weight * hidden_states + (1 - self.residual_weight) * projected
        return projected

def collect_hidden_states(model, tokenizer, dataset, source_layer, target_layer, num_samples=2000):
    """Collect hidden states from source and target layers"""
    device = model.model.embed_tokens.weight.device
    
    source_hiddens = []
    target_hiddens = []
    position_indices = []  # Track which token positions we're collecting
    
    print(f"Collecting hidden states from layers {source_layer} and {target_layer}...")
    
    for i, example in enumerate(tqdm(dataset[:num_samples], desc="Collecting")):
        if i >= num_samples:
            break
            
        prompt = example['prompt']
        inputs = tokenizer(prompt, return_tensors="pt", padding=False)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True, return_dict=True)
            
            # Get representations from source and target layers
            source_hidden = outputs.hidden_states[source_layer]
            target_hidden = outputs.hidden_states[target_layer]
            
            # Collect representations from multiple positions (not just the last token)
            # This gives us more training data
            seq_len = source_hidden.shape[1]
            
            # Use last 5 tokens (or all if sequence is shorter)
            num_positions = min(5, seq_len)
            for pos in range(seq_len - num_positions, seq_len):
                source_hiddens.append(source_hidden[:, pos, :].cpu())
                target_hiddens.append(target_hidden[:, pos, :].cpu())
                position_indices.append(pos - seq_len)  # Negative indexing from end
        
        # Clear memory periodically
        if i % 100 == 0:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    # Convert to tensors
    source_hiddens = torch.cat(source_hiddens, dim=0)
    target_hiddens = torch.cat(target_hiddens, dim=0)
    
    print(f"Collected {len(source_hiddens)} hidden state pairs")
    return source_hiddens, target_hiddens

def train_projection(source_hiddens, target_hiddens, hidden_size, device, num_epochs=20, batch_size=64):
    """Train a projection from source to target representations"""
    
    # Ensure consistent dtype (use float32 for training stability)
    source_hiddens = source_hiddens.float()
    target_hiddens = target_hiddens.float()
    
    projection = ProjectionLayer(hidden_size, use_residual=True).to(device).float()
    projection.train()
    
    # Move data to device
    source_hiddens = source_hiddens.to(device)
    target_hiddens = target_hiddens.to(device)
    
    # Create data loader
    dataset = TensorDataset(source_hiddens, target_hiddens)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Training setup
    optimizer = torch.optim.AdamW(projection.parameters(), lr=1e-3, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    criterion = nn.MSELoss()
    
    # Training metrics
    train_losses = []
    
    print(f"Training projection with {len(source_hiddens)} samples...")
    
    for epoch in range(num_epochs):
        epoch_loss = 0
        num_batches = 0
        
        for batch_source, batch_target in dataloader:
            optimizer.zero_grad()
            
            # Forward pass
            projected = projection(batch_source)
            loss = criterion(projected, batch_target)
            
            # Add cosine similarity loss to encourage directional alignment
            cos_sim = nn.functional.cosine_similarity(projected, batch_target, dim=1).mean()
            loss = loss - 0.1 * cos_sim  # Negative because we want to maximize similarity
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(projection.parameters(), 1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
            num_batches += 1
        
        scheduler.step()
        avg_loss = epoch_loss / num_batches
        train_losses.append(avg_loss)
        
        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")
    
    projection.eval()
    
    # Plot training curve
    plt.figure(figsize=(8, 5))
    plt.plot(train_losses)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Projection Training Loss')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('projection_plots/training_loss.png', dpi=150)
    plt.close()
    
    return projection

def evaluate_with_projection(model, tokenizer, dataset, stop_at_layer, projection, max_samples=None):
    """Evaluate model with projection from early layer"""
    
    device = model.model.embed_tokens.weight.device
    projection = projection.to(device)
    projection.eval()
    
    correct = 0
    total = 0
    predictions = []
    
    samples = dataset if max_samples is None else dataset[:max_samples]
    
    for example in tqdm(samples, desc=f"Layer {stop_at_layer} + projection"):
        prompt = example['prompt']
        true_answer = example['answer']
        
        inputs = tokenizer(prompt, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            generated_ids = inputs['input_ids'].clone()
            
            for _ in range(GEN_CFG.max_new_tokens):
                # Get hidden states at early layer
                outputs = model(generated_ids, output_hidden_states=True, return_dict=True)
                hidden_states = outputs.hidden_states[stop_at_layer]
                
                # Convert to float32 for projection
                hidden_states = hidden_states.float()
                
                # Project to final layer space
                hidden_states = projection(hidden_states)
                
                # Convert back to model's dtype (float16) and apply layer norm
                hidden_states = hidden_states.to(model.dtype)
                hidden_states = model.model.norm(hidden_states)
                logits = model.lm_head(hidden_states)
                
                # Get next token
                next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
                generated_ids = torch.cat([generated_ids, next_token], dim=1)
                
                if next_token.item() == tokenizer.eos_token_id or tokenizer.decode(next_token[0]) == ')':
                    break
        
        # Extract prediction
        generation = tokenizer.decode(generated_ids[0][inputs['input_ids'].size(1):], skip_special_tokens=True)
        predicted_answer = extract_int(generation)
        
        predictions.append({
            'prompt': prompt,
            'true_answer': true_answer,
            'predicted_answer': predicted_answer,
            'generated_text': generation.strip(),
            'correct': predicted_answer == true_answer
        })
        
        if predicted_answer == true_answer:
            correct += 1
        total += 1
        
        # Debug first few
        if total <= 5:
            print(f"\nExample {total}:")
            print(f"  Generated: '{generation.strip()}'")
            print(f"  Predicted: {predicted_answer}, True: {true_answer}")
            print(f"  Correct: {predicted_answer == true_answer}")
    
    accuracy = correct / total if total > 0 else 0
    return accuracy, predictions

def run_projection_experiment(model, tokenizer, dataset, start_layer=10, end_layer=20, 
                            train_samples=2000, eval_samples=None):
    """Run the complete projection experiment"""
    
    results = {}
    
    # Baseline accuracy (given)
    BASELINE_ACCURACY = 0.73520
    results['baseline'] = {
        'layer': num_layers,
        'accuracy': BASELINE_ACCURACY,
        'predictions': []
    }
    
    # Target layer is always the final layer
    target_layer = num_layers
    
    # Collect hidden states for all layers we'll evaluate
    print("\nCollecting hidden states for projection training...")
    all_hidden_states = {}
    
    # Collect target layer hidden states once
    _, target_hiddens = collect_hidden_states(
        model, tokenizer, dataset, target_layer, target_layer, num_samples=train_samples
    )
    
    # Train projections for each source layer
    projections = {}
    
    for source_layer in range(start_layer, end_layer + 1):
        print(f"\n=== Training projection for layer {source_layer} ===")
        
        # Collect source layer hidden states
        source_hiddens, _ = collect_hidden_states(
            model, tokenizer, dataset, source_layer, target_layer, num_samples=train_samples
        )
        
        # Train projection
        device = model.model.embed_tokens.weight.device
        projection = train_projection(
            source_hiddens, target_hiddens, 
            hidden_size=model.config.hidden_size,
            device=device,
            num_epochs=20,
            batch_size=64
        )
        
        projections[source_layer] = projection
        
        # Save projection
        torch.save(projection.state_dict(), f'projection_models/projection_layer_{source_layer}.pt')
        
        # Clear memory
        del source_hiddens
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # Evaluate with projections
    print("\n" + "="*50)
    print("EVALUATION PHASE")
    print("="*50)
    
    for layer in range(start_layer, end_layer + 1):
        print(f"\n=== Evaluating Layer {layer} with Projection ===")
        
        accuracy, predictions = evaluate_with_projection(
            model, tokenizer, dataset, 
            stop_at_layer=layer,
            projection=projections[layer],
            max_samples=eval_samples
        )
        
        results[f'layer_{layer}'] = {
            'layer': layer,
            'accuracy': accuracy,
            'predictions': predictions
        }
        
        print(f"Accuracy for Layer {layer} + Projection: {accuracy:.3%}")
        
        # Clear memory
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    return results, projections

def visualize_results(results):
    """Create visualizations for the projection experiment"""
    
    # 1. Overall accuracy by layer
    plt.figure(figsize=(10, 6))
    
    layers = []
    accuracies = []
    
    for key in sorted(results.keys()):
        if key != 'baseline':
            layer = results[key]['layer']
            accuracy = results[key]['accuracy']
            layers.append(layer)
            accuracies.append(accuracy)
    
    # Baseline
    baseline_acc = results['baseline']['accuracy']
    
    plt.plot(layers, accuracies, 'o-', linewidth=2, markersize=8, 
             label='Early Stop + Projection', color='blue')
    plt.axhline(y=baseline_acc, color='red', linestyle='--', 
                label=f'Full Model ({num_layers} layers): {baseline_acc:.3%}')
    
    # Add 50% line for reference
    plt.axhline(y=0.5, color='gray', linestyle=':', alpha=0.5, label='Random (50%)')
    
    plt.xlabel('Stop at Layer', fontsize=12)
    plt.ylabel('Accuracy', fontsize=12)
    plt.title('Counting Accuracy: Early Layers + Learned Projection', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.ylim(0, 1.05)
    
    # Annotate best early layer
    if accuracies:
        max_acc = max(accuracies)
        max_layer = layers[accuracies.index(max_acc)]
        plt.annotate(f'Best: {max_acc:.3%} at layer {max_layer}',
                    xy=(max_layer, max_acc),
                    xytext=(max_layer + 1, max_acc - 0.1),
                    arrowprops=dict(arrowstyle='->', color='blue'))
    
    plt.tight_layout()
    plt.savefig('projection_plots/accuracy_by_layer_with_projection.png', dpi=150)
    plt.show()
    
    # 2. Improvement over baseline early stopping
    if os.path.exists('early_stop_results.json'):
        # Load baseline early stop results if available
        with open('early_stop_results.json', 'r') as f:
            baseline_results = json.load(f)
        
        plt.figure(figsize=(10, 6))
        
        baseline_accs = []
        projection_accs = []
        comparison_layers = []
        
        for layer in layers:
            if f'layer_{layer}' in baseline_results.get('layer_results', {}):
                baseline_acc = baseline_results['layer_results'][f'layer_{layer}']['accuracy']
                projection_acc = results[f'layer_{layer}']['accuracy']
                
                baseline_accs.append(baseline_acc)
                projection_accs.append(projection_acc)
                comparison_layers.append(layer)
        
        if comparison_layers:
            x = np.arange(len(comparison_layers))
            width = 0.35
            
            plt.bar(x - width/2, baseline_accs, width, label='Direct Early Stop', color='orange', alpha=0.7)
            plt.bar(x + width/2, projection_accs, width, label='With Projection', color='blue', alpha=0.7)
            
            plt.xlabel('Layer')
            plt.ylabel('Accuracy')
            plt.title('Comparison: Direct Early Stop vs. With Projection')
            plt.xticks(x, comparison_layers)
            plt.legend()
            plt.grid(True, alpha=0.3, axis='y')
            
            plt.tight_layout()
            plt.savefig('projection_plots/comparison_with_baseline.png', dpi=150)
            plt.show()

def save_results(results, output_file='projection_results.json'):
    """Save results to JSON"""
    output = {
        'summary': {
            'baseline_accuracy': results['baseline']['accuracy'],
            'best_projection_layer': None,
            'best_projection_accuracy': 0
        },
        'layer_results': {}
    }
    
    # Find best layer
    for key, result in results.items():
        if key != 'baseline':
            if result['accuracy'] > output['summary']['best_projection_accuracy']:
                output['summary']['best_projection_accuracy'] = result['accuracy']
                output['summary']['best_projection_layer'] = result['layer']
    
    # Add layer results
    for key, result in results.items():
        output['layer_results'][key] = {
            'layer': result['layer'],
            'accuracy': result['accuracy'],
            'num_correct': sum(1 for p in result['predictions'] if p['correct']),
            'num_total': len(result['predictions'])
        }
    
    # Save example predictions
    output['example_predictions'] = {}
    for key, result in results.items():
        output['example_predictions'][key] = result['predictions'][:10]
    
    with open(output_file, 'w') as f:
        json.dump(output, f, indent=2)
    
    print(f"\nResults saved to {output_file}")

def main():
    # Load dataset
    dataset_path = "/net/scratch/slhleosun/counting-items-mechanisms/dataset.json"
    dataset = load_dataset(dataset_path)
    
    # Run experiment
    print("\n" + "="*50)
    print("PROJECTION TRAINING EXPERIMENT")
    print("="*50)
    
    results, projections = run_projection_experiment(
        model, tokenizer, dataset,
        start_layer=10,
        end_layer=20,
        train_samples=2000,  # Number of examples for training projections
        eval_samples=len(dataset)  # Evaluate on full dataset
    )
    
    # Visualize results
    print("\nCreating visualizations...")
    visualize_results(results)
    
    # Save results
    save_results(results)
    
    # Print summary
    print("\n" + "="*50)
    print("SUMMARY")
    print("="*50)
    print(f"Baseline accuracy (full model): {results['baseline']['accuracy']:.3%}")
    
    best_layer = None
    best_accuracy = 0
    for key, result in results.items():
        if key != 'baseline' and result['accuracy'] > best_accuracy:
            best_accuracy = result['accuracy']
            best_layer = result['layer']
    
    if best_layer:
        print(f"Best projection layer: {best_layer} with accuracy {best_accuracy:.3%}")
        print(f"Accuracy drop from baseline: {(results['baseline']['accuracy'] - best_accuracy):.3%}")
    
    print("\nAccuracy by layer:")
    for key in sorted(results.keys()):
        if key != 'baseline':
            layer = results[key]['layer']
            acc = results[key]['accuracy']
            print(f"  Layer {layer}: {acc:.3%}")
    
    print("\nExperiment complete!")
    print("Outputs:")
    print("  - projection_plots/: Visualizations")
    print("  - projection_models/: Trained projections") 
    print("  - projection_results.json: Detailed results")

if __name__ == "__main__":
    main()

PyTorch device: cuda
Loading model and tokenizer...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model has 32 layers, hidden size 4096
Loaded 5000 examples from dataset

PROJECTION TRAINING EXPERIMENT

Collecting hidden states for projection training...
Collecting hidden states from layers 32 and 32...


Collecting: 100%|███████████████████████████████████████████| 2000/2000 [01:37<00:00, 20.53it/s]


Collected 10000 hidden state pairs

=== Training projection for layer 10 ===
Collecting hidden states from layers 10 and 32...


Collecting: 100%|███████████████████████████████████████████| 2000/2000 [01:21<00:00, 24.50it/s]


Collected 10000 hidden state pairs
Training projection with 10000 samples...
Epoch 5/20, Loss: 0.1046, LR: 0.000854
Epoch 10/20, Loss: 0.0680, LR: 0.000500
Epoch 15/20, Loss: 0.0486, LR: 0.000146
Epoch 20/20, Loss: 0.0388, LR: 0.000000

=== Training projection for layer 11 ===
Collecting hidden states from layers 11 and 32...


Collecting: 100%|███████████████████████████████████████████| 2000/2000 [01:23<00:00, 24.08it/s]


Collected 10000 hidden state pairs
Training projection with 10000 samples...
Epoch 5/20, Loss: 0.1067, LR: 0.000854
Epoch 10/20, Loss: 0.0672, LR: 0.000500
Epoch 15/20, Loss: 0.0476, LR: 0.000146
Epoch 20/20, Loss: 0.0375, LR: 0.000000

=== Training projection for layer 12 ===
Collecting hidden states from layers 12 and 32...


Collecting: 100%|███████████████████████████████████████████| 2000/2000 [01:23<00:00, 24.06it/s]


Collected 10000 hidden state pairs
Training projection with 10000 samples...
Epoch 5/20, Loss: 0.0967, LR: 0.000854
Epoch 10/20, Loss: 0.0637, LR: 0.000500
Epoch 15/20, Loss: 0.0432, LR: 0.000146
Epoch 20/20, Loss: 0.0334, LR: 0.000000

=== Training projection for layer 13 ===
Collecting hidden states from layers 13 and 32...


Collecting: 100%|███████████████████████████████████████████| 2000/2000 [01:23<00:00, 24.00it/s]


Collected 10000 hidden state pairs
Training projection with 10000 samples...
Epoch 5/20, Loss: 0.0562, LR: 0.000854
Epoch 10/20, Loss: 0.0275, LR: 0.000500
Epoch 15/20, Loss: 0.0102, LR: 0.000146
Epoch 20/20, Loss: 0.0019, LR: 0.000000

=== Training projection for layer 14 ===
Collecting hidden states from layers 14 and 32...


Collecting: 100%|███████████████████████████████████████████| 2000/2000 [01:23<00:00, 24.09it/s]


Collected 10000 hidden state pairs
Training projection with 10000 samples...
Epoch 5/20, Loss: 0.0361, LR: 0.000854
Epoch 10/20, Loss: 0.0098, LR: 0.000500
Epoch 15/20, Loss: -0.0071, LR: 0.000146
Epoch 20/20, Loss: -0.0153, LR: 0.000000

=== Training projection for layer 15 ===
Collecting hidden states from layers 15 and 32...


Collecting: 100%|███████████████████████████████████████████| 2000/2000 [01:23<00:00, 23.84it/s]


Collected 10000 hidden state pairs
Training projection with 10000 samples...
Epoch 5/20, Loss: 0.0223, LR: 0.000854
Epoch 10/20, Loss: -0.0047, LR: 0.000500
Epoch 15/20, Loss: -0.0238, LR: 0.000146
Epoch 20/20, Loss: -0.0328, LR: 0.000000

=== Training projection for layer 16 ===
Collecting hidden states from layers 16 and 32...


Collecting: 100%|███████████████████████████████████████████| 2000/2000 [01:23<00:00, 23.95it/s]


Collected 10000 hidden state pairs
Training projection with 10000 samples...
Epoch 5/20, Loss: 0.0219, LR: 0.000854
Epoch 10/20, Loss: -0.0041, LR: 0.000500
Epoch 15/20, Loss: -0.0222, LR: 0.000146
Epoch 20/20, Loss: -0.0310, LR: 0.000000

=== Training projection for layer 17 ===
Collecting hidden states from layers 17 and 32...


Collecting: 100%|███████████████████████████████████████████| 2000/2000 [01:23<00:00, 23.97it/s]


Collected 10000 hidden state pairs
Training projection with 10000 samples...
Epoch 5/20, Loss: 0.0226, LR: 0.000854
Epoch 10/20, Loss: -0.0029, LR: 0.000500
Epoch 15/20, Loss: -0.0199, LR: 0.000146
Epoch 20/20, Loss: -0.0283, LR: 0.000000

=== Training projection for layer 18 ===
Collecting hidden states from layers 18 and 32...


Collecting: 100%|███████████████████████████████████████████| 2000/2000 [01:23<00:00, 23.83it/s]


Collected 10000 hidden state pairs
Training projection with 10000 samples...
Epoch 5/20, Loss: 0.0124, LR: 0.000854
Epoch 10/20, Loss: -0.0211, LR: 0.000500
Epoch 15/20, Loss: -0.0381, LR: 0.000146
Epoch 20/20, Loss: -0.0458, LR: 0.000000

=== Training projection for layer 19 ===
Collecting hidden states from layers 19 and 32...


Collecting: 100%|███████████████████████████████████████████| 2000/2000 [01:22<00:00, 24.10it/s]


Collected 10000 hidden state pairs
Training projection with 10000 samples...
Epoch 5/20, Loss: 0.0181, LR: 0.000854
Epoch 10/20, Loss: -0.0150, LR: 0.000500
Epoch 15/20, Loss: -0.0343, LR: 0.000146
Epoch 20/20, Loss: -0.0425, LR: 0.000000

=== Training projection for layer 20 ===
Collecting hidden states from layers 20 and 32...


Collecting: 100%|███████████████████████████████████████████| 2000/2000 [01:22<00:00, 24.13it/s]


Collected 10000 hidden state pairs
Training projection with 10000 samples...
Epoch 5/20, Loss: 0.0150, LR: 0.000854
Epoch 10/20, Loss: -0.0171, LR: 0.000500
Epoch 15/20, Loss: -0.0359, LR: 0.000146
Epoch 20/20, Loss: -0.0437, LR: 0.000000

EVALUATION PHASE

=== Evaluating Layer 10 with Projection ===


Layer 10 + projection:   0%|                                           | 0/5000 [00:01<?, ?it/s]


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:3! (when checking argument for argument tensors in method wrapper_CUDA_cat)

In [1]:
import json
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os
import re
import gc
from collections import defaultdict

# Set device - but note that model might be distributed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"PyTorch device: {device}")

# Create output directory
os.makedirs("early_stop_plots", exist_ok=True)

# Initialize model and tokenizer
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto"
)
model.eval()

# Get model config
num_layers = model.config.num_hidden_layers
hidden_size = model.config.hidden_size
vocab_size = model.config.vocab_size
print(f"Model has {num_layers} layers, hidden size {hidden_size}, vocab size {vocab_size}")

# Generation config
GEN_CFG = GenerationConfig(max_new_tokens=3, do_sample=False, max_length=None)

def load_dataset(filepath):
    """Load the counting dataset"""
    with open(filepath, 'r') as f:
        data = json.load(f)
    print(f"Loaded {len(data)} examples from dataset")
    return data

def extract_int(text: str):
    """Return first integer inside parentheses, or None."""
    INT_RE = re.compile(r"\d+")
    m = INT_RE.search(text)
    return int(m.group()) if m else -1

class EarlyStopLlamaModel(torch.nn.Module):
    """Wrapper to allow early stopping at specific layers"""
    def __init__(self, base_model, stop_at_layer=None):
        super().__init__()
        self.base_model = base_model
        self.stop_at_layer = stop_at_layer if stop_at_layer is not None else base_model.config.num_hidden_layers
        
    def forward(self, input_ids, attention_mask=None, **kwargs):
        # Use the base model's forward method but intercept at the right layer
        # We'll use a hook to stop execution at the desired layer
        
        outputs = []
        
        def hook_fn(module, input, output):
            # Capture the output and prevent further processing
            outputs.append(output[0])  # hidden states
            return output
        
        # Register hook on the target layer
        if self.stop_at_layer < self.base_model.config.num_hidden_layers:
            hook = self.base_model.model.layers[self.stop_at_layer - 1].register_forward_hook(hook_fn)
        
        try:
            # Run normal forward pass - it will process up to our hook
            with torch.no_grad():
                # Get model outputs
                model_outputs = self.base_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_hidden_states=True,
                    return_dict=True
                )
                
                # Get hidden states at the stop layer
                if outputs:
                    # Use the captured hidden states from our hook
                    hidden_states = outputs[0]
                else:
                    # If stop_at_layer == num_layers, use the final hidden states
                    hidden_states = model_outputs.hidden_states[self.stop_at_layer]
                
                # Apply final layer norm
                hidden_states = self.base_model.model.norm(hidden_states)
                
                # Get logits
                logits = self.base_model.lm_head(hidden_states)
                
                return logits
                
        finally:
            # Remove hook
            if self.stop_at_layer < self.base_model.config.num_hidden_layers:
                hook.remove()

def evaluate_early_stop(base_model, tokenizer, dataset, stop_at_layer, max_samples=None):
    """
    Evaluate model accuracy when stopping at a specific layer
    """
    correct = 0
    total = 0
    predictions = []
    
    samples = dataset if max_samples is None else dataset[:max_samples]
    
    for example in tqdm(samples, desc=f"Layer {stop_at_layer}"):
        prompt = example['prompt']
        true_answer = example['answer']
        
        # Tokenize
        inputs = tokenizer(prompt, return_tensors="pt")
        # Move to the same device as the model's embeddings
        device = base_model.model.embed_tokens.weight.device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Generate with early stop
        with torch.no_grad():
            # Simple greedy generation
            generated_ids = inputs['input_ids'].clone()
            
            for _ in range(GEN_CFG.max_new_tokens):
                # Get model outputs with hidden states
                outputs = base_model(
                    generated_ids,
                    output_hidden_states=True,
                    return_dict=True
                )
                
                # Get hidden states at the stop layer (layer outputs are 0-indexed, but we count from 1)
                # hidden_states[0] is embeddings, hidden_states[1] is layer 0 output, etc.
                hidden_states = outputs.hidden_states[stop_at_layer]
                
                # Apply final layer norm
                hidden_states = base_model.model.norm(hidden_states)
                
                # Get logits
                logits = base_model.lm_head(hidden_states)
                
                # Get next token
                next_token_logits = logits[:, -1, :]
                next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
                
                # Ensure next_token is on the same device as generated_ids
                next_token = next_token.to(generated_ids.device)
                
                # Append to sequence
                generated_ids = torch.cat([generated_ids, next_token], dim=1)
                
                # Check if we hit EOS or closing parenthesis
                if next_token.item() == tokenizer.eos_token_id or tokenizer.decode(next_token[0]) == ')':
                    break
        
        # Decode and extract prediction
        generation = tokenizer.decode(generated_ids[0][inputs['input_ids'].size(1):], skip_special_tokens=True)
        predicted_answer = extract_int(generation)
        
        predictions.append({
            'prompt': prompt,
            'true_answer': true_answer,
            'predicted_answer': predicted_answer,
            'generated_text': generation.strip(),
            'correct': predicted_answer == true_answer
        })
        
        if predicted_answer == true_answer:
            correct += 1
        total += 1
        
        # Debug first few examples
        if total <= 3:
            print(f"\nDebug Example {total}:")
            print(f"  Prompt ending: ...{prompt[-50:]}")
            print(f"  Generated: '{generation.strip()}'")
            print(f"  Extracted number: {predicted_answer}")
            print(f"  True answer: {true_answer}")
            print(f"  Correct: {predicted_answer == true_answer}")
    
    accuracy = correct / total if total > 0 else 0
    return accuracy, predictions



def run_early_stop_experiment(model, tokenizer, dataset, start_layer=10, end_layer=20):
    """
    Run early stop experiment across multiple layers
    """
    results = {}
    
    # Use provided baseline accuracy
    BASELINE_ACCURACY = 0.73520
    results['baseline'] = {
        'layer': num_layers,
        'accuracy': BASELINE_ACCURACY,
        'predictions': []  # Empty since we're not running it
    }
    
    # Run for specified layer range
    for layer in range(start_layer, end_layer + 1):
        print(f"\n=== Evaluating Early Stop at Layer {layer} ===")
        accuracy, predictions = evaluate_early_stop(
            model, tokenizer, dataset, layer, max_samples=len(dataset)
        )
        results[f'layer_{layer}'] = {
            'layer': layer,
            'accuracy': accuracy,
            'predictions': predictions
        }
        print(f"Accuracy for Early Stop at Layer {layer}: {accuracy:.3%}")
        
        # Clear memory periodically
        if layer % 5 == 0:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    return results

def analyze_results_by_category(results, dataset):
    """
    Analyze results broken down by category and list length
    """
    analysis = {}
    
    for key, result in results.items():
        layer = result['layer']
        predictions = result['predictions']
        
        # Analyze by category
        category_stats = defaultdict(lambda: {'correct': 0, 'total': 0})
        length_stats = defaultdict(lambda: {'correct': 0, 'total': 0})
        
        for i, pred in enumerate(predictions):
            if i < len(dataset):
                category = dataset[i]['type']
                list_length = dataset[i]['list_length']
                
                category_stats[category]['total'] += 1
                length_stats[list_length]['total'] += 1
                
                if pred['correct']:
                    category_stats[category]['correct'] += 1
                    length_stats[list_length]['correct'] += 1
        
        # Calculate accuracies
        category_acc = {cat: stats['correct']/stats['total'] 
                       for cat, stats in category_stats.items() if stats['total'] > 0}
        length_acc = {length: stats['correct']/stats['total'] 
                     for length, stats in length_stats.items() if stats['total'] > 0}
        
        analysis[key] = {
            'layer': layer,
            'overall_accuracy': result['accuracy'],
            'category_accuracy': category_acc,
            'length_accuracy': length_acc
        }
    
    return analysis

def visualize_results(results, analysis):
    """
    Create visualizations for the early stop experiment
    """
    # 1. Overall accuracy by layer
    plt.figure(figsize=(10, 6))
    
    layers = []
    accuracies = []
    
    for key in sorted(results.keys()):
        if key != 'baseline':
            layer = results[key]['layer']
            accuracy = results[key]['accuracy']
            layers.append(layer)
            accuracies.append(accuracy)
    
    # Add baseline
    baseline_acc = results['baseline']['accuracy']
    
    plt.plot(layers, accuracies, 'o-', linewidth=2, markersize=8, label='Early Stop')
    plt.axhline(y=baseline_acc, color='red', linestyle='--', 
                label=f'Full Model ({num_layers} layers): {baseline_acc:.3%}')
    
    plt.xlabel('Stop at Layer', fontsize=12)
    plt.ylabel('Accuracy', fontsize=12)
    plt.title('Counting Accuracy with Early Layer Decoding', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.ylim(0, 1.05)
    
    # Add annotations for key points
    if accuracies:
        max_early_acc = max(accuracies)
        max_early_layer = layers[accuracies.index(max_early_acc)]
        plt.annotate(f'Best: {max_early_acc:.3f} at layer {max_early_layer}',
                    xy=(max_early_layer, max_early_acc),
                    xytext=(max_early_layer + 1, max_early_acc - 0.05),
                    arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0.3'))
    
    plt.tight_layout()
    plt.savefig('early_stop_plots/accuracy_by_layer.png', dpi=150)
    plt.show()
    
    # 2. Heatmap of accuracy by category and layer
    categories = sorted(set(cat for a in analysis.values() 
                           for cat in a['category_accuracy'].keys()))
    
    if categories:
        accuracy_matrix = []
        layer_labels = []
        
        for key in sorted(results.keys()):
            if key != 'baseline' and key in analysis:
                layer = analysis[key]['layer']
                layer_labels.append(f'Layer {layer}')
                row = [analysis[key]['category_accuracy'].get(cat, 0) for cat in categories]
                accuracy_matrix.append(row)
        
        if accuracy_matrix:
            plt.figure(figsize=(10, 8))
            sns.heatmap(np.array(accuracy_matrix).T, 
                       xticklabels=layer_labels,
                       yticklabels=categories,
                       annot=True, fmt='.3f',
                       cmap='RdYlGn',
                       vmin=0, vmax=1,
                       cbar_kws={'label': 'Accuracy'})
            plt.xlabel('Layer')
            plt.ylabel('Category')
            plt.title('Counting Accuracy by Category and Early Stop Layer')
            plt.tight_layout()
            plt.savefig('early_stop_plots/accuracy_by_category_heatmap.png', dpi=150)
            plt.show()
    
    # 3. Error analysis - where do errors occur?
    plt.figure(figsize=(12, 6))
    
    for i, (key, result) in enumerate(results.items()):
        if key != 'baseline' and i < 5:  # Plot first 5 layers
            layer = result['layer']
            predictions = result['predictions']
            
            errors = []
            for pred in predictions:
                if not pred['correct'] and pred['predicted_answer'] != -1:
                    error = pred['predicted_answer'] - pred['true_answer']
                    errors.append(error)
            
            if errors:
                plt.subplot(2, 3, i + 1)
                plt.hist(errors, bins=20, alpha=0.7, edgecolor='black')
                plt.xlabel('Prediction Error')
                plt.ylabel('Count')
                plt.title(f'Layer {layer}')
                plt.grid(True, alpha=0.3)
    
    plt.suptitle('Distribution of Prediction Errors by Layer')
    plt.tight_layout()
    plt.savefig('early_stop_plots/error_distribution.png', dpi=150)
    plt.show()

def save_detailed_results(results, analysis, output_file='early_stop_results.json'):
    """
    Save detailed results to JSON file
    """
    # Convert to serializable format
    output = {
        'summary': {
            'baseline_accuracy': results['baseline']['accuracy'],
            'best_early_stop_layer': None,
            'best_early_stop_accuracy': 0
        },
        'layer_results': {},
        'analysis': analysis
    }
    
    # Find best early stop layer
    for key, result in results.items():
        if key != 'baseline':
            if result['accuracy'] > output['summary']['best_early_stop_accuracy']:
                output['summary']['best_early_stop_accuracy'] = result['accuracy']
                output['summary']['best_early_stop_layer'] = result['layer']
    
    # Add layer results (without full predictions to save space)
    for key, result in results.items():
        output['layer_results'][key] = {
            'layer': result['layer'],
            'accuracy': result['accuracy'],
            'num_correct': sum(1 for p in result['predictions'] if p['correct']),
            'num_total': len(result['predictions'])
        }
    
    # Add some example predictions
    output['example_predictions'] = {}
    for key, result in results.items():
        # Save first 5 predictions as examples
        output['example_predictions'][key] = result['predictions'][:5]
    
    with open(output_file, 'w') as f:
        json.dump(output, f, indent=2)
    
    print(f"\nDetailed results saved to {output_file}")

def main():
    # Load dataset
    dataset_path = "/net/scratch/slhleosun/counting-items-mechanisms/dataset.json"
    dataset = load_dataset(dataset_path)
    
    # Run experiment
    print("\nRunning early stop decode experiment...")
    results = run_early_stop_experiment(
        model, tokenizer, dataset, 
        start_layer=10, 
        end_layer=20
    )
    
    # Analyze results
    print("\nAnalyzing results...")
    analysis = analyze_results_by_category(results, dataset)
    
    # Create visualizations
    print("\nCreating visualizations...")
    visualize_results(results, analysis)
    
    # Save results
    save_detailed_results(results, analysis)
    
    # Print summary
    print("\n" + "="*50)
    print("SUMMARY")
    print("="*50)
    print(f"Baseline accuracy (all layers): {results['baseline']['accuracy']:.3%}")
    
    best_layer = None
    best_accuracy = 0
    for key, result in results.items():
        if key != 'baseline' and result['accuracy'] > best_accuracy:
            best_accuracy = result['accuracy']
            best_layer = result['layer']
    
    if best_layer:
        print(f"Best early stop layer: {best_layer} with accuracy {best_accuracy:.3%}")
        print(f"Accuracy drop from baseline: {(results['baseline']['accuracy'] - best_accuracy):.3%}")
    
    # Print accuracy progression
    print("\nAccuracy by layer:")
    for key in sorted(results.keys()):
        if key != 'baseline':
            layer = results[key]['layer']
            acc = results[key]['accuracy']
            print(f"  Layer {layer}: {acc:.3%}")
    a
    print("\nExperiment complete! Results saved to early_stop_plots/")

if __name__ == "__main__":
    main()

PyTorch device: cuda
Loading model and tokenizer...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model has 32 layers, hidden size 4096, vocab size 128256
Loaded 5000 examples from dataset

Running early stop decode experiment...

=== Evaluating Early Stop at Layer 10 ===


Layer 10:   0%|                                             | 1/5000 [00:17<23:51:32, 17.18s/it]


Debug Example 1:
  Prompt ending: ...ol, bear, lion, sheep, tangerine, peach]
Answer: (
  Generated: 'zielosenhower'
  Extracted number: -1
  True answer: 2
  Correct: False

Debug Example 2:
  Prompt ending: ...[cello, plane, cherry, banana, bed, cow]
Answer: (
  Generated: 'ziel♠theless'
  Extracted number: -1
  True answer: 1
  Correct: False


Layer 10:   0%|                                              | 3/5000 [00:17<6:19:27,  4.56s/it]


Debug Example 3:
  Prompt ending: ...pet, horse, cabinet, table, desk, flute]
Answer: (
  Generated: 'zielosenhower'
  Extracted number: -1
  True answer: 3
  Correct: False


Layer 10: 100%|█████████████████████████████████████████████| 5000/5000 [07:05<00:00, 11.75it/s]


Accuracy for Early Stop at Layer 10: 0.000%

=== Evaluating Early Stop at Layer 11 ===


Layer 11:   0%|                                              | 1/5000 [00:02<3:41:40,  2.66s/it]


Debug Example 1:
  Prompt ending: ...ol, bear, lion, sheep, tangerine, peach]
Answer: (
  Generated: '.nihkáchLBL'
  Extracted number: -1
  True answer: 2
  Correct: False

Debug Example 2:
  Prompt ending: ...[cello, plane, cherry, banana, bed, cow]
Answer: (
  Generated: '.nihkáchLBL'
  Extracted number: -1
  True answer: 1
  Correct: False


Layer 11:   0%|                                              | 3/5000 [00:02<1:05:00,  1.28it/s]


Debug Example 3:
  Prompt ending: ...pet, horse, cabinet, table, desk, flute]
Answer: (
  Generated: '.nihkáchLBL'
  Extracted number: -1
  True answer: 3
  Correct: False


Layer 11:  54%|████████████████████████▎                    | 2701/5000 [03:44<03:10, 12.04it/s]


KeyboardInterrupt: 