In [None]:
from classification_icl import ExperimentConfig, LinearTransformer, GaussianMixtureDataset
import torch
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from dataclasses import asdict
from typing import List, Dict, Any, Tuple

class CheckpointEvaluator:
    """Evaluator class for analyzing trained model checkpoints"""
    
    def __init__(self, checkpoint_dir: str):
        self.checkpoint_dir = Path(checkpoint_dir)
        
    def load_checkpoint(self, checkpoint_path: str) -> Tuple[LinearTransformer, ExperimentConfig]:
        """Load model and config from checkpoint"""
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        config = ExperimentConfig(**asdict(checkpoint['config']))
        model = LinearTransformer(config.d)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        return model, config

    def evaluate_risk_curves(
        self,
        model: LinearTransformer,
        d: int,
        max_seq_length: int = 50,
        num_samples: int = 1000,
        label_flip_ps: List[float] = [0.0, 0.15],
        device: str = 'cpu'
    ) -> Dict[str, np.ndarray]:
        """
        Evaluate model's risk curve - accuracy vs number of context examples.
        """
        model = model.to(device)
        model.W.data = torch.eye(d, device=device)
        results = {}
        R = d ** 0.7
        
        for label_flip_p in label_flip_ps:
            print(f"\nEvaluating risk curve for d={d}, label_flip_p={label_flip_p}")
            
            # Create dataset with reasonable size N
            dataset = GaussianMixtureDataset(
                d=d,
                N=max_seq_length,
                B=num_samples,
                R=R,  # Using R=1.0 like in test
                is_validation=True,
                label_flip_p=label_flip_p
            )
            
            # Get all data at once
            context_x, context_y, _, _ = [t.to(device) for t in dataset[0]]
            
            # Storage for accuracies at each position
            position_accuracies = np.zeros(max_seq_length-1)
            
            with torch.no_grad():
                for k in range(1, max_seq_length):
                    curr_context_x = context_x[:, :k]
                    curr_context_y = context_y[:, :k]
                    
                    preds = model.compute_in_context_preds(curr_context_x, curr_context_y)
                    accuracy = (preds == curr_context_y).float().mean().item()
                    position_accuracies[k-1] = accuracy
                    
                    if k % 10 == 0:
                        print(f"Position {k}: accuracy = {accuracy:.3f}")
            
            results[label_flip_p] = position_accuracies
            
        return results

    def old_evaluate_risk_curves(
        self,
        model: LinearTransformer,
        d: int,
        max_seq_length: int = 50,
        num_samples: int = 1000,
        label_flip_ps: List[float] = [0.0, 0.15],
        device: str = 'cpu'
    ) -> Dict[str, np.ndarray]:
        """
        Evaluate model's risk curve - accuracy vs number of context examples.
        Uses the model's compute_in_context_preds method for efficiency.
        
        Args:
            model: Trained LinearTransformer model
            d: Input dimension
            max_seq_length: Maximum sequence length to evaluate
            num_samples: Number of sequences to evaluate
            label_flip_ps: List of label flip probabilities to evaluate
            device: Device to run evaluation on
            
        Returns:
            Dictionary mapping label_flip_p to arrays of accuracies at each position
        """
        model = model.to(device)
        model.W.data = torch.eye(d, device=device) #TODO remove shortly, just testing for now
        results = {}
        R = 1
        
        for label_flip_p in label_flip_ps:
            print(f"\nEvaluating risk curve for d={d}, label_flip_p={label_flip_p}")
            
            # Create dataset with longer sequences
            dataset = GaussianMixtureDataset(
                d=d,
                N=max_seq_length,
                B=num_samples,
                R=R,
                is_validation=True,
                label_flip_p=label_flip_p
            )
            
            # Get all data at once
            context_x, context_y, _, _ = [t.to(device) for t in dataset[0]]
            
            # Storage for accuracies at each position
            position_accuracies = np.zeros(max_seq_length-1)
            
            with torch.no_grad():
                for k in range(1, max_seq_length):

                    # Use k examples to predict k+1
                    curr_context_x = context_x[:, :k]
                    curr_context_y = context_y[:, :k]
                    # next_y = context_y[:, k]
                    
                    # Use model's built-in prediction method for the k+1 position
                    preds = model.compute_in_context_preds(curr_context_x, curr_context_y)
                    accuracy = (preds == curr_context_y).float().mean().item()
                    position_accuracies[k-1] = accuracy
                    
                    if k % 10 == 0:
                        print(f"Position {k}: accuracy = {accuracy:.3f}")
            
            results[label_flip_p] = position_accuracies
            
        return results

    def plot_risk_curves(self, results_by_d: Dict[int, Dict[float, np.ndarray]], save_path: str = None):
        """
        Plot risk curves for different dimensions and label flip probabilities.
        
        Args:
            results_by_d: Dictionary mapping dimension d to results from evaluate_risk_curves
            save_path: Optional path to save the plot
        """
        plt.figure(figsize=(10, 6))
        colors = ['blue', 'red', 'green']
        styles = ['-', '--']
        
        for i, (d, results) in enumerate(sorted(results_by_d.items())):
            for j, (label_flip_p, accuracies) in enumerate(results.items()):
                x = np.arange(1, len(accuracies) + 1)
                label = f'd={d}, p={label_flip_p}'
                plt.plot(x, accuracies, label=label, color=colors[i], linestyle=styles[j])
        
        plt.xlabel('Number of Context Examples (k)')
        plt.ylabel('Memorization accuracy')
        plt.title('Memorization accuracy vs. number of in-context examples')
        #plt.ylabel('Accuracy on (k+1)th Example')
        #plt.title('Risk Curves for Different Dimensions and Label Noise')
        plt.legend()
        plt.grid(True)
        
        if save_path:
            plt.savefig(save_path)
            print(f"Plot saved to {save_path}")
        else:
            plt.show()
            
    def evaluate_checkpoint(self, checkpoint_file: str) -> Dict[int, Dict[float, np.ndarray]]:
        """
        Evaluate a single checkpoint and return risk curves.
        
        Args:
            checkpoint_file: Path to checkpoint file
            
        Returns:
            Results dictionary for plotting risk curves
        """
        print(f"\nEvaluating {checkpoint_file}")
        model, config = self.load_checkpoint(checkpoint_file)
        results = self.evaluate_risk_curves(
            model=model,
            d=config.d,
            max_seq_length=600,
            num_samples=1000,
            label_flip_ps=[0.0, 0.15],
            device='cuda' if torch.cuda.is_available() else 'cpu'
        )
        return {config.d: results}

