<a href="https://colab.research.google.com/github/shreyas-shrestha/VizFoldAutoencoder/blob/main/vizfoldfinal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
from pathlib import Path

# Check if the protein data directory exists
proteins_dir = "Proteins_layer47"
if os.path.exists(proteins_dir):
    npy_files = list(Path(proteins_dir).glob("*.npy"))
    print(f"Found {len(npy_files)} .npy files in '{proteins_dir}' directory:")
    for file_path in npy_files:
        print(f"  - {file_path.name}")
else:
    print(f"Warning: Directory '{proteins_dir}' not found!")
    
print("Ready to load protein data for autoencoder training.")

In [None]:
# This cell is to define visualization functions

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import pandas as pd
from collections import defaultdict
import warnings
import numpy as np
warnings.filterwarnings('ignore')

# Set up plotting style
plt.style.use('default')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

print("Visualization libraries loaded successfully!")

def visualize_kfold_splits(data_size, n_splits=5, random_state=42):
    """
    Visualize how K-fold cross-validation splits the data
    """
    from sklearn.model_selection import KFold
    
    fig, axes = plt.subplots(n_splits, 1, figsize=(15, 2*n_splits))
    if n_splits == 1:
        axes = [axes]
    
    # Create dummy data indices
    indices = np.arange(data_size)
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    
    colors = ['lightblue', 'lightcoral']
    
    for fold, (train_idx, val_idx) in enumerate(kfold.split(indices)):
        ax = axes[fold]
        
        # Create binary masks for visualization
        all_indices = np.arange(data_size)
        train_mask = np.isin(all_indices, train_idx)
        val_mask = np.isin(all_indices, val_idx)
        
        # Plot each sample as a small rectangle
        for i in range(data_size):
            color = colors[0] if train_mask[i] else colors[1]
            ax.barh(0, 1, left=i, height=0.5, color=color, alpha=0.7, edgecolor='none')
        
        ax.set_xlim(0, data_size)
        ax.set_ylim(-0.5, 0.5)
        ax.set_xlabel('Data Sample Index')
        ax.set_title(f'Fold {fold+1}: Train/Validation Split')
        ax.set_yticks([])
        
        # Add statistics
        ax.text(data_size*0.02, 0.2, f'Train: {len(train_idx)} samples ({len(train_idx)/data_size*100:.1f}%)', 
                fontsize=9, bbox=dict(boxstyle="round,pad=0.3", facecolor=colors[0], alpha=0.5))
        ax.text(data_size*0.02, -0.2, f'Val: {len(val_idx)} samples ({len(val_idx)/data_size*100:.1f}%)', 
                fontsize=9, bbox=dict(boxstyle="round,pad=0.3", facecolor=colors[1], alpha=0.5))
    
    plt.tight_layout()
    plt.suptitle(f'{n_splits}-Fold Cross-Validation Data Splits', y=1.02, fontsize=14, fontweight='bold')
    plt.show()
    
    # Print summary statistics
    print(f"\nK-Fold Cross-Validation Summary:")
    print(f"Total samples: {data_size}")
    print(f"Number of folds: {n_splits}")
    print(f"Training samples per fold: ~{data_size*(n_splits-1)//n_splits}")
    print(f"Validation samples per fold: ~{data_size//n_splits}")
    print(f"Training/Validation ratio: {(n_splits-1)}:1")

