# Flattened Autoencoder for Particle Physics Anomaly Detection

This notebook implements and trains a simple flattened (fully connected) autoencoder for anomaly detection in particle physics events. The model processes events as flattened vectors of all point coordinates and energies.

## Key Features:
- **Input**: Flattened event data - all [x, y, z, energy] values concatenated into a single vector
- **Architecture**: Simple fully connected encoder-decoder with SiLU activations
- **Padding**: Fixed-size input with zero padding for smaller events
- **Anomaly Detection**: Reconstruction loss-based approach
- **Simplicity**: Straightforward approach that treats each event as a high-dimensional vector

## Comparison with CNN Approach:
- **Pros**: Simpler architecture, fewer hyperparameters, faster training for small datasets
- **Cons**: No spatial structure awareness, higher memory usage for variable-size events, less scalable

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import glob
import os
from tqdm import tqdm
import time
from multiprocessing import Pool, cpu_count

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Style for plots
plt.style.use('default')
sns.set_palette("husl")

In [None]:
# Configuration
CONFIG = {
    'data_path': '/nevis/houston/home/sc5303/anomaly/offline_anomaly/anomalous_showers/pi0_decay/pi0_tree_output_*.npy',
    'max_points': 1000,  # Maximum number of points per event (for flattened approach)
    'min_points': 10,    # Minimum number of points per event
    'max_files': 5,      # Limit number of files for testing
    'batch_size': 32,    # Batch size for training
    'learning_rate': 1e-3,
    'num_epochs': 100,
    'latent_dim': 20,    # Small latent dimension as in your example
    'hidden_dim': 50,    # Hidden layer dimension
    'test_split': 0.2,
    'val_split': 0.1,
    'num_workers': 4,
    'feature_dim': 4,    # [x, y, z, energy]
    'normalization': 'tanh',  # 'standard', 'minmax', or 'tanh' (for [-1, 1] range)
    'patience': 15,      # Early stopping patience
}

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

# Calculate input dimension
CONFIG['input_dim'] = CONFIG['max_points'] * CONFIG['feature_dim']
print(f"\nCalculated input dimension: {CONFIG['input_dim']}")

# Find available data files
data_files = sorted(glob.glob(CONFIG['data_path']))
if CONFIG['max_files']:
    data_files = data_files[:CONFIG['max_files']]

print(f"\nFound {len(data_files)} data files:")
for file in data_files:
    file_size = os.path.getsize(file) / (1024**2)  # MB
    print(f"  {os.path.basename(file)} ({file_size:.1f} MB)")

In [None]:
# Data Loading and Flattening Functions

def load_npy_file(file_path):
    """Load and process a single NPY file"""
    try:
        data = np.load(file_path)
        print(f"Loading {os.path.basename(file_path)}: shape {data.shape}")
        
        if data.shape[1] != 5:  # [event, x, y, z, energy]
            print(f"  Warning: Expected 5 columns, got {data.shape[1]}")
            return []
        
        # Group by event
        df = pd.DataFrame(data, columns=['event', 'x', 'y', 'z', 'energy'])
        events = []
        
        for event_id in df['event'].unique():
            event_data = df[df['event'] == event_id][['x', 'y', 'z', 'energy']].values
            
            # Filter by size
            if CONFIG['min_points'] <= len(event_data) <= CONFIG['max_points']:
                # Sort by energy (highest to lowest) for consistency
                sorted_indices = np.argsort(-event_data[:, 3])
                sorted_event = event_data[sorted_indices]
                events.append(sorted_event)
        
        print(f"  Loaded {len(events)} valid events")
        return events
        
    except Exception as e:
        print(f"Error loading {file_path}: {str(e)}")
        return []

def flatten_event(event, max_points):
    """
    Flatten an event into a fixed-size vector with zero padding
    
    Args:
        event: numpy array of shape (n_points, 4) with [x, y, z, energy]
        max_points: maximum number of points to include
    
    Returns:
        flattened vector of shape (max_points * 4,)
    """
    n_points = len(event)
    
    # Create padded array
    padded_event = np.zeros((max_points, 4))
    
    # Copy actual data (truncate if necessary)
    actual_points = min(n_points, max_points)
    padded_event[:actual_points] = event[:actual_points]
    
    # Flatten to 1D vector
    flattened = padded_event.flatten()
    
    return flattened, actual_points

def load_all_data(file_paths):
    """Load all data files and return flattened events"""
    print("Loading all data files...")
    all_events = []
    all_flattened = []
    all_lengths = []
    
    for file_path in tqdm(file_paths, desc="Loading files"):
        events = load_npy_file(file_path)
        all_events.extend(events)
    
    print(f"\nTotal events loaded: {len(all_events)}")
    
    if all_events:
        # Analyze event sizes before flattening
        event_sizes = [len(event) for event in all_events]
        print(f"Event size statistics:")
        print(f"  Min: {min(event_sizes)} points")
        print(f"  Max: {max(event_sizes)} points") 
        print(f"  Mean: {np.mean(event_sizes):.1f} points")
        print(f"  Median: {np.median(event_sizes):.1f} points")
        
        # Flatten all events
        print("\nFlattening events...")
        for event in tqdm(all_events, desc="Flattening"):
            flattened, length = flatten_event(event, CONFIG['max_points'])
            all_flattened.append(flattened)
            all_lengths.append(length)
        
        all_flattened = np.array(all_flattened)
        all_lengths = np.array(all_lengths)
        
        print(f"Flattened data shape: {all_flattened.shape}")
        print(f"Memory usage: {all_flattened.nbytes / (1024**2):.2f} MB")
        
        # Plot size distribution
        plt.figure(figsize=(15, 5))
        
        plt.subplot(1, 3, 1)
        plt.hist(event_sizes, bins=30, alpha=0.7, edgecolor='black')
        plt.xlabel('Number of Points per Event')
        plt.ylabel('Frequency')
        plt.title('Original Event Size Distribution')
        plt.grid(True, alpha=0.3)
        
        plt.subplot(1, 3, 2)
        plt.hist(all_lengths, bins=30, alpha=0.7, edgecolor='black', color='orange')
        plt.xlabel('Actual Points Used (after truncation)')
        plt.ylabel('Frequency')
        plt.title('Points Used in Flattened Version')
        plt.grid(True, alpha=0.3)
        
        plt.subplot(1, 3, 3)
        # Show padding analysis
        padding_ratios = (CONFIG['max_points'] - all_lengths) / CONFIG['max_points'] * 100
        plt.hist(padding_ratios, bins=30, alpha=0.7, edgecolor='black', color='red')
        plt.xlabel('Padding Percentage (%)')
        plt.ylabel('Frequency')
        plt.title('Zero Padding Analysis')
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        print(f"Average padding: {padding_ratios.mean():.1f}%")
    
    return all_flattened, all_lengths