# Example usage:
def main():
    evaluator = CheckpointEvaluator('checkpoints/')
    
    # Collect results for all dimensions
    all_results = {}
    # dimensions = [50, 500, 2500]
    dimensions = [500]
    
    for d in dimensions:
        # checkpoint_file = f"checkpoints/checkpoint_d{d}_*.pt"  # Adjust pattern as needed
        matches = list(Path('checkpoints/').glob(f"checkpoint_d{d}*.pt"))
        if matches:
            results = evaluator.evaluate_checkpoint(str(matches[0]))
            all_results.update(results)
    
    # Plot combined results
    evaluator.plot_risk_curves(all_results, save_path=None)

if __name__ == "__main__":
    main()


Evaluating checkpoints/checkpoint_d500_B500_R111_20241118_163035_step_499.pt

Evaluating risk curve for d=500, label_flip_p=0.0
Position 10: accuracy = 1.000
Position 20: accuracy = 1.000
Position 30: accuracy = 1.000
Position 40: accuracy = 1.000
Position 50: accuracy = 1.000
Position 60: accuracy = 1.000
Position 70: accuracy = 1.000
Position 80: accuracy = 1.000
Position 90: accuracy = 1.000
Position 100: accuracy = 1.000
Position 110: accuracy = 1.000
Position 120: accuracy = 1.000
Position 130: accuracy = 1.000
Position 140: accuracy = 1.000
Position 150: accuracy = 1.000
Position 160: accuracy = 1.000
Position 170: accuracy = 1.000
Position 180: accuracy = 1.000
Position 190: accuracy = 1.000
Position 200: accuracy = 1.000
Position 210: accuracy = 1.000
Position 220: accuracy = 1.000
Position 230: accuracy = 1.000
Position 240: accuracy = 1.000
Position 250: accuracy = 1.000
Position 260: accuracy = 1.000
Position 270: accuracy = 1.000
Position 280: accuracy = 1.000
Position 290