def analyze_data_distribution(data_matrix, protein_names):
    """
    Analyze and visualize the distribution of protein data
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # 1. Overall data distribution
    axes[0,0].hist(data_matrix.flatten(), bins=50, alpha=0.7, color='skyblue', edgecolor='black')
    axes[0,0].set_title('Distribution of All Protein Data Values')
    axes[0,0].set_xlabel('Value')
    axes[0,0].set_ylabel('Frequency')
    axes[0,0].axvline(data_matrix.mean(), color='red', linestyle='--', label=f'Mean: {data_matrix.mean():.3f}')
    axes[0,0].legend()
    
    # 2. Sample-wise statistics
    sample_means = np.mean(data_matrix, axis=1)
    sample_stds = np.std(data_matrix, axis=1)
    
    axes[0,1].scatter(sample_means, sample_stds, alpha=0.6, color='coral')
    axes[0,1].set_title('Sample-wise Mean vs Standard Deviation')
    axes[0,1].set_xlabel('Sample Mean')
    axes[0,1].set_ylabel('Sample Std Dev')
    
    # 3. Feature-wise statistics (sample a subset due to high dimensionality)
    feature_subset = np.random.choice(data_matrix.shape[1], min(1000, data_matrix.shape[1]), replace=False)
    feature_means = np.mean(data_matrix[:, feature_subset], axis=0)
    feature_stds = np.std(data_matrix[:, feature_subset], axis=0)
    
    axes[1,0].scatter(feature_means, feature_stds, alpha=0.6, color='lightgreen')
    axes[1,0].set_title(f'Feature-wise Statistics (Random {len(feature_subset)} features)')
    axes[1,0].set_xlabel('Feature Mean')
    axes[1,0].set_ylabel('Feature Std Dev')
    
    # 4. Protein count by type (if we can extract protein types)
    protein_types = [name.split('_')[0] for name in protein_names[:min(100, len(protein_names))]]
    type_counts = pd.Series(protein_types).value_counts()
    
    axes[1,1].bar(range(len(type_counts)), type_counts.values, color='gold', alpha=0.7)
    axes[1,1].set_title('Protein Type Distribution (First 100 samples)')
    axes[1,1].set_xlabel('Protein Type')
    axes[1,1].set_ylabel('Count')
    axes[1,1].set_xticks(range(len(type_counts)))
    axes[1,1].set_xticklabels(type_counts.index, rotation=45, ha='right')
    
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    print(f"\nData Distribution Summary:")
    print(f"Shape: {data_matrix.shape}")
    print(f"Mean: {data_matrix.mean():.6f}")
    print(f"Std: {data_matrix.std():.6f}")
    print(f"Min: {data_matrix.min():.6f}")
    print(f"Max: {data_matrix.max():.6f}")
    print(f"Unique proteins: {len(set([name.split('_')[0] for name in protein_names]))}")

print("Visualization functions defined successfully!")

In [None]:
# =============================================================================
# CELL 3: Data Loading with Fixed Padding Logic
# =============================================================================

from sklearn.preprocessing import StandardScaler

def load_protein_data(data_directory="Proteins_layer47", normalize=True):
    """
    Load and preprocess protein data for autoencoder training.
    Uses zero-padding to create fixed 50-dimensional vectors.

    Args:
        data_directory (str): Path to folder containing protein .npy files
        normalize (bool): Whether to standardize the data

    Returns:
        tuple: (data_matrix, protein_names, original_L)
            - data_matrix: numpy array of shape (n_proteins * 128, 2500)
            - protein_names: list of protein names (repeated 128 times)
            - original_L: int, original L dimension before padding
    """

    data_path = Path(data_directory)
    if not data_path.exists():
        raise FileNotFoundError(f"Directory {data_directory} not found")

    npy_files = list(data_path.glob("*.npy"))
    if not npy_files:
        raise ValueError(f"No .npy files found in {data_directory}")

    print(f"Found {len(npy_files)} protein files")

    # Load protein data and extract vectors for each of the 128 channels
    all_vectors = []
    all_protein_names = []
    original_L = None
    target_dim = 50  # Target dimension for padding

    for file_path in npy_files:
        try:
            # Load the protein data (L×L×128)
            data = np.load(file_path)

            if original_L is None:
                original_L = data.shape[0]  # Store original L dimension

            # Extract each of the 128 feature vectors (L×L each)
            L = data.shape[0]
            for channel in range(data.shape[2]):
                # Get L×L matrix for this channel
                channel_matrix = data[:, :, channel]  # Shape: (L, L)

                # Pad to exactly 50×50 = 2,500 elements
                if L <= target_dim:
                    # Zero-pad to 50×50
                    padded_matrix = np.zeros((target_dim, target_dim), dtype=channel_matrix.dtype)
                    padded_matrix[:L, :L] = channel_matrix
                else:
                    # If L > 50, truncate to 50×50 (though problem states padding for L≤50)
                    padded_matrix = channel_matrix[:target_dim, :target_dim]
                
                # Flatten to get 2,500 elements
                vector = padded_matrix.flatten()

                all_vectors.append(vector)
                all_protein_names.append(f"{file_path.stem}_ch{channel:03d}")

        except Exception as e:
            print(f"Error loading {file_path.name}: {e}")
            continue

    if not all_vectors:
        raise ValueError("No protein data was successfully loaded")

    # Convert to numpy array
    data_matrix = np.array(all_vectors)

    # Normalize the data if requested
    if normalize:
        scaler = StandardScaler()
        data_matrix = scaler.fit_transform(data_matrix)

    print(f"Loaded data shape: {data_matrix.shape}")
    print(f"Total vectors: {len(all_vectors)}")
    print(f"Vector size: {data_matrix.shape[1]} (50×50 padded matrices)")
    print(f"Original L dimension: {original_L}")
    
    # Check for any problematic values
    if np.isnan(data_matrix).any():
        print("Warning: NaN values detected in data matrix!")
        data_matrix = np.nan_to_num(data_matrix, nan=0.0)
    
    if np.isinf(data_matrix).any():
        print("Warning: Infinite values detected in data matrix!")
        data_matrix = np.nan_to_num(data_matrix, posinf=5.0, neginf=-5.0)
    
    print(f"Data value range after preprocessing: [{data_matrix.min():.3f}, {data_matrix.max():.3f}]")

    return data_matrix, all_protein_names, original_L

print("Data loading function with fixed padding ready!")

Found 4 protein files
Loaded data shape: (512, 490000)
Total vectors: 512
Vector size: 490000 (padded to 700×700 = 49,000)
Original L dimension: 280

Data ready for autoencoder:
Shape: (512, 490000)
Data type: float32
Value range: [-11.610, 14.519]
Original L dimension: 280


In [None]:
# =============================================================================
# CELL 4: Autoencoder Model and Training Functions
# =============================================================================

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import numpy as np
import pandas as pd

class SimpleAutoencoder(nn.Module):
    def __init__(self, input_dim=2500, proj_dim=1024, hidden_dim=128, latent_dim=8):
        super().__init__()
        
        # Encoder: input_dim → 1024 → 128 → 8
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, proj_dim),
            nn.ReLU(),
            nn.Linear(proj_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        
        # Decoder: 8 → 128 → 1024 → input_dim
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, proj_dim),
            nn.ReLU(),
            nn.Linear(proj_dim, input_dim)
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

def train_simple(model, train_loader, val_loader, lr, wd, device, return_history=False):
    """Train for 50 epochs with SGD and MSE loss"""
    model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=wd)
    criterion = nn.MSELoss()
    
    # Track training history
    train_losses = []
    val_losses = []
    
    for epoch in range(50):
        # Training
        model.train()
        epoch_train_loss = 0
        for batch in train_loader:
            data = batch[0].to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, data)
            loss.backward()
            optimizer.step()
            epoch_train_loss += loss.item()
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                data = batch[0].to(device)
                output = model(data)
                val_loss += criterion(output, data).item()
        
        # Store losses
        avg_train_loss = epoch_train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        
        if (epoch + 1) % 10 == 0:
            print(f"  Epoch {epoch+1}/50, Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}")
    
    if return_history:
        return avg_val_loss, train_losses, val_losses
    return avg_val_loss

def plot_training_curves(fold_histories, hyperparams):
    """
    Plot training curves for all folds and hyperparameter combinations
    """
    n_params = len(hyperparams)
    fig, axes = plt.subplots(2, (n_params + 1) // 2, figsize=(15, 8))
    if n_params == 1:
        axes = np.array([axes])
    axes = axes.flatten()
    
    colors = plt.cm.Set3(np.linspace(0, 1, 5))  # 5 colors for 5 folds
    
    for param_idx, (params, histories) in enumerate(zip(hyperparams, fold_histories)):
        ax = axes[param_idx]
        
        # Plot each fold
        for fold_idx, (train_losses, val_losses) in enumerate(histories):
            epochs = range(1, len(train_losses) + 1)
            ax.plot(epochs, train_losses, '--', color=colors[fold_idx], alpha=0.7, 
                   label=f'Fold {fold_idx+1} Train' if param_idx == 0 else "")
            ax.plot(epochs, val_losses, '-', color=colors[fold_idx], alpha=0.9,
                   label=f'Fold {fold_idx+1} Val' if param_idx == 0 else "")
        
        ax.set_title(f'lr={params["lr"]}, wd={params["wd"]:.0e}')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.grid(True, alpha=0.3)
        
        if param_idx == 0:
            ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    
    # Hide unused subplots
    for i in range(n_params, len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    plt.suptitle('Training Curves for All Hyperparameter Combinations', y=1.02, fontsize=14)
    plt.show()

def plot_hyperparameter_comparison(results_df):
    """
    Create comprehensive hyperparameter comparison plots
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # 1. Heatmap of mean scores
    pivot_table = results_df.pivot(index='lr', columns='wd', values='mean_score')
    sns.heatmap(pivot_table, annot=True, fmt='.6f', cmap='viridis_r', ax=axes[0,0])
    axes[0,0].set_title('Mean CV Score Heatmap')
    
    # 2. Bar plot of all combinations
    param_labels = [f"lr={row['lr']:.0e}\nwd={row['wd']:.0e}" for _, row in results_df.iterrows()]
    axes[0,1].bar(range(len(results_df)), results_df['mean_score'], 
                  yerr=results_df['std_score'], capsize=5, alpha=0.7, color='skyblue')
    axes[0,1].set_xlabel('Hyperparameter Combination')
    axes[0,1].set_ylabel('Mean CV Score')
    axes[0,1].set_title('Hyperparameter Comparison with Error Bars')
    axes[0,1].set_xticks(range(len(results_df)))
    axes[0,1].set_xticklabels(param_labels, rotation=45, ha='right', fontsize=8)
    
    # 3. Learning rate comparison
    lr_grouped = results_df.groupby('lr').agg({'mean_score': ['mean', 'std']}).reset_index()
    lr_grouped.columns = ['lr', 'mean', 'std']
    axes[1,0].bar(range(len(lr_grouped)), lr_grouped['mean'], 
                  yerr=lr_grouped['std'], capsize=5, alpha=0.7, color='lightcoral')
    axes[1,0].set_xlabel('Learning Rate')
    axes[1,0].set_ylabel('Mean CV Score')
    axes[1,0].set_title('Learning Rate Effect')
    axes[1,0].set_xticks(range(len(lr_grouped)))
    axes[1,0].set_xticklabels([f'{lr:.0e}' for lr in lr_grouped['lr']])
    
    # 4. Weight decay comparison
    wd_grouped = results_df.groupby('wd').agg({'mean_score': ['mean', 'std']}).reset_index()
    wd_grouped.columns = ['wd', 'mean', 'std']
    axes[1,1].bar(range(len(wd_grouped)), wd_grouped['mean'], 
                  yerr=wd_grouped['std'], capsize=5, alpha=0.7, color='lightgreen')
    axes[1,1].set_xlabel('Weight Decay')
    axes[1,1].set_ylabel('Mean CV Score')
    axes[1,1].set_title('Weight Decay Effect')
    axes[1,1].set_xticks(range(len(wd_grouped)))
    axes[1,1].set_xticklabels([f'{wd:.0e}' for wd in wd_grouped['wd']])
    
    plt.tight_layout()
    plt.show()