# Load data
print("Starting data loading...")
start_time = time.time()
flattened_data, event_lengths = load_all_data(data_files)
load_time = time.time() - start_time
print(f"Data loading completed in {load_time:.2f} seconds")

In [None]:
# Dataset Class and Normalization

class FlattenedEventDataset(Dataset):
    """Dataset class for flattened particle physics events"""
    
    def __init__(self, flattened_data, event_lengths, scaler=None, fit_scaler=True, normalization='tanh'):
        self.flattened_data = flattened_data
        self.event_lengths = event_lengths
        self.scaler = scaler
        self.normalization = normalization
        
        if self.scaler is None and fit_scaler:
            # Fit scaler on non-zero values only (ignore padding)
            if normalization == 'standard':
                self.scaler = StandardScaler()
            elif normalization == 'minmax':
                self.scaler = MinMaxScaler()
            elif normalization == 'tanh':
                # For tanh activation, normalize to [-1, 1] range
                self.scaler = MinMaxScaler(feature_range=(-1, 1))
            else:
                raise ValueError("normalization must be 'standard', 'minmax', or 'tanh'")
            
            # Create mask for non-zero values (ignore padding)
            non_zero_mask = flattened_data != 0
            
            # Fit scaler only on non-zero values
            if non_zero_mask.any():
                non_zero_data = flattened_data[non_zero_mask].reshape(-1, 1)
                self.scaler.fit(non_zero_data)
                print(f"Fitted {normalization} scaler on {len(non_zero_data)} non-zero values")
                if hasattr(self.scaler, 'mean_'):
                    print(f"Scaler mean: {self.scaler.mean_[0]:.6f}")
                    print(f"Scaler scale: {self.scaler.scale_[0]:.6f}")
                elif hasattr(self.scaler, 'data_min_'):
                    print(f"Scaler range: [{self.scaler.data_min_[0]:.6f}, {self.scaler.data_max_[0]:.6f}]")
        
        # Normalize data
        self.normalized_data = np.zeros_like(flattened_data)
        
        if self.scaler is not None:
            for i in range(len(flattened_data)):
                sample = flattened_data[i].copy()
                non_zero_mask = sample != 0
                
                if non_zero_mask.any():
                    # Only normalize non-zero values, keep zeros as zeros
                    sample[non_zero_mask] = self.scaler.transform(sample[non_zero_mask].reshape(-1, 1)).flatten()
                
                self.normalized_data[i] = sample
        else:
            self.normalized_data = flattened_data
        
        # Convert to tensors
        self.data_tensor = torch.FloatTensor(self.normalized_data)
        self.lengths_tensor = torch.LongTensor(event_lengths)
    
    def __len__(self):
        return len(self.data_tensor)
    
    def __getitem__(self, idx):
        return self.data_tensor[idx], self.lengths_tensor[idx]

# Create dataset and split
if len(flattened_data) > 0:
    print("\nCreating dataset...")
    
    # Split data
    train_data, temp_data, train_lengths, temp_lengths = train_test_split(
        flattened_data, event_lengths, 
        test_size=CONFIG['test_split'] + CONFIG['val_split'], 
        random_state=42
    )
    
    val_data, test_data, val_lengths, test_lengths = train_test_split(
        temp_data, temp_lengths,
        test_size=CONFIG['test_split'] / (CONFIG['test_split'] + CONFIG['val_split']),
        random_state=42
    )
    
    print(f"Dataset splits:")
    print(f"  Train: {len(train_data)} events")
    print(f"  Validation: {len(val_data)} events")
    print(f"  Test: {len(test_data)} events")
    
    # Create datasets
    train_dataset = FlattenedEventDataset(train_data, train_lengths, 
                                         normalization=CONFIG['normalization'], fit_scaler=True)
    val_dataset = FlattenedEventDataset(val_data, val_lengths, 
                                       scaler=train_dataset.scaler, fit_scaler=False,
                                       normalization=CONFIG['normalization'])
    test_dataset = FlattenedEventDataset(test_data, test_lengths, 
                                        scaler=train_dataset.scaler, fit_scaler=False,
                                        normalization=CONFIG['normalization'])
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=CONFIG['batch_size'], 
        shuffle=True, num_workers=CONFIG['num_workers']
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=CONFIG['batch_size'], 
        shuffle=False, num_workers=CONFIG['num_workers']
    )
    
    test_loader = DataLoader(
        test_dataset, batch_size=CONFIG['batch_size'],
        shuffle=False, num_workers=CONFIG['num_workers']
    )
    
    print(f"Data loaders created successfully")
    
    # Show example batch
    sample_batch = next(iter(train_loader))
    sample_data, sample_lengths = sample_batch
    print(f"Sample batch data shape: {sample_data.shape}")
    print(f"Sample batch lengths: {sample_lengths[:5].tolist()}")
    print(f"Sample batch data range: [{sample_data.min():.3f}, {sample_data.max():.3f}]")
    print(f"Non-zero ratio in sample: {(sample_data != 0).float().mean():.3f}")
    
else:
    print("No events loaded - cannot proceed with training")

In [None]:
# Flattened Autoencoder Model

class FlattenedAutoencoder(nn.Module):
    """
    Simple flattened autoencoder for particle physics events.
    Based on the architecture from your training.py file.
    
    Input: (batch_size, input_dim) where input_dim = max_points * 4
    Output: (batch_size, input_dim) reconstructed input
    """
    
    def __init__(self, input_dim=4000, hidden_dim=50, latent_dim=20):
        super(FlattenedAutoencoder, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, latent_dim),
            nn.SiLU()
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Tanh()  # Tanh activation for [-1, 1] normalized input
        )
    
    def encode(self, x):
        """Encode input to latent space"""
        return self.encoder(x)
    
    def decode(self, z):
        """Decode latent representation back to input space"""
        return self.decoder(z)
    
    def forward(self, x):
        """Full forward pass: encode then decode"""
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