In [10]:
def test_numerical_precision():
    """Test to examine numerical precision effects with different R values"""
    # Setup parameters
    d = 1000  
    N = 50  # Longer sequence to see effects
    B = 100   
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Test both R values
    R_values = [1.0, d]
    
    for R in R_values:
        print(f"\n=== Testing R = {R} ===")
        
        # Create dataset
        dataset = GaussianMixtureDataset(d, N, B, R, label_flip_p=0.15)
        context_x, context_y, _, _ = [t.to(device) for t in dataset[0]]
        
        # Create model and set W to identity
        model = LinearTransformer(d).to(device)
        with torch.no_grad():
            model.W.data = torch.eye(d, device=device)
        
        # Examine quantities for different sequence lengths
        for k in [1, 5, 10, 20, 50]:
            print(f"\nSequence length k = {k}")
            curr_x = context_x[:, :k]
            curr_y = context_y[:, :k]
            
            # Compute components explicitly
            curr_y_signal = 2 * curr_y - 1
            
            # Examine x magnitudes
            x_norms = torch.norm(curr_x, dim=2)  # (B, k)
            print(f"X norms - mean: {x_norms.mean():.2f}, std: {x_norms.std():.2f}")
            
            # Examine contribution to mean
            contributions = curr_y_signal[..., None] * curr_x  # (B, k, d)
            contrib_norms = torch.norm(contributions, dim=2)  # (B, k)
            print(f"Contribution norms - mean: {contrib_norms.mean():.2f}, std: {contrib_norms.std():.2f}")
            
            # Examine hat_mu
            hat_mu = (1/k) * torch.sum(contributions, dim=1)  # (B, d)
            mu_norms = torch.norm(hat_mu, dim=1)  # (B,)
            print(f"hat_mu norms - mean: {mu_norms.mean():.2f}, std: {mu_norms.std():.2f}")
            
            # Examine logits before thresholding
            logits = (hat_mu[:, None, :] @ curr_x.transpose(-1, -2))[:, 0, :]  # (B, k)
            print(f"Logit magnitudes - mean: {logits.abs().mean():.2f}, std: {logits.std():.2f}")
            print(f"Logit range: [{logits.min():.2f}, {logits.max():.2f}]")
            
            # Get predictions and check accuracy
            preds = model.compute_in_context_preds(curr_x, curr_y)
            accuracy = (preds == curr_y).float().mean().item()
            print(f"Accuracy: {accuracy:.3f}")
            
            # Check if reconstructions are exact
            # Get a few examples where prediction failed
            if accuracy < 1.0:
                failed_cases = (preds != curr_y).any(dim=1)
                b_idx = failed_cases.nonzero()[0].item()
                print("\nExample of failed reconstruction:")
                print(f"Original y: {curr_y[b_idx].cpu().numpy()}")
                print(f"Predicted:  {preds[b_idx].cpu().numpy()}")
                print(f"Logits:     {logits[b_idx].cpu().numpy()}")
                
                # Examine exact y_i * x_i @ x_i values for this case
                for i in range(k):
                    x_i = curr_x[b_idx, i]
                    y_i = curr_y_signal[b_idx, i]
                    recon = y_i * (x_i @ x_i)
                    print(f"Position {i}: y_i * (x_i @ x_i) = {recon:.5f}")

if __name__ == "__main__":
    test_numerical_precision()


=== Testing R = 1.0 ===

Sequence length k = 1
X norms - mean: 31.72, std: 0.77
Contribution norms - mean: 31.72, std: 0.77
hat_mu norms - mean: 31.72, std: 0.77
Logit magnitudes - mean: 1006.67, std: 999.79
Logit range: [-1116.14, 1096.52]
Accuracy: 1.000

Sequence length k = 5
X norms - mean: 31.63, std: 0.70
Contribution norms - mean: 31.63, std: 0.70
hat_mu norms - mean: 14.16, std: 0.31
Logit magnitudes - mean: 200.50, std: 201.29
Logit range: [-240.88, 234.39]
Accuracy: 1.000

Sequence length k = 10
X norms - mean: 31.63, std: 0.70
Contribution norms - mean: 31.63, std: 0.70
hat_mu norms - mean: 9.98, std: 0.20
Logit magnitudes - mean: 99.70, std: 100.29
Logit range: [-130.24, 127.40]
Accuracy: 1.000

Sequence length k = 20
X norms - mean: 31.62, std: 0.70
Contribution norms - mean: 31.62, std: 0.70
hat_mu norms - mean: 7.09, std: 0.14
Logit magnitudes - mean: 50.27, std: 50.79
Logit range: [-77.50, 74.13]
Accuracy: 1.000

Sequence length k = 50
X norms - mean: 31.63, std: 0.70


In [8]:

def validate_with_identity_model(d: int, N: int = 50, num_samples: int = 1000):
    """
    Validate evaluation code using a model with identity weight matrix.
    This should perform well when signal-to-noise ratio is high enough.
    """
    print(f"\nValidating with identity model: d={d}, N={N}")
    
    # Create model and set weights to identity
    model = LinearTransformer(d)
    with torch.no_grad():
        model.W.copy_(torch.eye(d))
    model.eval()
    
    # Create dataset with high signal-to-noise ratio
    R = 2*d**0.3 # High SNR regime
    print(f"Using R = {R:.2f}")
    
    dataset = GaussianMixtureDataset(
        d=d,
        N=N,
        B=num_samples,
        R=R,
        is_validation=True,
        label_flip_p=0.0
    )
    
    # Get data
    context_x, context_y, _, _ = dataset[0]
    accuracies = []
    
    with torch.no_grad():
        # Test prediction at each sequence length
        for k in range(1, N):
            curr_context_x = context_x[:, :k]
            curr_context_y = context_y[:, :k]
            next_y = context_y[:, k]
            
            # Get predictions
            preds = model.compute_in_context_preds(curr_context_x, curr_context_y)
            accuracy = (preds[:, -1] == next_y).float().mean().item()
            accuracies.append(accuracy)
            
            print(f"k={k}: accuracy={accuracy:.3f}")
            
            # Additional debugging for first few iterations
            if k <= 3:
                # Check intermediate values
                batch_idx = 0  # Look at first sequence in batch
                context_term = (1/k) * torch.sum(
                    (2 * curr_context_y[batch_idx] - 1).unsqueeze(-1) * curr_context_x[batch_idx], 
                    dim=0
                )
                signal_norm = torch.norm(context_term).item()
                print(f"  Signal norm at position {k}: {signal_norm:.3f}")
                
                # Print a few example predictions
                for i in range(min(3, len(next_y))):
                    pred = preds[i, -1].item()
                    true_y = next_y[i].item()
                    print(f"  Example {i}: pred={pred:.1f}, true={true_y}")
    
    return accuracies