def visualize_model_architecture(input_dim=2500):
    """
    Create a visual representation of the autoencoder architecture
    """
    fig, ax = plt.subplots(1, 1, figsize=(14, 8))
    
    # Define layer dimensions and positions
    layers = [
        {'name': f'Input\n{input_dim:,}', 'size': input_dim, 'pos': 0, 'color': 'lightblue'},
        {'name': 'Projection\n1,024', 'size': 1024, 'pos': 2, 'color': 'lightgreen'},
        {'name': 'Hidden\n128', 'size': 128, 'pos': 4, 'color': 'orange'},
        {'name': 'Latent\n8', 'size': 8, 'pos': 6, 'color': 'red'},
        {'name': 'Hidden\n128', 'size': 128, 'pos': 8, 'color': 'orange'},
        {'name': 'Projection\n1,024', 'size': 1024, 'pos': 10, 'color': 'lightgreen'},
        {'name': f'Output\n{input_dim:,}', 'size': input_dim, 'pos': 12, 'color': 'lightblue'}
    ]
    
    max_size = max(layer['size'] for layer in layers)
    
    # Draw layers
    for i, layer in enumerate(layers):
        # Calculate height proportional to layer size (with minimum height)
        height = max(0.5, (layer['size'] / max_size) * 4)
        
        # Draw rectangle for layer
        rect = Rectangle((layer['pos'], -height/2), 1.5, height, 
                        facecolor=layer['color'], edgecolor='black', alpha=0.7)
        ax.add_patch(rect)
        
        # Add layer label
        ax.text(layer['pos'] + 0.75, height/2 + 0.3, layer['name'], 
               ha='center', va='bottom', fontsize=10, fontweight='bold')
        
        # Draw arrows between layers
        if i < len(layers) - 1:
            ax.arrow(layer['pos'] + 1.5, 0, 0.4, 0, 
                    head_width=0.1, head_length=0.1, fc='black', ec='black')
    
    # Add encoder/decoder labels
    ax.text(3, -3, 'ENCODER', ha='center', va='center', fontsize=14, 
           fontweight='bold', bbox=dict(boxstyle="round,pad=0.3", facecolor='lightgray'))
    ax.text(9, -3, 'DECODER', ha='center', va='center', fontsize=14, 
           fontweight='bold', bbox=dict(boxstyle="round,pad=0.3", facecolor='lightgray'))
    
    # Add compression ratio annotation
    compression_ratio = input_dim / 8
    ax.text(6, 3, f'Compression Ratio:\n{input_dim:,} → 8\n({compression_ratio:,.0f}:1)', 
           ha='center', va='center', fontsize=12, 
           bbox=dict(boxstyle="round,pad=0.5", facecolor='yellow', alpha=0.7))
    
    ax.set_xlim(-1, 14)
    ax.set_ylim(-4, 4)
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_title('Autoencoder Architecture', fontsize=16, fontweight='bold', pad=20)
    
    plt.tight_layout()
    plt.show()