# Alternative model with more layers for comparison
class DeepFlattenedAutoencoder(nn.Module):
    """
    Deeper version of the flattened autoencoder
    """
    
    def __init__(self, input_dim=4000, latent_dim=20):
        super(DeepFlattenedAutoencoder, self).__init__()
        
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        
        # Encoder with gradual dimension reduction
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.SiLU(),
            nn.Dropout(0.1),
            nn.Linear(512, 128),
            nn.SiLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 50),
            nn.SiLU(),
            nn.Linear(50, latent_dim),
            nn.SiLU()
        )
        
        # Decoder with gradual dimension expansion
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 50),
            nn.SiLU(),
            nn.Linear(50, 128),
            nn.SiLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 512),
            nn.SiLU(),
            nn.Dropout(0.1),
            nn.Linear(512, input_dim),
            nn.Tanh()
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

# Create model
if len(flattened_data) > 0:
    # Choose model type
    use_deep_model = False  # Set to True to use deeper model
    
    if use_deep_model:
        model = DeepFlattenedAutoencoder(
            input_dim=CONFIG['input_dim'],
            latent_dim=CONFIG['latent_dim']
        ).to(device)
        model_name = "Deep Flattened Autoencoder"
    else:
        model = FlattenedAutoencoder(
            input_dim=CONFIG['input_dim'],
            hidden_dim=CONFIG['hidden_dim'],
            latent_dim=CONFIG['latent_dim']
        ).to(device)
        model_name = "Simple Flattened Autoencoder"
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"{model_name} created successfully!")
    print(f"Architecture: {CONFIG['input_dim']} -> {CONFIG['hidden_dim']} -> {CONFIG['latent_dim']} -> {CONFIG['hidden_dim']} -> {CONFIG['input_dim']}")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Model size: {total_params * 4 / (1024**2):.2f} MB")
    
    # Test forward pass
    with torch.no_grad():
        sample_input, sample_lengths = next(iter(train_loader))
        sample_input = sample_input[:2].to(device)  # Take 2 samples
        sample_output = model(sample_input)
        
        print(f"\nTest forward pass:")
        print(f"  Input shape: {sample_input.shape}")
        print(f"  Output shape: {sample_output.shape}")
        print(f"  Input range: [{sample_input.min():.3f}, {sample_input.max():.3f}]")
        print(f"  Output range: [{sample_output.min():.3f}, {sample_output.max():.3f}]")
        
        # Check reconstruction of padded regions
        input_np = sample_input[0].cpu().numpy()
        output_np = sample_output[0].cpu().numpy()
        
        zero_positions = input_np == 0
        print(f"  Zero positions in input: {zero_positions.sum()}")
        print(f"  Output values at zero positions range: [{output_np[zero_positions].min():.4f}, {output_np[zero_positions].max():.4f}]")
        
else:
    print("No data available - cannot create model")

In [None]:
# Training Functions

def masked_loss(output, target, lengths, loss_fn=F.mse_loss):
    """
    Calculate loss only on non-padded regions
    
    Args:
        output, target: tensors of shape (batch_size, input_dim)
        lengths: actual lengths of events (number of points used)
        loss_fn: loss function to use
    """
    batch_size = output.size(0)
    total_loss = 0
    total_elements = 0
    
    for i in range(batch_size):
        # Calculate how many elements are actual data (not padding)
        actual_elements = lengths[i].item() * CONFIG['feature_dim']
        
        if actual_elements > 0:
            # Only calculate loss on non-padded elements
            sample_output = output[i][:actual_elements]
            sample_target = target[i][:actual_elements]
            
            sample_loss = loss_fn(sample_output, sample_target, reduction='sum')
            total_loss += sample_loss
            total_elements += actual_elements
    
    return total_loss / total_elements if total_elements > 0 else torch.tensor(0.0)

def train_epoch(model, dataloader, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    num_batches = 0
    
    progress_bar = tqdm(dataloader, desc="Training")
    
    for batch in progress_bar:
        data, lengths = batch
        data = data.to(device)
        lengths = lengths.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        reconstructed = model(data)
        
        # Calculate masked loss
        loss = masked_loss(reconstructed, data, lengths)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
        
        progress_bar.set_postfix({'Loss': f'{loss.item():.6f}'})
    
    return total_loss / num_batches

def validate_epoch(model, dataloader, device):
    """Validate for one epoch"""
    model.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            data, lengths = batch
            data = data.to(device)
            lengths = lengths.to(device)
            
            # Forward pass
            reconstructed = model(data)
            
            # Calculate masked loss
            loss = masked_loss(reconstructed, data, lengths)
            
            total_loss += loss.item()
            num_batches += 1
    
    return total_loss / num_batches

# Training setup
if len(flattened_data) > 0:
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.7, patience=7, verbose=True
    )
    
    # Training history
    train_losses = []
    val_losses = []
    
    print("Training setup complete!")
    print(f"Optimizer: Adam with lr={CONFIG['learning_rate']}")
    print(f"Scheduler: ReduceLROnPlateau")
    print(f"Loss function: Masked MSE (ignores padded regions)")
    print(f"Training for up to {CONFIG['num_epochs']} epochs")
    print(f"Early stopping patience: {CONFIG['patience']} epochs")
else:
    print("No data available - cannot setup training")

In [None]:
# Training Loop