validate_with_identity_model(d=50, N=200, num_samples=100)


Validating with identity model: d=50, N=200
Using R = 6.47
k=1: accuracy=0.460
  Signal norm at position 1: 8.911
  Example 0: pred=1.0, true=0.0
  Example 1: pred=0.0, true=0.0
  Example 2: pred=0.0, true=1.0
k=2: accuracy=0.440
  Signal norm at position 2: 7.665
  Example 0: pred=0.0, true=1.0
  Example 1: pred=0.0, true=0.0
  Example 2: pred=1.0, true=1.0
k=3: accuracy=0.510
  Signal norm at position 3: 7.490
  Example 0: pred=1.0, true=1.0
  Example 1: pred=0.0, true=1.0
  Example 2: pred=1.0, true=0.0
k=4: accuracy=0.580
k=5: accuracy=0.590
k=6: accuracy=0.510
k=7: accuracy=0.460
k=8: accuracy=0.610
k=9: accuracy=0.520
k=10: accuracy=0.400
k=11: accuracy=0.490
k=12: accuracy=0.490
k=13: accuracy=0.510
k=14: accuracy=0.510
k=15: accuracy=0.540
k=16: accuracy=0.510
k=17: accuracy=0.560
k=18: accuracy=0.470
k=19: accuracy=0.510
k=20: accuracy=0.540
k=21: accuracy=0.500
k=22: accuracy=0.520
k=23: accuracy=0.370
k=24: accuracy=0.410
k=25: accuracy=0.610
k=26: accuracy=0.470
k=27: accu