print("Autoencoder model and training functions ready!")

In [None]:
# =============================================================================
# CELL 5: Main Training Function with Full Visualization
# =============================================================================

def run_simple_training_fixed():
    # Load data
    data_matrix, protein_names, original_L = load_protein_data("Proteins_layer47", normalize=True)
    
    print(f"Data loaded: {data_matrix.shape}")
    input_dim = data_matrix.shape[1]  # Get actual input dimension
    print(f"Using input_dim = {input_dim}")
    
    # Hyperparameters (exact as specified)
    learning_rates = [0.01, 0.001, 0.0001]
    weight_decays = [1e-4, 1e-5]
    
    best_score = float('inf')
    best_params = None
    
    # Storage for visualization
    all_results = []
    all_fold_histories = []
    all_hyperparams = []
    
    print("\nStarting K-fold cross validation with hyperparameter grid search...")
    print(f"Architecture: input_dim={input_dim}, proj_dim=1024, hidden_dim=128, latent_dim=8")
    print(f"Training: 50 epochs, SGD, MSE loss, batch_size=64")
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    # Grid search with K-fold CV
    for lr in learning_rates:
        for wd in weight_decays:
            print(f"\nTesting lr={lr}, wd={wd}")
            
            kfold = KFold(n_splits=5, shuffle=True, random_state=42)
            scores = []
            fold_histories = []
            
            for fold, (train_idx, val_idx) in enumerate(kfold.split(data_matrix)):
                print(f"  Fold {fold+1}/5")
                
                # Split data
                train_data = data_matrix[train_idx]
                val_data = data_matrix[val_idx]
                
                # Create loaders (batch_size=64 as specified)
                train_loader = DataLoader(TensorDataset(torch.FloatTensor(train_data)), 
                                        batch_size=64, shuffle=True)
                val_loader = DataLoader(TensorDataset(torch.FloatTensor(val_data)), 
                                      batch_size=64, shuffle=False)
                
                # Train model with CORRECT input dimension and collect history
                model = SimpleAutoencoder(input_dim=input_dim)
                score, train_losses, val_losses = train_simple(
                    model, train_loader, val_loader, lr, wd, device, return_history=True
                )
                scores.append(score)
                fold_histories.append((train_losses, val_losses))
            
            mean_score = np.mean(scores)
            std_score = np.std(scores)
            print(f"  Mean: {mean_score:.6f} ± {std_score:.6f}")
            
            # Store results
            all_results.append({
                'lr': lr,
                'wd': wd,
                'mean_score': mean_score,
                'std_score': std_score,
                'scores': scores
            })
            all_fold_histories.append(fold_histories)
            all_hyperparams.append({'lr': lr, 'wd': wd})
            
            if mean_score < best_score:
                best_score = mean_score
                best_params = {'lr': lr, 'wd': wd}
    
    print(f"\n{'='*50}")
    print(f"BEST RESULTS:")
    print(f"Best params: lr={best_params['lr']}, wd={best_params['wd']}")
    print(f"Best CV score: {best_score:.6f}")
    print(f"{'='*50}")
    
    # Create results DataFrame
    results_df = pd.DataFrame(all_results)
    
    # Plot training curves for all hyperparameter combinations
    print("\nGenerating training curve visualizations...")
    plot_training_curves(all_fold_histories, all_hyperparams)
    
    # Plot hyperparameter comparison
    print("\nGenerating hyperparameter comparison plots...")
    plot_hyperparameter_comparison(results_df)
    
    # Train final model with best parameters
    print("\nTraining final model with best parameters...")
    split_idx = int(0.8 * len(data_matrix))
    train_data = data_matrix[:split_idx]
    val_data = data_matrix[split_idx:]
    
    train_loader = DataLoader(TensorDataset(torch.FloatTensor(train_data)), 
                            batch_size=64, shuffle=True)
    val_loader = DataLoader(TensorDataset(torch.FloatTensor(val_data)), 
                          batch_size=64, shuffle=False)
    
    # Final model with CORRECT input dimension
    final_model = SimpleAutoencoder(input_dim=input_dim)
    final_score, final_train_losses, final_val_losses = train_simple(
        final_model, train_loader, val_loader, 
        best_params['lr'], best_params['wd'], device, return_history=True
    )
    
    print(f"Final model validation score: {final_score:.6f}")
    
    # Plot final training curve
    print("\nGenerating final model training curve...")
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    epochs = range(1, len(final_train_losses) + 1)
    ax.plot(epochs, final_train_losses, '--', label='Training Loss', linewidth=2)
    ax.plot(epochs, final_val_losses, '-', label='Validation Loss', linewidth=2)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title(f'Final Model Training (lr={best_params["lr"]}, wd={best_params["wd"]:.0e})')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    # Save model
    torch.save(final_model.state_dict(), 'simple_autoencoder_fixed.pth')
    print("Model saved as 'simple_autoencoder_fixed.pth'")
    
    return final_model, best_params, best_score, results_df