if len(flattened_data) > 0:
    print("Starting training...")
    best_val_loss = float('inf')
    patience_counter = 0
    
    start_time = time.time()
    
    for epoch in range(CONFIG['num_epochs']):
        print(f"\nEpoch {epoch + 1}/{CONFIG['num_epochs']}")
        
        # Training
        train_loss = train_epoch(model, train_loader, optimizer, device)
        train_losses.append(train_loss)
        
        # Validation
        val_loss = validate_epoch(model, val_loader, device)
        val_losses.append(val_loss)
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        
        # Print results
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f} | LR: {current_lr:.2e}")
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            # Save best model
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'config': CONFIG,
                'scaler': train_dataset.scaler,
                'model_type': 'FlattenedAutoencoder',
                'architecture': {
                    'input_dim': CONFIG['input_dim'],
                    'hidden_dim': CONFIG['hidden_dim'],
                    'latent_dim': CONFIG['latent_dim']
                }
            }, 'best_flat_autoencoder.pth')
            print(f"New best model saved! Val Loss: {val_loss:.6f}")
        else:
            patience_counter += 1
        
        # Early stopping check
        if patience_counter >= CONFIG['patience']:
            print(f"Early stopping triggered after {CONFIG['patience']} epochs without improvement!")
            break
    
    total_time = time.time() - start_time
    print(f"\nTraining completed in {total_time:.2f} seconds ({total_time/60:.1f} minutes)")
    print(f"Best validation loss: {best_val_loss:.6f}")
    print(f"Epochs trained: {len(train_losses)}")
    
    # Plot training history
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.plot(train_losses, label='Train Loss', color='blue', linewidth=2)
    plt.plot(val_losses, label='Validation Loss', color='red', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training History')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 3, 2)
    plt.semilogy(train_losses, label='Train Loss', color='blue', linewidth=2)
    plt.semilogy(val_losses, label='Validation Loss', color='red', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss (log scale)')
    plt.title('Training History (Log Scale)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 3, 3)
    # Learning rate history
    lr_history = []
    # Reconstruct LR history (approximation)
    current_lr = CONFIG['learning_rate']
    for i, val_loss in enumerate(val_losses):
        if i > 0 and val_losses[i] >= val_losses[i-1]:
            # LR might have been reduced
            pass  # In real scenario, we'd track this properly
        lr_history.append(current_lr)
    
    plt.plot(lr_history, color='green', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate Schedule')
    plt.yscale('log')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Training summary
    print(f"\n" + "="*50)
    print("TRAINING SUMMARY")
    print("="*50)
    print(f"Model: {model_name}")
    print(f"Input dimension: {CONFIG['input_dim']:,}")
    print(f"Latent dimension: {CONFIG['latent_dim']}")
    print(f"Total parameters: {total_params:,}")
    print(f"Training time: {total_time:.1f}s")
    print(f"Final train loss: {train_losses[-1]:.6f}")
    print(f"Best validation loss: {best_val_loss:.6f}")
    print(f"Normalization: {CONFIG['normalization']}")
    
else:
    print("No data available - cannot train model")

In [None]:
# Model Evaluation and Visualization

def evaluate_model(model, dataloader, device, scaler=None):
    """Evaluate model and return reconstruction losses"""
    model.eval()
    all_losses = []
    sample_originals = []
    sample_reconstructions = []
    sample_lengths = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            data, lengths = batch
            data = data.to(device)
            lengths = lengths.to(device)
            
            # Get reconstructions
            reconstructed = model(data)
            
            # Calculate per-sample losses using masked loss
            for i in range(data.size(0)):
                actual_elements = lengths[i].item() * CONFIG['feature_dim']
                
                if actual_elements > 0:
                    orig = data[i][:actual_elements]
                    recon = reconstructed[i][:actual_elements]
                    
                    sample_loss = F.mse_loss(recon, orig).item()
                    all_losses.append(sample_loss)
                    
                    # Store first few samples for visualization
                    if len(sample_originals) < 6:
                        orig_full = data[i].cpu().numpy()
                        recon_full = reconstructed[i].cpu().numpy()
                        
                        sample_originals.append(orig_full)
                        sample_reconstructions.append(recon_full)
                        sample_lengths.append(lengths[i].item())
    
    return np.array(all_losses), sample_originals, sample_reconstructions, sample_lengths

def unflatten_event(flattened_data, actual_length, max_points, scaler=None):
    """Convert flattened data back to event format"""
    # Reshape to (max_points, 4)
    reshaped = flattened_data.reshape(max_points, CONFIG['feature_dim'])
    
    # Take only actual data (not padding)
    actual_data = reshaped[:actual_length]
    
    # Denormalize if scaler provided
    if scaler is not None:
        # Only denormalize non-zero values
        non_zero_mask = actual_data.flatten() != 0
        if non_zero_mask.any():
            flat_data = actual_data.flatten()
            flat_data[non_zero_mask] = scaler.inverse_transform(flat_data[non_zero_mask].reshape(-1, 1)).flatten()
            actual_data = flat_data.reshape(-1, CONFIG['feature_dim'])
    
    return actual_data

if len(flattened_data) > 0 and 'model' in locals():
    print("Evaluating model on test set...")
    
    # Load best model
    try:
        checkpoint = torch.load('best_flat_autoencoder.pth')
        model.load_state_dict(checkpoint['model_state_dict'])
        print("Loaded best model checkpoint")
    except:
        print("Using current model (no checkpoint found)")
    
    # Evaluate on test set
    test_losses, test_originals, test_reconstructions, test_sample_lengths = evaluate_model(
        model, test_loader, device, scaler=train_dataset.scaler
    )
    
    print(f"\nTest Set Evaluation:")
    print(f"Number of test samples: {len(test_losses)}")
    print(f"Mean reconstruction loss: {test_losses.mean():.6f}")
    print(f"Std reconstruction loss: {test_losses.std():.6f}")
    print(f"Min reconstruction loss: {test_losses.min():.6f}")
    print(f"Max reconstruction loss: {test_losses.max():.6f}")
    
    # Plot reconstruction loss distribution
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.hist(test_losses, bins=50, alpha=0.7, edgecolor='black')
    plt.xlabel('Reconstruction Loss')
    plt.ylabel('Frequency')
    plt.title('Test Reconstruction Loss Distribution')
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 3, 2)
    plt.hist(test_losses, bins=50, alpha=0.7, edgecolor='black')
    plt.xlabel('Reconstruction Loss')
    plt.ylabel('Frequency')
    plt.title('Test Loss Distribution (Log Scale)')
    plt.yscale('log')
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 3, 3)
    plt.boxplot(test_losses)
    plt.ylabel('Reconstruction Loss')
    plt.title('Test Loss Box Plot')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Visualize reconstructions
    if test_originals and test_reconstructions:
        print(f"\nVisualizing {len(test_reconstructions)} sample reconstructions...")
        
        # Unflatten and visualize
        n_examples = min(3, len(test_reconstructions))
        fig, axes = plt.subplots(n_examples, 4, figsize=(16, 4 * n_examples))
        
        if n_examples == 1:
            axes = axes.reshape(1, -1)
        
        for i in range(n_examples):
            orig_flat = test_originals[i]
            recon_flat = test_reconstructions[i]
            length = test_sample_lengths[i]
            
            # Unflatten to original format
            orig_event = unflatten_event(orig_flat, length, CONFIG['max_points'], train_dataset.scaler)
            recon_event = unflatten_event(recon_flat, length, CONFIG['max_points'], train_dataset.scaler)
            
            # Feature plots
            features = ['X', 'Y', 'Z', 'Energy']
            for j in range(4):
                axes[i, j].plot(orig_event[:, j], 'b-', label='Original', alpha=0.8, linewidth=2, marker='o', markersize=3)
                axes[i, j].plot(recon_event[:, j], 'r--', label='Reconstructed', alpha=0.8, linewidth=2, marker='^', markersize=3)
                axes[i, j].set_title(f'Sample {i+1}: {features[j]} ({length} points)')
                axes[i, j].set_xlabel('Point Index (sorted by energy)')
                axes[i, j].set_ylabel(features[j])
                axes[i, j].legend()
                axes[i, j].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # 3D scatter plots for spatial reconstruction
        fig = plt.figure(figsize=(15, 5))
        
        for i in range(min(3, len(test_reconstructions))):
            orig_flat = test_originals[i]
            recon_flat = test_reconstructions[i]
            length = test_sample_lengths[i]
            
            orig_event = unflatten_event(orig_flat, length, CONFIG['max_points'], train_dataset.scaler)
            recon_event = unflatten_event(recon_flat, length, CONFIG['max_points'], train_dataset.scaler)
            
            ax = fig.add_subplot(1, 3, i+1, projection='3d')
            
            # Plot original points
            ax.scatter(orig_event[:, 0], orig_event[:, 1], orig_event[:, 2], 
                      c=orig_event[:, 3], cmap='viridis', alpha=0.7, s=30, label='Original')
            
            # Plot reconstructed points
            ax.scatter(recon_event[:, 0], recon_event[:, 1], recon_event[:, 2], 
                      c=recon_event[:, 3], cmap='plasma', alpha=0.7, s=20, marker='^', label='Reconstructed')
            
            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            ax.set_zlabel('Z')
            ax.set_title(f'3D Reconstruction Sample {i+1}\n({length} points)')
            ax.legend()
        
        plt.tight_layout()
        plt.show()
        
        # Reconstruction quality analysis
        print(f"\nReconstruction Quality Analysis:")
        for i in range(min(3, len(test_reconstructions))):
            orig_flat = test_originals[i]
            recon_flat = test_reconstructions[i]
            length = test_sample_lengths[i]
            
            orig_event = unflatten_event(orig_flat, length, CONFIG['max_points'], train_dataset.scaler)
            recon_event = unflatten_event(recon_flat, length, CONFIG['max_points'], train_dataset.scaler)
            
            # Calculate per-feature reconstruction errors
            feature_errors = np.mean((orig_event - recon_event)**2, axis=0)
            print(f"Sample {i+1} ({length} points):")
            features = ['X', 'Y', 'Z', 'Energy']
            for j, feature in enumerate(features):
                print(f"  {feature} MSE: {feature_errors[j]:.6f}")
    
else:
    print("No trained model available for evaluation")

In [None]:
# Anomaly Detection and Model Comparison

def detect_anomalies(losses, threshold_percentile=95):
    """Detect anomalies based on reconstruction loss"""
    threshold = np.percentile(losses, threshold_percentile)
    anomaly_mask = losses > threshold
    return anomaly_mask, threshold

if len(flattened_data) > 0 and 'test_losses' in locals():
    print("Performing anomaly detection analysis...")
    
    # Detect anomalies using different thresholds
    thresholds = [90, 95, 99]
    results = {}
    
    for thresh in thresholds:
        anomaly_mask, threshold_value = detect_anomalies(test_losses, thresh)
        n_anomalies = np.sum(anomaly_mask)
        
        results[thresh] = {
            'mask': anomaly_mask,
            'threshold': threshold_value,
            'n_anomalies': n_anomalies,
            'anomaly_rate': n_anomalies / len(test_losses) * 100
        }
        
        print(f"\nThreshold: {thresh}th percentile")
        print(f"  Threshold value: {threshold_value:.6f}")
        print(f"  Anomalies detected: {n_anomalies}/{len(test_losses)} ({results[thresh]['anomaly_rate']:.1f}%)")
    
    # Plot anomaly detection results
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss distribution with thresholds
    axes[0, 0].hist(test_losses, bins=50, alpha=0.7, edgecolor='black', density=True)
    colors = ['orange', 'red', 'darkred']
    for i, thresh in enumerate(thresholds):
        axes[0, 0].axvline(results[thresh]['threshold'], 
                          color=colors[i], linestyle='--', alpha=0.8,
                          label=f'{thresh}th percentile')
    axes[0, 0].set_xlabel('Reconstruction Loss')
    axes[0, 0].set_ylabel('Density')
    axes[0, 0].set_title('Reconstruction Loss Distribution with Thresholds')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Log scale version
    axes[0, 1].hist(test_losses, bins=50, alpha=0.7, edgecolor='black', density=True)
    for i, thresh in enumerate(thresholds):
        axes[0, 1].axvline(results[thresh]['threshold'], 
                          color=colors[i], linestyle='--', alpha=0.8,
                          label=f'{thresh}th percentile')
    axes[0, 1].set_xlabel('Reconstruction Loss')
    axes[0, 1].set_ylabel('Density')
    axes[0, 1].set_title('Loss Distribution (Log Scale)')
    axes[0, 1].set_yscale('log')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Anomaly rates
    percentiles = list(results.keys())
    rates = [results[p]['anomaly_rate'] for p in percentiles]
    
    axes[1, 0].bar([str(p) for p in percentiles], rates, alpha=0.7, color='lightcoral', edgecolor='black')
    axes[1, 0].set_xlabel('Threshold Percentile')
    axes[1, 0].set_ylabel('Anomaly Rate (%)')
    axes[1, 0].set_title('Anomaly Detection Rates')
    axes[1, 0].grid(True, alpha=0.3)
    
    # ROC-like curve for different thresholds
    thresh_range = np.arange(50, 100, 1)
    anomaly_rates = []
    
    for t in thresh_range:
        _, thresh_val = detect_anomalies(test_losses, t)
        anomaly_rate = np.sum(test_losses > thresh_val) / len(test_losses) * 100
        anomaly_rates.append(anomaly_rate)
    
    axes[1, 1].plot(thresh_range, anomaly_rates, 'b-', linewidth=2)
    axes[1, 1].set_xlabel('Threshold Percentile')
    axes[1, 1].set_ylabel('Anomaly Rate (%)')
    axes[1, 1].set_title('Anomaly Rate vs Threshold')
    axes[1, 1].grid(True, alpha=0.3)
    
    # Mark common thresholds
    for i, thresh in enumerate([90, 95, 99]):
        if thresh in thresh_range:
            idx = list(thresh_range).index(thresh)
            axes[1, 1].scatter(thresh, anomaly_rates[idx], color=colors[i], s=100, zorder=5)
            axes[1, 1].annotate(f'{thresh}%', 
                               (thresh, anomaly_rates[idx]), 
                               xytext=(5, 5), textcoords='offset points')
    
    plt.tight_layout()
    plt.show()
    
    # Model comparison summary
    print(f"\n" + "="*60)
    print("MODEL COMPARISON: FLATTENED vs CNN AUTOENCODER")
    print("="*60)
    
    comparison = {
        'Architecture': {
            'Flattened': f'FC: {CONFIG["input_dim"]} -> {CONFIG["hidden_dim"]} -> {CONFIG["latent_dim"]} -> {CONFIG["hidden_dim"]} -> {CONFIG["input_dim"]}',
            'CNN': '1D Conv layers with variable-length sequences'
        },
        'Input Format': {
            'Flattened': f'Fixed vector of {CONFIG["input_dim"]} elements (zero-padded)',
            'CNN': 'Variable-length sequences (dynamically padded)'
        },
        'Memory Usage': {
            'Flattened': f'{CONFIG["input_dim"] * 4 / 1024:.1f} KB per sample (fixed)',
            'CNN': 'Variable (depends on sequence length)'
        },
        'Complexity': {
            'Flattened': f'O({CONFIG["input_dim"]}) - linear in max event size',
            'CNN': 'O(n) - linear in actual event size'
        },
        'Spatial Awareness': {
            'Flattened': 'None - treats all coordinates independently',
            'CNN': 'Local patterns through convolutions'
        },
        'Parameters': {
            'Flattened': f'{total_params:,} parameters' if 'total_params' in locals() else 'N/A',
            'CNN': 'Typically more due to conv layers'
        },
        'Training Speed': {
            'Flattened': 'Fast - simple forward/backward passes',
            'CNN': 'Slower - convolution operations'
        },
        'Best Use Cases': {
            'Flattened': 'Small datasets, simple patterns, fast inference needed',
            'CNN': 'Large datasets, spatial patterns important, better scalability'
        }
    }
    
    for category, details in comparison.items():
        print(f"\n{category}:")
        for model, description in details.items():
            print(f"  {model:10}: {description}")
    
    # Performance summary
    print(f"\n" + "="*50)
    print("FLATTENED AUTOENCODER PERFORMANCE SUMMARY")
    print("="*50)
    print(f"Test samples: {len(test_losses)}")
    print(f"Mean reconstruction loss: {test_losses.mean():.6f}")
    print(f"Std reconstruction loss: {test_losses.std():.6f}")
    print(f"\nRecommended threshold (95th percentile): {results[95]['threshold']:.6f}")
    print(f"Expected anomaly rate: ~5%")
    print(f"Actual anomaly rate: {results[95]['anomaly_rate']:.1f}%")
    
    # Advantages and disadvantages
    print(f"\nAdvantages of Flattened Autoencoder:")
    print(f"✓ Simple architecture - easy to understand and debug")
    print(f"✓ Fast training - fewer parameters than CNN")
    print(f"✓ Deterministic input size - no dynamic padding complexity")
    print(f"✓ Good baseline - establishes performance floor")
    print(f"✓ Memory predictable - fixed size per sample")
    
    print(f"\nDisadvantages of Flattened Autoencoder:")
    print(f"✗ No spatial structure awareness")
    print(f"✗ Inefficient for variable-size events (lots of zero padding)")
    print(f"✗ Doesn't scale well with max event size")
    print(f"✗ May not capture complex spatial relationships")
    print(f"✗ Large input dimension for big events")
    
else:
    print("No test results available for anomaly detection analysis")

In [None]:
# Model Saving and Usage Instructions

if len(flattened_data) > 0 and 'model' in locals():
    print("Saving final model and analysis results...")
    
    # Save final model state
    final_model_data = {
        'model_state_dict': model.state_dict(),
        'config': CONFIG,
        'scaler': train_dataset.scaler,
        'train_losses': train_losses if 'train_losses' in locals() else [],
        'val_losses': val_losses if 'val_losses' in locals() else [],
        'test_losses': test_losses.tolist() if 'test_losses' in locals() else [],
        'architecture': 'FlattenedAutoencoder',
        'model_params': {
            'input_dim': CONFIG['input_dim'],
            'hidden_dim': CONFIG['hidden_dim'],
            'latent_dim': CONFIG['latent_dim'],
            'total_parameters': total_params if 'total_params' in locals() else 0
        },
        'data_stats': {
            'max_points': CONFIG['max_points'],
            'feature_dim': CONFIG['feature_dim'],
            'normalization': CONFIG['normalization'],
            'total_samples': len(flattened_data),
            'train_samples': len(train_data) if 'train_data' in locals() else 0,
            'test_samples': len(test_data) if 'test_data' in locals() else 0
        }
    }
    
    torch.save(final_model_data, 'final_flat_autoencoder.pth')
    print("Model saved as 'final_flat_autoencoder.pth'")
    
    # Create usage example
    usage_example = '''
# Example: How to use the trained flattened autoencoder

import torch
import numpy as np
from sklearn.preprocessing import MinMaxScaler

# 1. Load the trained model
checkpoint = torch.load('final_flat_autoencoder.pth')
config = checkpoint['config']
scaler = checkpoint['scaler']

# 2. Initialize model with saved architecture
class FlattenedAutoencoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_dim, latent_dim),
            torch.nn.SiLU()
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, hidden_dim),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_dim, input_dim),
            torch.nn.Tanh()
        )
    
    def forward(self, x):
        return self.decoder(self.encoder(x))

# Create and load model
model = FlattenedAutoencoder(
    config['input_dim'], 
    config['hidden_dim'], 
    config['latent_dim']
)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# 3. Process new event data
def preprocess_event(event_points, max_points=1000):
    """
    Convert event to flattened format
    event_points: numpy array of shape (n_points, 4) with [x, y, z, energy]
    """
    # Sort by energy (descending)
    sorted_indices = np.argsort(-event_points[:, 3])
    sorted_event = event_points[sorted_indices]
    
    # Pad to max_points
    padded = np.zeros((max_points, 4))
    n_points = min(len(sorted_event), max_points)
    padded[:n_points] = sorted_event[:n_points]
    
    # Flatten
    flattened = padded.flatten()
    
    # Normalize non-zero values
    non_zero_mask = flattened != 0
    if non_zero_mask.any():
        flattened[non_zero_mask] = scaler.transform(
            flattened[non_zero_mask].reshape(-1, 1)
        ).flatten()
    
    return torch.FloatTensor(flattened).unsqueeze(0), n_points

# 4. Detect anomalies in new data
def detect_anomaly(event_points, threshold=None):
    """Detect if an event is anomalous"""
    if threshold is None:
        # Use 95th percentile from training
        threshold = np.percentile(checkpoint['test_losses'], 95)
    
    # Preprocess
    processed_event, n_points = preprocess_event(event_points)
    
    # Get reconstruction
    with torch.no_grad():
        reconstructed = model(processed_event)
    
    # Calculate loss only on actual data (not padding)
    actual_elements = n_points * 4
    original = processed_event[0][:actual_elements]
    recon = reconstructed[0][:actual_elements]
    
    reconstruction_loss = torch.nn.functional.mse_loss(recon, original).item()
    
    is_anomaly = reconstruction_loss > threshold
    
    return {
        'is_anomaly': is_anomaly,
        'reconstruction_loss': reconstruction_loss,
        'threshold': threshold,
        'anomaly_score': reconstruction_loss / threshold
    }

# 5. Example usage
# new_event = np.random.randn(150, 4)  # Example event with 150 points
# result = detect_anomaly(new_event)
# print(f"Anomaly: {result['is_anomaly']}, Score: {result['anomaly_score']:.3f}")
    '''
    
    # Save usage example
    with open('flat_autoencoder_usage_example.py', 'w') as f:
        f.write(usage_example)
    
    print("Usage example saved as 'flat_autoencoder_usage_example.py'")
    
    # Final summary
    print(f"\n" + "="*60)
    print("FINAL MODEL SUMMARY")
    print("="*60)
    
    print(f"Model Type: Flattened Autoencoder")
    print(f"Architecture: {CONFIG['input_dim']} -> {CONFIG['hidden_dim']} -> {CONFIG['latent_dim']} -> {CONFIG['hidden_dim']} -> {CONFIG['input_dim']}")
    print(f"Total Parameters: {total_params:,}" if 'total_params' in locals() else "Total Parameters: N/A")
    print(f"Input Format: Flattened vectors (max {CONFIG['max_points']} points × 4 features)")
    print(f"Normalization: {CONFIG['normalization']} scaling")
    
    if 'train_losses' in locals() and 'val_losses' in locals():
        print(f"\nTraining Results:")
        print(f"  Epochs trained: {len(train_losses)}")
        print(f"  Final train loss: {train_losses[-1]:.6f}")
        print(f"  Best validation loss: {best_val_loss:.6f}")
    
    if 'test_losses' in locals():
        print(f"\nTest Performance:")
        print(f"  Test samples: {len(test_losses)}")
        print(f"  Mean reconstruction loss: {test_losses.mean():.6f}")
        print(f"  Recommended anomaly threshold: {np.percentile(test_losses, 95):.6f}")
    
    print(f"\nFiles Created:")
    print(f"  • best_flat_autoencoder.pth - Best model checkpoint")
    print(f"  • final_flat_autoencoder.pth - Final model with all metadata")
    print(f"  • flat_autoencoder_usage_example.py - Usage instructions")
    
    print(f"\nWhen to use this model:")
    print(f"  ✓ Small to medium datasets (< 10K events)")
    print(f"  ✓ Simple anomaly detection tasks") 
    print(f"  ✓ When interpretability is important")
    print(f"  ✓ As a baseline for comparison")
    print(f"  ✓ Fast inference requirements")
    
else:
    print("No model available to save")

# Variable-Size Input Comparison

## The Problem with Flattened Approach for Variable Sizes

The flattened autoencoder has fundamental limitations when dealing with variable-size inputs:

### **Memory and Computational Inefficiency**
- **Fixed allocation**: Must allocate memory for `max_points` regardless of actual event size
- **Padding waste**: Small events (50 points) still require same memory as large events (1000 points)
- **Computational overhead**: Network processes zeros in padded regions

### **Loss of Spatial Structure**
- **No locality**: Point at index 0 has no relationship to point at index 1
- **Arbitrary ordering**: Padding breaks any spatial continuity
- **Feature mixing**: All coordinates treated as independent features

### **Scalability Issues**
- **Linear growth**: Input dimension grows as `max_points × features`
- **Parameter explosion**: Model size scales quadratically with max event size
- **GPU memory**: Large input dimensions can exceed GPU memory limits

## Why CNN Autoencoder is Superior for Variable Sizes

### **Dynamic Memory Usage**
```python
# Flattened: Always uses full allocation
flattened_memory = max_points * features * batch_size  # Always 1000 × 4 × 32 = 128,000

# CNN: Uses only what's needed
cnn_memory = actual_points * features * batch_size    # 50 × 4 × 32 = 6,400 (for 50-point event)
```

### **Spatial Awareness**
- **Local patterns**: Convolutions capture relationships between nearby points
- **Translation invariance**: Same patterns detected regardless of position
- **Hierarchical features**: Multiple scales of spatial information

### **Better Scalability**
- **Linear complexity**: O(sequence_length) rather than O(max_sequence_length)
- **Efficient padding**: Only minimal padding for batch processing
- **GPU friendly**: Modern GPUs optimized for variable-length sequences

In [None]:
# Concrete Comparison: Variable-Size Input Handling

# Let's analyze the efficiency differences with real numbers from your data
if 'event_lengths' in locals() and len(event_lengths) > 0:
    print("VARIABLE-SIZE INPUT ANALYSIS")
    print("=" * 50)
    
    # Analyze actual event sizes in your dataset
    print(f"Event size statistics from your data:")
    print(f"  Min event size: {event_lengths.min()} points")
    print(f"  Max event size: {event_lengths.max()} points")
    print(f"  Mean event size: {event_lengths.mean():.1f} points")
    print(f"  Median event size: {np.median(event_lengths):.1f} points")
    
    # Calculate efficiency metrics
    max_size = CONFIG['max_points']
    actual_sizes = event_lengths
    
    # Memory usage comparison
    flattened_total_memory = len(actual_sizes) * max_size * CONFIG['feature_dim'] * 4  # 4 bytes per float
    cnn_total_memory = np.sum(actual_sizes) * CONFIG['feature_dim'] * 4
    
    print(f"\nMemory Usage Comparison:")
    print(f"  Flattened approach: {flattened_total_memory / (1024**2):.1f} MB")
    print(f"  CNN approach: {cnn_total_memory / (1024**2):.1f} MB")
    print(f"  Memory saved with CNN: {(flattened_total_memory - cnn_total_memory) / (1024**2):.1f} MB ({100 * (flattened_total_memory - cnn_total_memory) / flattened_total_memory:.1f}%)")
    
    # Padding efficiency
    total_padding = np.sum(max_size - actual_sizes)
    total_actual = np.sum(actual_sizes)
    padding_ratio = total_padding / (total_actual + total_padding) * 100
    
    print(f"\nPadding Analysis:")
    print(f"  Total actual data points: {total_actual:,}")
    print(f"  Total padding points: {total_padding:,}")
    print(f"  Wasted space: {padding_ratio:.1f}%")
    
    # Efficiency by event size
    small_events = actual_sizes[actual_sizes < 100]
    medium_events = actual_sizes[(actual_sizes >= 100) & (actual_sizes < 500)]
    large_events = actual_sizes[actual_sizes >= 500]
    
    print(f"\nEfficiency by Event Size:")
    if len(small_events) > 0:
        small_efficiency = np.mean(small_events) / max_size * 100
        print(f"  Small events (<100 points): {len(small_events)} events, {small_efficiency:.1f}% efficiency")
    
    if len(medium_events) > 0:
        medium_efficiency = np.mean(medium_events) / max_size * 100
        print(f"  Medium events (100-500 points): {len(medium_events)} events, {medium_efficiency:.1f}% efficiency")
    
    if len(large_events) > 0:
        large_efficiency = np.mean(large_events) / max_size * 100
        print(f"  Large events (500+ points): {len(large_events)} events, {large_efficiency:.1f}% efficiency")

# Demonstrate the difference in computational requirements
print(f"\n" + "=" * 60)
print("COMPUTATIONAL COMPLEXITY COMPARISON")
print("=" * 60)

example_sizes = [50, 150, 500, 1000]
batch_size = 32

for size in example_sizes:
    # Flattened autoencoder computations
    flat_input_dim = CONFIG['max_points'] * CONFIG['feature_dim']
    flat_hidden_ops = batch_size * flat_input_dim * CONFIG['hidden_dim']  # encoder first layer
    
    # CNN autoencoder computations (approximate)
    cnn_ops = batch_size * size * CONFIG['feature_dim'] * 64 * 7  # first conv layer: channels * kernel_size
    
    print(f"\nEvent size: {size} points")
    print(f"  Flattened AE operations: {flat_hidden_ops:,}")
    print(f"  CNN AE operations: {cnn_ops:,}")
    print(f"  CNN efficiency gain: {flat_hidden_ops / cnn_ops:.1f}x fewer operations")

print(f"\n" + "=" * 60)
print("RECOMMENDATIONS FOR VARIABLE-SIZE INPUTS")
print("=" * 60)

recommendations = {
    "Use CNN Autoencoder when:": [
        "• Events have widely varying sizes (10x+ difference)",
        "• Memory efficiency is important", 
        "• Spatial patterns matter (points have geometric relationships)",
        "• You have > 1000 events in dataset",
        "• GPU memory is limited",
        "• Training/inference speed is important"
    ],
    "Use Flattened Autoencoder when:": [
        "• All events are similar size (< 2x difference)",
        "• Very small dataset (< 500 events)",
        "• Simple baseline needed for comparison",
        "• Spatial relationships don't matter",
        "• Interpretability is more important than efficiency"
    ],
    "Hybrid Approach:": [
        "• Use CNN for large events (> 200 points)",
        "• Use flattened for small events (< 200 points)", 
        "• Automatically choose based on event size",
        "• Ensemble both approaches for better accuracy"
    ]
}

for category, items in recommendations.items():
    print(f"\n{category}")
    for item in items:
        print(f"  {item}")

print(f"\n" + "=" * 60)
print("IMPLEMENTATION COMPARISON")
print("=" * 60)

print("""
CNN Autoencoder Implementation:
```python
# Handles variable lengths naturally
def collate_fn(batch):
    return pad_sequence(batch, batch_first=True)

# Dynamic loss calculation
def masked_loss(output, target, lengths):
    loss = 0
    for i, length in enumerate(lengths):
        loss += mse_loss(output[i][:length], target[i][:length])
    return loss / len(lengths)
```

Flattened Autoencoder Implementation:
```python
# Fixed size - always processes max_points
def preprocess(event, max_points=1000):
    padded = np.zeros((max_points, 4))
    padded[:len(event)] = event[:len(event)]
    return padded.flatten()

# Must manually track actual lengths
def masked_loss(output, target, lengths):
    # Complex indexing to avoid padded regions
    ...
```
""")

print(f"\nCONCLUSION:")
print(f"For your particle physics data with variable event sizes,")
print(f"CNN autoencoder is significantly better due to:")
print(f"  • {padding_ratio:.0f}% reduction in wasted computation" if 'padding_ratio' in locals() else "  • ~50-80% reduction in wasted computation")
print(f"  • Natural handling of variable lengths")
print(f"  • Spatial pattern recognition")
print(f"  • Better scalability to larger events")