[0.46000000834465027,
 0.4399999976158142,
 0.5099999904632568,
 0.5799999833106995,
 0.5899999737739563,
 0.5099999904632568,
 0.46000000834465027,
 0.6100000143051147,
 0.5199999809265137,
 0.4000000059604645,
 0.49000000953674316,
 0.49000000953674316,
 0.5099999904632568,
 0.5099999904632568,
 0.5400000214576721,
 0.5099999904632568,
 0.5600000023841858,
 0.4699999988079071,
 0.5099999904632568,
 0.5400000214576721,
 0.5,
 0.5199999809265137,
 0.3700000047683716,
 0.4099999964237213,
 0.6100000143051147,
 0.4699999988079071,
 0.5799999833106995,
 0.5299999713897705,
 0.4699999988079071,
 0.49000000953674316,
 0.5,
 0.4699999988079071,
 0.49000000953674316,
 0.49000000953674316,
 0.5199999809265137,
 0.5299999713897705,
 0.5099999904632568,
 0.5400000214576721,
 0.5199999809265137,
 0.5699999928474426,
 0.5099999904632568,
 0.550000011920929,
 0.5299999713897705,
 0.5600000023841858,
 0.5799999833106995,
 0.5,
 0.49000000953674316,
 0.5899999737739563,
 0.4099999964237213,
 0.490000

In [9]:
import torch
import numpy as np
from classification_icl import LinearTransformer
from typing import Tuple

def generate_test_data(d: int, N: int, B: int, R: float) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Generate test data with known signal structure to verify model behavior.
    """
    # Generate mean vectors
    mus = torch.randn(B, d)
    mus = mus / torch.norm(mus, dim=1, keepdim=True) * R
    
    # Generate labels
    y = (torch.rand(B, N) > 0.5).float()
    y_signal = 2 * y - 1  # Convert to ±1
    
    # Generate inputs with signal
    z = torch.randn(B, N, d)  # noise
    x = y_signal.unsqueeze(-1) * mus.unsqueeze(1) + z
    
    return x, y

def debug_identity_model(d: int = 5000, N: int = 20, B: int = 1000):
    """
    Debug the identity model with detailed computation checks.
    """
    print(f"\nDebugging identity model with d={d}, N={N}, B={B}")
    
    # Set up model with identity matrix
    model = LinearTransformer(d)
    with torch.no_grad():
        model.W.copy_(torch.eye(d))
    model.eval()
    
    # Generate data with known structure
    R = 2 * d**0.3
    print(f"R = {R:.2f}")
    x, y = generate_test_data(d, N, B, R)
    
    # Debug computations step by step
    with torch.no_grad():
        for k in range(1, N):
            print(f"\nk = {k}")
            
            # Get current context and next example
            context_x = x[:, :k]  # (B, k, d)
            context_y = y[:, :k]  # (B, k)
            next_x = x[:, k]      # (B, d)
            next_y = y[:, k]      # (B)
            
            # Manual computation for clarity
            # 1. Convert labels to ±1
            y_signal = 2 * context_y - 1  # (B, k)
            
            # 2. Compute context term (1/k)∑y_i x_i
            context_term = (1/k) * torch.sum(y_signal.unsqueeze(-1) * context_x, dim=1)  # (B, d)
            
            # 3. Since W is identity, transformed = context_term
            # 4. Compute prediction logit
            logits = torch.sum(context_term * next_x, dim=1)  # (B)
            
            # Compute accuracy
            preds = (logits > 0).float()
            acc = (preds == next_y).float().mean().item()
            
            # Detailed diagnostics
            signal_norm = torch.norm(context_term, dim=1)
            noise_level = torch.norm(next_x, dim=1)
            dot_products = torch.sum(context_term * next_x, dim=1)
            
            print(f"Accuracy: {acc:.3f}")
            print(f"Signal norm stats: mean={signal_norm.mean():.3f}, std={signal_norm.std():.3f}")
            print(f"Noise level stats: mean={noise_level.mean():.3f}, std={noise_level.std():.3f}")
            print(f"Dot product stats: mean={dot_products.mean():.3f}, std={dot_products.std():.3f}")
            
            # Look at a few examples in detail
            if k <= 3:
                for i in range(3):
                    print(f"\nExample {i}:")
                    print(f"True next y: {next_y[i].item()}")
                    print(f"Predicted y: {preds[i].item()}")
                    print(f"Logit: {logits[i].item():.3f}")
                    print(f"Signal norm: {signal_norm[i].item():.3f}")
                    print(f"Dot product: {dot_products[i].item():.3f}")
                    # Check alignment between context_term and true mu
                    if i == 0:
                        # Estimate mu from next_x using true label
                        true_signal = next_y[i].item() * 2 - 1
                        estimated_mu = true_signal * next_x[i] / R
                        alignment = torch.dot(context_term[i], estimated_mu) / (torch.norm(context_term[i]) * torch.norm(estimated_mu))
                        print(f"Alignment with true direction: {alignment.item():.3f}")

if __name__ == "__main__":
    debug_identity_model()


Debugging identity model with d=5000, N=20, B=1000
R = 25.75

k = 1
Accuracy: 1.000
Signal norm stats: mean=75.252, std=0.753
Noise level stats: mean=75.257, std=0.752
Dot product stats: mean=-22.941, std=670.401

Example 0:
True next y: 0.0
Predicted y: 0.0
Logit: -655.774
Signal norm: 75.684
Dot product: -655.774
Alignment with true direction: 0.117

Example 1:
True next y: 0.0
Predicted y: 0.0
Logit: -610.141
Signal norm: 75.443
Dot product: -610.141

Example 2:
True next y: 1.0
Predicted y: 1.0
Logit: 695.984
Signal norm: 74.590
Dot product: 695.984

k = 2
Accuracy: 1.000
Signal norm stats: mean=56.254, std=0.560
Noise level stats: mean=75.232, std=0.760
Dot product stats: mean=-4.550, std=666.523

Example 0:
True next y: 0.0
Predicted y: 0.0
Logit: -616.523
Signal norm: 55.962
Dot product: -616.523
Alignment with true direction: 0.148

Example 1:
True next y: 1.0
Predicted y: 1.0
Logit: 682.434
Signal norm: 55.825
Dot product: 682.434

Example 2:
True next y: 0.0
Predicted y: 0.0

In [13]:
import torch
import numpy as np
from classification_icl import LinearTransformer

def test_identity_simple(d=10, N=200, B=100):
    """
    Simple test of identity matrix behavior with clear signal structure
    """
    # Create model with identity matrix
    model = LinearTransformer(d)
    with torch.no_grad():
        model.W.copy_(torch.eye(d))
    model.eval()
    
    # Generate data with known signal
    # R = 2 * d**0.3
    R = d**0.1
    print(f"\nTesting identity matrix: d={d}, N={N}, B={B}, R={R:.2f}")
    
    # Generate mean vectors (true signals)
    mus = torch.randn(B, d)
    mus = mus / torch.norm(mus, dim=1, keepdim=True) * R
    
    # Generate sequences with these means
    y = (torch.rand(B, N) > 0.5).float()
    y_signal = 2 * y - 1  # Convert to ±1
    
    # Add noise to create inputs
    z = torch.randn(B, N, d)  # Standard normal noise
    x = y_signal.unsqueeze(-1) * mus.unsqueeze(1) + z
    
    print("\nVerifying signal and noise properties:")
    signal_norms = torch.norm(y_signal.unsqueeze(-1) * mus.unsqueeze(1), dim=-1)
    noise_norms = torch.norm(z, dim=-1)
    print(f"Signal norms: mean={signal_norms.mean():.3f}, std={signal_norms.std():.3f}")
    print(f"Noise norms: mean={noise_norms.mean():.3f}, std={noise_norms.std():.3f}")
    
    # Test prediction at each k
    for k in range(1, 5):  # Just test first few k values
        print(f"\nk = {k}")
        
        # Get current context
        context_x = x[:, :k]  # (B, k, d)
        context_y = y[:, :k]  # (B, k)
        next_x = x[:, k]      # (B, d)
        next_y = y[:, k]      # (B)
        
        # Manual prediction computation
        y_signal = 2 * context_y - 1
        context_term = (1/k) * torch.sum(y_signal.unsqueeze(-1) * context_x, dim=1)
        
        # Since W is identity, just compute dot product with next_x
        logits = torch.sum(context_term * next_x, dim=1)
        preds = (logits > 0).float()
        acc = (preds == next_y).float().mean().item()
        
        # Compute some diagnostics
        signal_norm = torch.norm(context_term, dim=1)
        print(f"Accuracy: {acc:.3f}")
        print(f"Signal norm: mean={signal_norm.mean():.3f}, std={signal_norm.std():.3f}")
        
        # Look at a few example predictions in detail
        for i in range(3):
            alignment = torch.dot(context_term[i], mus[i]) / (torch.norm(context_term[i]) * torch.norm(mus[i]))
            print(f"\nExample {i}:")
            print(f"True y: {next_y[i].item()}")
            print(f"Pred y: {preds[i].item()}")
            print(f"Logit: {logits[i].item():.3f}")
            print(f"Signal norm: {signal_norm[i].item():.3f}")
            print(f"Alignment with true mu: {alignment.item():.3f}")

if __name__ == "__main__":
    test_identity_simple()


Testing identity matrix: d=10, N=200, B=100, R=1.26

Verifying signal and noise properties:
Signal norms: mean=1.259, std=0.000
Noise norms: mean=3.087, std=0.692

k = 1
Accuracy: 0.710
Signal norm: mean=3.255, std=0.666

Example 0:
True y: 1.0
Pred y: 1.0
Logit: 0.820
Signal norm: 3.008
Alignment with true mu: 0.015

Example 1:
True y: 0.0
Pred y: 0.0
Logit: -1.353
Signal norm: 3.446
Alignment with true mu: 0.263

Example 2:
True y: 1.0
Pred y: 1.0
Logit: 6.998
Signal norm: 3.700
Alignment with true mu: 0.149

k = 2
Accuracy: 0.670
Signal norm: mean=2.519, std=0.609

Example 0:
True y: 1.0
Pred y: 0.0
Logit: -0.719
Signal norm: 2.393
Alignment with true mu: 0.521

Example 1:
True y: 0.0
Pred y: 1.0
Logit: 1.792
Signal norm: 2.315
Alignment with true mu: 0.251

Example 2:
True y: 0.0
Pred y: 0.0
Logit: -1.944
Signal norm: 3.212
Alignment with true mu: 0.418

k = 3
Accuracy: 0.800
Signal norm: mean=2.186, std=0.488

Example 0:
True y: 1.0
Pred y: 1.0
Logit: 1.822
Signal norm: 1.913
Ali

In [1]:
from classification_icl import ExperimentConfig, LinearTransformer, GaussianMixtureDataset

import torch
import numpy as np
import pandas as pd
from pathlib import Path
import json
from dataclasses import asdict
from typing import List, Dict, Any, Tuple

class CheckpointEvaluator:
    """Evaluator class for analyzing trained model checkpoints"""
    
    def __init__(self, checkpoint_dir: str):
        self.checkpoint_dir = Path(checkpoint_dir)
        
    def load_checkpoint(self, checkpoint_path: str) -> Tuple[LinearTransformer, ExperimentConfig]:
        """Load model and config from checkpoint"""
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        
        # Recreate config
        config = ExperimentConfig(**asdict(checkpoint['config']))
        
        # Initialize and load model
        model = LinearTransformer(config.d)
        model.load_state_dict(checkpoint['model_state_dict'])
        # model.to(config.device)
        model.eval()
        
        return model, config

    def evaluate_sequential_predictions(
        self, 
        model: LinearTransformer,
        config: ExperimentConfig,
        num_samples: int = 100,
        max_seq_length: int = 50
    ) -> Dict[str, Any]:
        """
        Evaluate model's sequential prediction performance.
        """
        eval_dataset = GaussianMixtureDataset(
            d=config.d,
            N=max_seq_length,
            B=num_samples,
            R=config.R_val,
            device=config.device,
            is_validation=True,
            label_flip_p=config.label_flip_p
        )
        
        accuracies = []
        confidences = []
        position_accuracies = np.zeros(max_seq_length)
        position_counts = np.zeros(max_seq_length)
        
        with torch.no_grad():
            context_x, context_y, target_x, target_y = eval_dataset[0]
            print(f"Data shapes: context_x {context_x.shape}, context_y {context_y.shape}")
            print(f"Data ranges: context_x [{context_x.min():.3f}, {context_x.max():.3f}], context_y [{context_y.min():.3f}, {context_y.max():.3f}]")
            
            for i in range(num_samples):
                x_seq = context_x[i]  # Shape: (N, d)
                y_seq = context_y[i]  # Shape: (N,)
                
                seq_preds = []
                seq_confs = []
                
                for k in range(max_seq_length - 1):
                    curr_context_x = x_seq[:k+1][None, ...]  
                    curr_context_y = y_seq[:k+1][None, ...]  
                    curr_target_x = x_seq[k+1][None, ...]    
                    curr_target_y = y_seq[k+1]              
                    
                    logits = model(curr_context_x, curr_context_y, curr_target_x)
                    
                    if torch.isnan(logits).any():
                        print(f"NaN in logits at position {k}, sequence {i}")
                        print(f"Input ranges: context_x [{curr_context_x.min():.3f}, {curr_context_x.max():.3f}]")
                        continue
                    
                    pred = (logits > 0).float().item()
                    conf = torch.sigmoid(logits).item()
                    
                    seq_preds.append(pred)
                    seq_confs.append(conf)
                    
                    position_accuracies[k] += float(pred == curr_target_y.item())
                    position_counts[k] += 1
                
                accuracies.append(seq_preds)
                confidences.append(seq_confs)
                
                if i % 10 == 0:
                    print(f"Processed sequence {i}, current mean acc: {np.mean(position_accuracies[position_counts > 0] / position_counts[position_counts > 0]):.3f}")
        
        # Safe division for position accuracies
        valid_positions = position_counts > 0
        if valid_positions.any():
            position_accuracies[valid_positions] = position_accuracies[valid_positions] / position_counts[valid_positions]
        
        mean_acc = float(np.mean(position_accuracies[valid_positions])) if valid_positions.any() else 0.0
        
        metrics = {
            'position_accuracies': position_accuracies.tolist(),  # convert to list for JSON serialization
            'mean_accuracy': mean_acc,
            'accuracies': accuracies,
            'confidences': confidences,
            'config': asdict(config)
        }
        
        print(f"\nFinal metrics:")
        print(f"Mean accuracy: {mean_acc:.3f}")
        print(f"Position accuracies range: [{min(position_accuracies[valid_positions]):.3f}, {max(position_accuracies[valid_positions]):.3f}]")
        
        return metrics

    def old_evaluate_sequential_predictions(
        self, 
        model: LinearTransformer,
        config: ExperimentConfig,
        num_samples: int = 100,
        max_seq_length: int = 50
    ) -> Dict[str, Any]:
        """
        Evaluate model's sequential prediction performance.
        Uses broadcasting notation instead of squeeze/unsqueeze.
        """
        # Create dataset with longer sequences
        eval_dataset = GaussianMixtureDataset(
            d=config.d,
            N=max_seq_length,
            B=num_samples,
            R=config.R_val,
            device='cpu',
            is_validation=True,
            label_flip_p=config.label_flip_p
        )
        
        # Storage for metrics
        accuracies = []
        confidences = []
        position_accuracies = np.zeros(max_seq_length)
        position_counts = np.zeros(max_seq_length)
        
        with torch.no_grad():
            # Get data using dataset's __getitem__
            x_all, y_all, target_x, target_y = eval_dataset[0]  # Changed this line
        
            # For each sequence
            for i in range(num_samples):
                x_seq = x_all[i]   # Shape: (N, d)
                y_seq = y_all[i]   # Shape: (N,)
                
                # For each position k, predict k+1 using 1:k
                seq_preds = []
                seq_confs = []
                
                for k in range(max_seq_length - 1):
                    # Get context up to position k
                    context_x = x_seq[:k+1][None, :, :]  # Add batch dim: (1, k+1, d)
                    context_y = y_seq[:k+1][None, :]     # Add batch dim: (1, k+1)
                    target_x = x_seq[k+1][None, :]       # Add batch dim: (1, d)
                    target_y = y_seq[k+1]                # Keep as scalar
                    
                    # Make prediction
                    logits = model(context_x, context_y, target_x)  # Shape: (1,)
                    
                    # Store prediction and confidence
                    pred = (logits > 0).float().item()
                    conf = torch.sigmoid(logits).item()
                    seq_preds.append(pred)
                    seq_confs.append(conf)
                    
                    # Update position accuracies
                    position_accuracies[k] += (pred == target_y.item())
                    position_counts[k] += 1
                
                # Store sequence metrics
                accuracies.append(seq_preds)
                confidences.append(seq_confs)
        
        # Compute metrics
        position_accuracies = position_accuracies / position_counts
        
        metrics = {
            'position_accuracies': position_accuracies,
            'mean_accuracy': np.mean(position_accuracies),
            'accuracies': accuracies,
            'confidences': confidences,
            'config': asdict(config)
        }
        
        return metrics

    def evaluate_all_checkpoints(
        self,
        checkpoint_pattern: str = "checkpoint_step_*.pt",
        save_results: bool = True
    ) -> Dict[str, Dict[str, Any]]:
        """
        Evaluate all checkpoints matching pattern.
        
        Args:
            checkpoint_pattern: Pattern to match checkpoint files
            save_results: Whether to save results to disk
            
        Returns:
            Dictionary mapping checkpoint paths to their metrics
        """
        results = {}
        
        # Find all checkpoint files
        checkpoint_files = list(self.checkpoint_dir.glob(checkpoint_pattern))
        print(f"Found {len(checkpoint_files)} checkpoints to evaluate")
        
        for checkpoint_file in checkpoint_files:
            print(f"\nEvaluating {checkpoint_file}")
            
            # Load checkpoint
            model, config = self.load_checkpoint(checkpoint_file)
            
            # Run evaluation
            metrics = self.evaluate_sequential_predictions(model, config)
            
            # Store results
            results[str(checkpoint_file)] = metrics
            
        return results
    

    def analyze_results(self, results: Dict[str, Dict[str, Any]]) -> pd.DataFrame:
        """
        Analyze evaluation results across checkpoints.
        
        Args:
            results: Dictionary of evaluation results per checkpoint
            
        Returns:
            DataFrame with analysis
        """
        records = []
        
        for checkpoint_path, metrics in results.items():
            config = metrics['config']
            
            record = {
                'checkpoint': checkpoint_path,
                'd': config['d'],
                'B': config['B'],
                'R': config['R_train'],
                'mean_accuracy': metrics['mean_accuracy'],
                'final_accuracy': metrics['position_accuracies'][-1],
                'early_accuracy': np.mean(metrics['position_accuracies'][:5]),
            }
            records.append(record)
            
        df = pd.DataFrame(records)
        return df



ckpt = CheckpointEvaluator('checkpoints/')

In [2]:
eval = CheckpointEvaluator('checkpoints/')
res = eval.evaluate_all_checkpoints(checkpoint_pattern = 'checkpoint*')

Found 5 checkpoints to evaluate

Evaluating checkpoints/checkpoint_d50_B50_R13_20241118_163034_step_499.pt
Data shapes: context_x torch.Size([100, 50, 50]), context_y torch.Size([100, 50])
Data ranges: context_x [-9.566, 9.064], context_y [0.000, 1.000]
Processed sequence 0, current mean acc: 1.000
Processed sequence 10, current mean acc: 1.000
Processed sequence 20, current mean acc: 1.000
Processed sequence 30, current mean acc: 1.000
Processed sequence 40, current mean acc: 1.000
Processed sequence 50, current mean acc: 1.000
Processed sequence 60, current mean acc: 1.000
Processed sequence 70, current mean acc: 1.000
Processed sequence 80, current mean acc: 1.000
Processed sequence 90, current mean acc: 1.000

Final metrics:
Mean accuracy: 1.000
Position accuracies range: [1.000, 1.000]

Evaluating checkpoints/checkpoint_d500_B500_R111_20241118_163035_step_499.pt
Data shapes: context_x torch.Size([100, 50, 500]), context_y torch.Size([100, 50])
Data ranges: context_x [-23.959, 25.3

In [3]:
df = eval.analyze_results(res)

In [5]:
pprint(df)

Pretty printing has been turned ON


In [6]:
df

Unnamed: 0,checkpoint,d,B,R,mean_accuracy,final_accuracy,early_accuracy
0,checkpoints/checkpoint_d50_B50_R13_20241118_16...,50,50,35.355339,1.0,0.0,1.0
1,checkpoints/checkpoint_d500_B500_R111_20241118...,500,500,111.803399,1.0,0.0,1.0
2,checkpoints/checkpoint_d2500_B2500_R35_2024111...,2500,2500,250.0,1.0,0.0,1.0
3,checkpoints/checkpoint_d50_B50_R35_20241118_16...,50,50,35.355339,1.0,0.0,1.0
4,checkpoints/checkpoint_d500_B500_R23_20241118_...,500,500,111.803399,1.0,0.0,1.0


In [6]:

ckpt = CheckpointEvaluator('checkpoints/')
c = torch.load('checkpoints/checkpoint_d2500_B2500_R35_20241118_171451_step_299.pt')

In [24]:
from dataclasses import asdict
asdict(c['config'])

{'d': 2500,
 'N': 40,
 'B': 2500,
 'B_val': 500,
 'R_train': 250.0,
 'R_val': 35.35533905932738,
 'max_steps': 300,
 'checkpoint_steps': [299],
 'label_flip_p': 0.0,
 'learning_rate': 0.01,
 'use_cuda': True,
 'use_wandb': False,
 'wandb_project': 'linear-transformer',
 'save_checkpoints': True,
 'save_results': True,
 'checkpoint_dir': 'checkpoints',
 'results_dir': 'results_20241118_171451',
 'experiment_name': 'd2500_B2500_R35_20241118_171451'}

In [10]:
from pprint import pprint 
pprint(c)

{'config': ExperimentConfig(d=2500,
                            N=40,
                            B=2500,
                            B_val=500,
                            R_train=250.0,
                            R_val=35.35533905932738,
                            max_steps=300,
                            checkpoint_steps=[299],
                            label_flip_p=0.0,
                            learning_rate=0.01,
                            use_cuda=True,
                            use_wandb=False,
                            wandb_project='linear-transformer',
                            save_checkpoints=True,
                            save_results=True,
                            checkpoint_dir='checkpoints',
                            results_dir='results_20241118_171451',
                            experiment_name='d2500_B2500_R35_20241118_171451'),
 'metrics': {'batch_time': [1.220353603363037,
                            0.5143530368804932,
                    