print("Main training function with full visualization ready!")

In [None]:
# =============================================================================
# CELL 6: Execute Training with Visualizations
# =============================================================================

print("="*70)
print("STARTING COMPREHENSIVE AUTOENCODER TRAINING WITH VISUALIZATION")
print("="*70)

# First, visualize the model architecture
print("\nVisualizing Model Architecture...")
visualize_model_architecture(input_dim=2500)

# Now run the actual training
print("\nStarting actual autoencoder training...")
final_model, best_params, best_score, results_df = run_simple_training_fixed()

print("\nTraining completed!")
print(f"Best hyperparameters: {best_params}")
print(f"Best validation score: {best_score:.6f}")

# Display detailed results summary
print("\n" + "="*60)
print("DETAILED TRAINING RESULTS SUMMARY")
print("="*60)
print(f"Final MSE Loss: {best_score:.6f}")
print(f"Best Learning Rate: {best_params['lr']}")
print(f"Best Weight Decay: {best_params['wd']:.2e}")
print("\nAll Hyperparameter Results:")
for _, row in results_df.iterrows():
    print(f"  lr={row['lr']:.0e}, wd={row['wd']:.0e} → MSE: {row['mean_score']:.6f} ± {row['std_score']:.6f}")
print("="*60)


In [None]:
# =============================================================================
# CELL 7: Detailed Loss Analysis and Model Evaluation
# =============================================================================

def evaluate_reconstruction_quality(model, data_loader, device):
    """
    Evaluate reconstruction quality with detailed metrics
    """
    model.eval()
    total_mse = 0
    total_mae = 0
    total_samples = 0
    
    all_originals = []
    all_reconstructions = []
    
    with torch.no_grad():
        for batch in data_loader:
            data = batch[0].to(device)
            reconstructed = model(data)
            
            # Calculate MSE and MAE
            mse = nn.MSELoss()(reconstructed, data)
            mae = nn.L1Loss()(reconstructed, data)
            
            total_mse += mse.item() * data.size(0)
            total_mae += mae.item() * data.size(0)
            total_samples += data.size(0)
            
            # Store for visualization (sample a few)
            if len(all_originals) < 100:  # Limit for memory
                all_originals.extend(data.cpu().numpy())
                all_reconstructions.extend(reconstructed.cpu().numpy())
    
    avg_mse = total_mse / total_samples
    avg_mae = total_mae / total_samples
    
    return avg_mse, avg_mae, np.array(all_originals), np.array(all_reconstructions)

def plot_reconstruction_analysis(originals, reconstructions):
    """
    Plot detailed reconstruction analysis
    """
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    
    # 1. Original vs Reconstructed scatter plot
    sample_indices = np.random.choice(len(originals), min(1000, len(originals)), replace=False)
    orig_sample = originals[sample_indices].flatten()
    recon_sample = reconstructions[sample_indices].flatten()
    
    axes[0,0].scatter(orig_sample, recon_sample, alpha=0.5, s=1)
    axes[0,0].plot([orig_sample.min(), orig_sample.max()], 
                   [orig_sample.min(), orig_sample.max()], 'r--', lw=2)
    axes[0,0].set_xlabel('Original Values')
    axes[0,0].set_ylabel('Reconstructed Values')
    axes[0,0].set_title('Original vs Reconstructed Values')
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. Reconstruction error distribution
    errors = (originals - reconstructions).flatten()
    axes[0,1].hist(errors, bins=50, alpha=0.7, color='orange', edgecolor='black')
    axes[0,1].set_xlabel('Reconstruction Error')
    axes[0,1].set_ylabel('Frequency')
    axes[0,1].set_title('Distribution of Reconstruction Errors')
    axes[0,1].axvline(0, color='red', linestyle='--', label='Perfect Reconstruction')
    axes[0,1].legend()
    
    # 3. Sample-wise MSE
    sample_mse = np.mean((originals - reconstructions)**2, axis=1)
    axes[0,2].hist(sample_mse, bins=30, alpha=0.7, color='lightgreen', edgecolor='black')
    axes[0,2].set_xlabel('Sample MSE')
    axes[0,2].set_ylabel('Frequency')
    axes[0,2].set_title('Distribution of Sample-wise MSE')
    
    # 4. Feature-wise reconstruction quality
    feature_mse = np.mean((originals - reconstructions)**2, axis=0)
    axes[1,0].plot(feature_mse[:1000])  # Plot first 1000 features
    axes[1,0].set_xlabel('Feature Index')
    axes[1,0].set_ylabel('Feature MSE')
    axes[1,0].set_title('Feature-wise Reconstruction Error (First 1000)')
    axes[1,0].grid(True, alpha=0.3)
    
    # 5. Correlation between original and reconstructed
    correlations = []
    for i in range(min(100, len(originals))):  # Sample 100 examples
        corr = np.corrcoef(originals[i], reconstructions[i])[0,1]
        if not np.isnan(corr):
            correlations.append(corr)
    
    axes[1,1].hist(correlations, bins=20, alpha=0.7, color='purple', edgecolor='black')
    axes[1,1].set_xlabel('Correlation Coefficient')
    axes[1,1].set_ylabel('Frequency')
    axes[1,1].set_title('Sample-wise Correlation Distribution')
    axes[1,1].axvline(np.mean(correlations), color='red', linestyle='--', 
                      label=f'Mean: {np.mean(correlations):.3f}')
    axes[1,1].legend()
    
    # 6. Reconstruction quality by magnitude
    orig_magnitudes = np.linalg.norm(originals, axis=1)
    sample_errors = np.mean(np.abs(originals - reconstructions), axis=1)
    
    axes[1,2].scatter(orig_magnitudes, sample_errors, alpha=0.6, s=10)
    axes[1,2].set_xlabel('Original Sample Magnitude')
    axes[1,2].set_ylabel('Mean Absolute Error')
    axes[1,2].set_title('Reconstruction Error vs Sample Magnitude')
    axes[1,2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    print(f"\nReconstruction Quality Summary:")
    print(f"Mean Reconstruction Error: {np.mean(errors):.6f}")
    print(f"Std Reconstruction Error: {np.std(errors):.6f}")
    print(f"Mean Sample MSE: {np.mean(sample_mse):.6f}")
    print(f"Mean Sample Correlation: {np.mean(correlations):.4f}")
    print(f"Min/Max Correlations: {np.min(correlations):.4f} / {np.max(correlations):.4f}")

print("Detailed evaluation functions ready!")

# Run detailed evaluation on the trained model
if 'final_model' in locals():
    print("\n" + "="*70)
    print("RUNNING DETAILED MSE LOSS AND RECONSTRUCTION ANALYSIS")
    print("="*70)
    
    # Load data for evaluation
    data_matrix, protein_names, original_L = load_protein_data("Proteins_layer47", normalize=True)
    
    # Create test data loader
    test_loader = DataLoader(TensorDataset(torch.FloatTensor(data_matrix)), 
                           batch_size=64, shuffle=False)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Evaluate reconstruction quality
    print("Evaluating reconstruction quality...")
    mse, mae, originals, reconstructions = evaluate_reconstruction_quality(
        final_model, test_loader, device)
    
    print(f"\nFinal Model Performance:")
    print(f"MSE Loss: {mse:.6f}")
    print(f"MAE Loss: {mae:.6f}")
    print(f"RMSE: {np.sqrt(mse):.6f}")
    
    # Plot detailed reconstruction analysis
    print("\nGenerating detailed reconstruction analysis plots...")
    plot_reconstruction_analysis(originals, reconstructions)
    
else:
    print("No trained model found. Please run the training first!")
