In [3]:
!pip install -q tensorflow_datasets umap-learn "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.htmlimport jax

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.4.1 requires nvidia-cudnn-cu12==9.1.0.70; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cudnn-cu12 9.9.0.52 which is incompatible.
tensorflow 2.17.0 requires ml-dtypes<0.5.0,>=0.3.1, but you have ml-dtypes 0.5.1 which is incompatible.[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.2[0m[39;49m -> [0m[32;49m25.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [None]:
# ==============================
# 1. IMPORTS
# ==============================
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap
import numpy as np
from functools import partial
import time
import os
import urllib.request
import tarfile
import pickle
from datetime import datetime
from tqdm.auto import tqdm
from dataclasses import dataclass, field
from typing import Dict, Tuple, List, Any, NamedTuple

# For visualization and metrics
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import umap


In [None]:
# ==============================
# 2. CONFIGURATION CLASSES
# ==============================
@dataclass
class ViTConfig:
    """Configuration for Vision Transformer model"""
    # Image processing
    img_size: int = 32
    patch_size: int = 4
    
    # Model architecture
    num_classes: int = 10
    num_heads: int = 8
    num_layers: int = 6
    hidden_dim: int = 384
    mlp_dim: int = 1536
    dropout_rate: float = 0.1
    initializer_range: float = 0.02
    
    # Training parameters
    batch_size: int = 128
    num_epochs: int = 30
    learning_rate: float = 3e-4
    weight_decay: float = 0.01
    warmup_steps: int = 500
    
    # Optimizer parameters
    beta1: float = 0.9
    beta2: float = 0.999
    eps: float = 1e-8
    
    # Other settings
    seed: int = 42
    
    @property
    def num_patches(self) -> int:
        """Calculate the number of patches based on image and patch size"""
        return (self.img_size // self.patch_size) ** 2


@dataclass
class TrainingState:
    """State for tracking training progress"""
    params: Any
    opt_state: Any
    train_losses: List[float] = field(default_factory=list)
    train_accs: List[float] = field(default_factory=list)
    eval_losses: List[float] = field(default_factory=list)
    eval_accs: List[float] = field(default_factory=list)
    step: int = 0
    epoch: int = 0
    best_accuracy: float = 0.0
    best_epoch: int = 0


In [None]:
# ==============================
# 3. UTILITY CLASSES
# ==============================
class Utils:
    """Utility functions for the Vision Transformer"""
    
    @staticmethod
    def to_device(x):
        """Helper function to move data to the selected device"""
        device = jax.devices()[0]
        return jax.device_put(x, device)
    
    @staticmethod
    def get_initializer(scale: float = 0.02):
        """Get weight initializer function"""
        return lambda key, shape, dtype=jnp.float32: random.normal(key, shape, dtype) * scale


class Metrics:
    """Class for computing and tracking metrics"""
    
    @staticmethod
    def cross_entropy_loss(logits: jnp.ndarray, labels: jnp.ndarray) -> float:
        """Compute cross-entropy loss"""
        one_hot_labels = jax.nn.one_hot(labels, 10)  # CIFAR-10 has 10 classes
        softmax_logits = jax.nn.log_softmax(logits)
        loss = -jnp.sum(one_hot_labels * softmax_logits) / labels.shape[0]
        return loss
    
    @staticmethod
    def accuracy(logits: jnp.ndarray, labels: jnp.ndarray) -> float:
        """Compute accuracy"""
        preds = jnp.argmax(logits, axis=-1)
        return jnp.mean(preds == labels)
    
    @staticmethod
    def get_predictions(logits: jnp.ndarray) -> jnp.ndarray:
        """Get class predictions from logits"""
        return jnp.argmax(logits, axis=-1)
    
    @staticmethod
    def compute_detailed_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict:
        """Compute detailed metrics including confusion matrix"""
        # Compute confusion matrix
        cm = confusion_matrix(y_true, y_pred)
        
        # Compute precision, recall, and F1 score for each class
        precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average=None)
        
        # Compute macro-averaged metrics
        macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(
            y_true, y_pred, average='macro'
        )
        
        return {
            'confusion_matrix': cm,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'macro_precision': macro_precision,
            'macro_recall': macro_recall,
            'macro_f1': macro_f1
        }



In [None]:
# ==============================
# 4. PARAMETER INITIALIZATION
# ==============================
class ParameterInitializer:
    """Handles parameter initialization for the Vision Transformer"""
    
    def __init__(self, config: ViTConfig):
        self.config = config
        self.device = jax.devices()[0]
        
    def init_linear_params(self, key: jnp.ndarray, in_dim: int, out_dim: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Initialize linear layer parameters"""
        k1, k2 = random.split(key)
        weight_init = Utils.get_initializer(self.config.initializer_range)
        weight = weight_init(k1, (in_dim, out_dim))
        bias = jnp.zeros((out_dim,))
        return weight, bias
    
    def init_layer_norm_params(self, dim: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Initialize layer normalization parameters"""
        scale = jnp.ones((dim,))
        bias = jnp.zeros((dim,))
        return scale, bias
    
    def init_transformer_params(self, key: jnp.ndarray) -> Dict:
        """Initialize all transformer parameters"""
        keys = random.split(key, num=self.config.num_layers * 4 + 5)
        key_idx = 0
        
        # Projection layer (patch embedding)
        patch_projection_w, patch_projection_b = self.init_linear_params(
            keys[key_idx], 
            self.config.patch_size * self.config.patch_size * 3, 
            self.config.hidden_dim
        )
        key_idx += 1
        
        # Class token and position embeddings
        cls_token = random.normal(keys[key_idx], (1, self.config.hidden_dim)) * self.config.initializer_range
        key_idx += 1
        pos_embedding = random.normal(
            keys[key_idx], 
            (1, self.config.num_patches + 1, self.config.hidden_dim)
        ) * self.config.initializer_range
        key_idx += 1
        
        # Initialize transformer blocks
        encoder_blocks = []
        for _ in range(self.config.num_layers):
            block = self._init_transformer_block(keys[key_idx:key_idx+4])
            encoder_blocks.append(block)
            key_idx += 4
        
        # Final layer norm and classification head
        ln_final_scale, ln_final_bias = self.init_layer_norm_params(self.config.hidden_dim)
        head_w, head_b = self.init_linear_params(keys[key_idx], self.config.hidden_dim, self.config.num_classes)
        
        # Construct the parameter dictionary
        params = {
            'patch_projection_w': patch_projection_w,
            'patch_projection_b': patch_projection_b,
            'cls_token': cls_token,
            'pos_embedding': pos_embedding,
            'encoder_blocks': encoder_blocks,
            'ln_final_scale': ln_final_scale,
            'ln_final_bias': ln_final_bias,
            'head_w': head_w,
            'head_b': head_b,
        }
        
        # Move all parameters to device
        params = jax.tree_util.tree_map(lambda x: jax.device_put(x, self.device), params)
        
        return params
    
    def _init_transformer_block(self, keys: List[jnp.ndarray]) -> Dict:
        """Initialize a single transformer block"""
        # Layer norm 1
        ln1_scale, ln1_bias = self.init_layer_norm_params(self.config.hidden_dim)
        
        # Multi-head attention
        qkv_w, qkv_b = self.init_linear_params(keys[0], self.config.hidden_dim, 3 * self.config.hidden_dim)
        out_w, out_b = self.init_linear_params(keys[1], self.config.hidden_dim, self.config.hidden_dim)
        
        # Layer norm 2
        ln2_scale, ln2_bias = self.init_layer_norm_params(self.config.hidden_dim)
        
        # MLP
        mlp1_w, mlp1_b = self.init_linear_params(keys[2], self.config.hidden_dim, self.config.mlp_dim)
        mlp2_w, mlp2_b = self.init_linear_params(keys[3], self.config.mlp_dim, self.config.hidden_dim)
        
        return {
            'ln1_scale': ln1_scale,
            'ln1_bias': ln1_bias,
            'qkv_w': qkv_w,
            'qkv_b': qkv_b,
            'out_w': out_w,
            'out_b': out_b,
            'ln2_scale': ln2_scale,
            'ln2_bias': ln2_bias,
            'mlp1_w': mlp1_w,
            'mlp1_b': mlp1_b,
            'mlp2_w': mlp2_w,
            'mlp2_b': mlp2_b,
        }



In [None]:

# ==============================
# 5. MODEL COMPONENTS
# ==============================
class ModelLayers:
    """Basic neural network layers"""
    
    @staticmethod
    def linear(x: jnp.ndarray, w: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
        """Linear layer"""
        return jnp.dot(x, w) + b
    
    @staticmethod
    def layer_norm(x: jnp.ndarray, scale: jnp.ndarray, bias: jnp.ndarray, eps: float = 1e-6) -> jnp.ndarray:
        """Layer normalization"""
        mean = jnp.mean(x, axis=-1, keepdims=True)
        var = jnp.var(x, axis=-1, keepdims=True)
        return scale * (x - mean) / jnp.sqrt(var + eps) + bias
    
    @staticmethod
    def gelu(x: jnp.ndarray) -> jnp.ndarray:
        """GELU activation function"""
        return 0.5 * x * (1 + jnp.tanh(jnp.sqrt(2 / jnp.pi) * (x + 0.044715 * x**3)))
    
    @staticmethod
    def dropout(key: jnp.ndarray, x: jnp.ndarray, rate: float) -> jnp.ndarray:
        """Apply dropout"""
        if rate == 0.0:
            return x
        keep_prob = 1.0 - rate
        mask = random.bernoulli(key, keep_prob, x.shape)
        return x * mask / keep_prob



In [None]:
# ==============================
# 6. TRANSFORMER COMPONENTS
# ==============================
class TransformerBlocks:
    """Transformer-specific components"""
    
    def __init__(self, config: ViTConfig):
        self.config = config
        self.layers = ModelLayers()
    
    def multi_head_attention(
        self, 
        x: jnp.ndarray, 
        params: Dict, 
        key: jnp.ndarray, 
        training: bool = True
    ) -> jnp.ndarray:
        """Multi-head self-attention"""
        batch_size, seq_len, hidden_dim = x.shape
        head_dim = hidden_dim // self.config.num_heads
        
        # Project queries, keys, and values
        qkv = self.layers.linear(x, params['qkv_w'], params['qkv_b'])
        qkv = qkv.reshape(batch_size, seq_len, 3, self.config.num_heads, head_dim)
        qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Scaled dot-product attention
        scale = jnp.sqrt(head_dim)
        attention_scores = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) / scale
        attention_probs = jax.nn.softmax(attention_scores)
        
        # Apply dropout to attention probabilities during training
        if training and self.config.dropout_rate > 0:
            attention_key = random.split(key, num=1)[0]
            attention_probs = self.layers.dropout(attention_key, attention_probs, self.config.dropout_rate)
        
        context = jnp.matmul(attention_probs, v)
        context = jnp.transpose(context, (0, 2, 1, 3))
        context = context.reshape(batch_size, seq_len, hidden_dim)
        
        # Output projection
        output = self.layers.linear(context, params['out_w'], params['out_b'])
        
        # Apply dropout during training
        if training and self.config.dropout_rate > 0:
            output_key = random.split(key, num=1)[0]
            output = self.layers.dropout(output_key, output, self.config.dropout_rate)
        
        return output
    
    def mlp_block(
        self, 
        x: jnp.ndarray, 
        params: Dict, 
        key: jnp.ndarray, 
        training: bool = True
    ) -> jnp.ndarray:
        """MLP block with GELU activation"""
        # First dense layer
        x = self.layers.linear(x, params['mlp1_w'], params['mlp1_b'])
        x = self.layers.gelu(x)
        
        # Apply dropout during training
        if training and self.config.dropout_rate > 0:
            mlp1_key = random.split(key, num=1)[0]
            x = self.layers.dropout(mlp1_key, x, self.config.dropout_rate)
        
        # Second dense layer
        x = self.layers.linear(x, params['mlp2_w'], params['mlp2_b'])
        
        # Apply dropout during training
        if training and self.config.dropout_rate > 0:
            mlp2_key = random.split(key, num=1)[0]
            x = self.layers.dropout(mlp2_key, x, self.config.dropout_rate)
        
        return x
    
    def encoder_block(
        self, 
        x: jnp.ndarray, 
        params: Dict, 
        key: jnp.ndarray, 
        training: bool = True
    ) -> jnp.ndarray:
        """Transformer encoder block"""
        # Layer norm 1
        norm1 = self.layers.layer_norm(x, params['ln1_scale'], params['ln1_bias'])
        
        # Multi-head attention with residual connection
        attn_key = random.split(key, num=1)[0]
        attn_output = self.multi_head_attention(norm1, params, attn_key, training)
        x = x + attn_output
        
        # Layer norm 2
        norm2 = self.layers.layer_norm(x, params['ln2_scale'], params['ln2_bias'])
        
        # MLP block with residual connection
        mlp_key = random.split(key, num=1)[0]
        mlp_output = self.mlp_block(norm2, params, mlp_key, training)
        x = x + mlp_output
        
        return x


In [None]:
# ==============================
# 7. MAIN MODEL CLASS
# ==============================
class VisionTransformer:
    """Vision Transformer model implementation"""
    
    def __init__(self, config: ViTConfig):
        self.config = config
        self.layers = ModelLayers()
        self.transformer = TransformerBlocks(config)
    
    def __call__(
        self, 
        params: Dict, 
        images: jnp.ndarray, 
        key: jnp.ndarray, 
        training: bool = True,
        return_features: bool = False
    ) -> Any:
        """Forward pass of Vision Transformer"""
        batch_size, height, width, channels = images.shape
        
        # Reshape images into patches
        patches = jnp.reshape(
            images, 
            (batch_size, height // self.config.patch_size, self.config.patch_size, 
             width // self.config.patch_size, self.config.patch_size, channels)
        )
        patches = jnp.transpose(patches, (0, 1, 3, 2, 4, 5))
        patches = jnp.reshape(
            patches, 
            (batch_size, self.config.num_patches, self.config.patch_size * self.config.patch_size * channels)
        )
        
        # Linear projection of flattened patches
        patch_embeddings = self.layers.linear(
            patches, 
            params['patch_projection_w'], 
            params['patch_projection_b']
        )
        
        # Add class token
        cls_tokens = jnp.broadcast_to(
            params['cls_token'], 
            (batch_size, 1, self.config.hidden_dim)
        )
        x = jnp.concatenate([cls_tokens, patch_embeddings], axis=1)
        
        # Add position embeddings
        x = x + params['pos_embedding']
        
        # Apply dropout to embeddings during training
        if training and self.config.dropout_rate > 0:
            embed_key = random.split(key, num=1)[0]
            x = self.layers.dropout(embed_key, x, self.config.dropout_rate)
        
        # Apply transformer encoder blocks
        for i, block_params in enumerate(params['encoder_blocks']):
            block_key = random.fold_in(key, i)
            x = self.transformer.encoder_block(x, block_params, block_key, training)
        
        # Apply final layer norm
        x = self.layers.layer_norm(x, params['ln_final_scale'], params['ln_final_bias'])
        
        # Use [CLS] token representation for classification
        cls_representation = x[:, 0]
        
        # Classification head
        logits = self.layers.linear(cls_representation, params['head_w'], params['head_b'])
        
        if return_features:
            return logits, cls_representation
        else:
            return logits

In [None]:
# ==============================
# 8. OPTIMIZER
# ==============================
class AdamWOptimizer:
    """AdamW optimizer implementation"""
    
    def __init__(self, config: ViTConfig):
        self.config = config
        
    def init_state(self, params: Dict) -> Dict:
        """Initialize optimizer state"""
        # Initialize first and second moments to zeros
        m = jax.tree_util.tree_map(jnp.zeros_like, params)
        v = jax.tree_util.tree_map(jnp.zeros_like, params)
        
        return {
            'm': m,
            'v': v,
            't': 0,
        }
    
    def update(
        self, 
        grads: Dict, 
        opt_state: Dict, 
        params: Dict, 
        learning_rate: float
    ) -> Tuple[Dict, Dict]:
        """Apply AdamW update step"""
        # Extract optimizer state
        m, v, t = opt_state['m'], opt_state['v'], opt_state['t']
        
        # Increment step
        t = t + 1
        
        # Compute bias correction terms
        bias_correction1 = 1 - self.config.beta1 ** t
        bias_correction2 = 1 - self.config.beta2 ** t
        
        # Define update function for a single parameter pair
        def update_momentum(m_param, grad):
            return self.config.beta1 * m_param + (1 - self.config.beta1) * grad
        
        def update_velocity(v_param, grad):
            return self.config.beta2 * v_param + (1 - self.config.beta2) * jnp.square(grad)
        
        def update_param(param, m_param, v_param):
            # Compute bias-corrected moment estimates
            m_hat = m_param / bias_correction1
            v_hat = v_param / bias_correction2
            
            # Apply weight decay
            param_with_decay = param - learning_rate * self.config.weight_decay * param
            
            # Update parameter
            return param_with_decay - learning_rate * m_hat / (jnp.sqrt(v_hat) + self.config.eps)
        
        # Update momentum terms
        new_m = jax.tree_util.tree_map(update_momentum, m, grads)
        
        # Update velocity terms  
        new_v = jax.tree_util.tree_map(update_velocity, v, grads)
        
        # Update parameters
        new_params = jax.tree_util.tree_map(update_param, params, new_m, new_v)
        
        # Update optimizer state
        new_opt_state = {
            'm': new_m,
            'v': new_v,
            't': t,
        }
        
        return new_params, new_opt_state



In [None]:

# ==============================
# 9. LEARNING RATE SCHEDULER
# ==============================
class LearningRateScheduler:
    """Learning rate scheduler with warmup and cosine decay"""
    
    def __init__(self, config: ViTConfig, total_steps: int):
        self.config = config
        self.total_steps = total_steps
        self.warmup_steps = config.warmup_steps
        self.peak_lr = config.learning_rate
        self.end_lr = 0.0
        
    def __call__(self, step: int) -> float:
        """Get learning rate for the current step"""
        # Linear warmup
        warmup_lr = self.peak_lr * jnp.minimum(1.0, step / self.warmup_steps)
        
        # Cosine decay
        decay_steps = self.total_steps - self.warmup_steps
        decay_factor = 0.5 * (1 + jnp.cos(
            jnp.pi * jnp.minimum(step - self.warmup_steps, decay_steps) / decay_steps
        ))
        cosine_lr = self.end_lr + (self.peak_lr - self.end_lr) * decay_factor
        
        # Use warmup for steps < warmup_steps, cosine decay afterward
        lr = jnp.where(step < self.warmup_steps, warmup_lr, cosine_lr)
        
        return lr


In [None]:
# ==============================
# 10. DATA LOADER
# ==============================
class DataLoader:
    """Handles data loading for CIFAR-10"""
    
    CIFAR10_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                       'dog', 'frog', 'horse', 'ship', 'truck']
    
    def __init__(self, data_dir: str = "cifar10_data"):
        self.data_dir = data_dir
        self.url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
        
    def load_datasets(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Load CIFAR-10 dataset"""
        # Create data directory
        os.makedirs(self.data_dir, exist_ok=True)
        
        # Download if necessary
        tar_file_path = os.path.join(self.data_dir, "cifar-10-python.tar.gz")
        if not os.path.exists(tar_file_path):
            print(f"Downloading CIFAR-10 dataset from {self.url}")
            self._download_dataset(tar_file_path)
        
        # Extract if necessary
        extract_dir = os.path.join(self.data_dir, "cifar-10-batches-py")
        if not os.path.exists(extract_dir):
            print(f"Extracting dataset to {extract_dir}")
            self._extract_dataset(tar_file_path, self.data_dir)
        
        # Load training and test data
        train_data, train_labels = self._load_training_data(extract_dir)
        test_data, test_labels = self._load_test_data(extract_dir)
        
        # Preprocess data
        train_data = self._preprocess_data(train_data)
        test_data = self._preprocess_data(test_data)
        
        print(f"Loaded {len(train_data)} training samples and {len(test_data)} test samples")
        
        return train_data, train_labels, test_data, test_labels
    
    def _download_dataset(self, file_path: str):
        """Download CIFAR-10 dataset"""
        urllib.request.urlretrieve(self.url, file_path)
        print(f"Dataset downloaded to {file_path}")
    
    def _extract_dataset(self, tar_file_path: str, extract_to: str):
        """Extract CIFAR-10 dataset"""
        with tarfile.open(tar_file_path, 'r:gz') as tar:
            tar.extractall(path=extract_to)
        print("Dataset extracted")
    
    def _load_training_data(self, extract_dir: str) -> Tuple[np.ndarray, np.ndarray]:
        """Load training data from multiple batch files"""
        train_data = []
        train_labels = []
        
        # Load training batches
        for i in range(1, 6):
            batch_file = os.path.join(extract_dir, f'data_batch_{i}')
            with open(batch_file, 'rb') as f:
                batch_data = pickle.load(f, encoding='bytes')
                train_data.append(batch_data[b'data'])
                train_labels.extend(batch_data[b'labels'])
        
        # Concatenate all training data
        train_data = np.vstack(train_data)
        train_labels = np.array(train_labels)
        
        return train_data, train_labels
    
    def _load_test_data(self, extract_dir: str) -> Tuple[np.ndarray, np.ndarray]:
        """Load test data"""
        test_file = os.path.join(extract_dir, 'test_batch')
        with open(test_file, 'rb') as f:
            test_batch = pickle.load(f, encoding='bytes')
            test_data = test_batch[b'data']
            test_labels = np.array(test_batch[b'labels'])
        
        return test_data, test_labels
    
    def _preprocess_data(self, data: np.ndarray) -> np.ndarray:
        """Preprocess CIFAR-10 data"""
        # Reshape and normalize the data
        # CIFAR-10 data layout: [N, 3072] where 3072 = 3 x 32 x 32
        # We reshape to [N, 32, 32, 3] for image processing
        data = data.reshape(-1, 3, 32, 32)
        data = data.transpose(0, 2, 3, 1)
        data = data.astype(np.float32) / 255.0
        
        return data


In [None]:
# ==============================
# 11. VISUALIZER
# ==============================
class Visualizer:
    """Handles all visualization tasks"""
    
    def __init__(self, results_dir: str):
        self.results_dir = results_dir
        self.cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                               'dog', 'frog', 'horse', 'ship', 'truck']
    
    def plot_training_metrics(self, state: TrainingState):
        """Plot training and validation metrics over epochs"""
        epochs = range(1, len(state.train_losses) + 1)
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # Loss plot
        ax1.plot(epochs, state.train_losses, 'b-', label='Training Loss')
        ax1.plot(epochs, state.eval_losses, 'r-', label='Validation Loss')
        ax1.set_title('Training and Validation Loss')
        ax1.set_xlabel('Epochs')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True)
        
        # Accuracy plot
        ax2.plot(epochs, state.train_accs, 'b-', label='Training Accuracy')
        ax2.plot(epochs, state.eval_accs, 'r-', label='Validation Accuracy')
        ax2.set_title('Training and Validation Accuracy')
        ax2.set_xlabel('Epochs')
        ax2.set_ylabel('Accuracy')
        ax2.legend()
        ax2.grid(True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.results_dir, 'training_metrics.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    def plot_learning_rate_schedule(self, lr_scheduler, total_steps: int):
        """Plot the learning rate schedule"""
        steps = range(total_steps)
        lrs = [float(lr_scheduler(step)) for step in steps]
        
        plt.figure(figsize=(10, 6))
        plt.plot(steps, lrs)
        plt.title('Learning Rate Schedule')
        plt.xlabel('Step')
        plt.ylabel('Learning Rate')
        plt.grid(True)
        plt.savefig(os.path.join(self.results_dir, 'lr_schedule.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    def plot_confusion_matrix(self, cm: np.ndarray):
        """Plot confusion matrix"""
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                    xticklabels=self.cifar10_classes, 
                    yticklabels=self.cifar10_classes)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title('Confusion Matrix')
        plt.tight_layout()
        plt.savefig(os.path.join(self.results_dir, 'confusion_matrix.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    def plot_classification_metrics(self, metrics: Dict):
        """Plot precision, recall, and F1 scores"""
        metrics_df = pd.DataFrame({
            'Precision': metrics['precision'],
            'Recall': metrics['recall'],
            'F1 Score': metrics['f1']
        }, index=self.cifar10_classes)
        
        plt.figure(figsize=(12, 6))
        metrics_df.plot(kind='bar', figsize=(12, 6))
        plt.title('Precision, Recall, and F1 Score by Class')
        plt.xlabel('Class')
        plt.ylabel('Score')
        plt.xticks(rotation=45)
        plt.ylim(0, 1.0)
        plt.tight_layout()
        plt.legend(loc='best')
        plt.savefig(os.path.join(self.results_dir, 'precision_recall_f1.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    def plot_feature_space(self, features: np.ndarray, labels: np.ndarray, subset_size: int = 2000):
        """Visualize feature space using t-SNE, PCA, and UMAP"""
        # Use subset if needed
        if len(features) > subset_size:
            subset_indices = np.random.choice(len(features), subset_size, replace=False)
            features = features[subset_indices]
            labels = labels[subset_indices]
        
        # t-SNE visualization
        print("Computing t-SNE projection...")
        tsne = TSNE(n_components=2, random_state=42, n_jobs=-1)
        tsne_result = tsne.fit_transform(features)
        
        self._plot_2d_embedding(tsne_result, labels, 't-SNE Visualization of Feature Space', 'tsne_visualization.png')
        
        # PCA visualization
        print("Computing PCA projection...")
        pca = PCA(n_components=2, random_state=42)
        pca_result = pca.fit_transform(features)
        
        self._plot_2d_embedding(pca_result, labels, 'PCA Visualization of Feature Space', 'pca_visualization.png')
        
        # UMAP visualization
        print("Computing UMAP projection...")
        umap_model = umap.UMAP(n_components=2, random_state=42)
        umap_result = umap_model.fit_transform(features)
        
        self._plot_2d_embedding(umap_result, labels, 'UMAP Visualization of Feature Space', 'umap_visualization.png')
    
    def _plot_2d_embedding(self, embedding: np.ndarray, labels: np.ndarray, title: str, filename: str):
        """Plot 2D embedding with class colors"""
        plt.figure(figsize=(10, 8))
        
        # Create scatter plot for each class
        for i, class_name in enumerate(self.cifar10_classes):
            indices = labels == i
            plt.scatter(embedding[indices, 0], embedding[indices, 1], 
                       label=class_name, alpha=0.7, s=20)
        
        plt.title(title)
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.savefig(os.path.join(self.results_dir, filename), dpi=300, bbox_inches='tight')
        plt.close()
    
    def plot_sample_predictions(
        self, 
        vit: VisionTransformer,
        params: Dict,
        test_data: np.ndarray,
        test_labels: np.ndarray,
        num_samples: int = 10
    ):
        """Visualize sample predictions"""
        # Get random subset of test samples
        indices = np.random.choice(len(test_data), num_samples, replace=False)
        images = test_data[indices]
        labels = test_labels[indices]
        
        # Move to device and get predictions
        device = jax.devices()[0]
        images_device = jax.device_put(images, device)
        labels_device = jax.device_put(labels, device)
        
        # Get predictions
        eval_key = jax.device_put(random.PRNGKey(0), device)
        logits = vit(params, images_device, eval_key, training=False)
        preds = Metrics.get_predictions(logits)
        
        # Move back to CPU for visualization
        images_cpu = np.array(images)
        preds_cpu = np.array(preds)
        labels_cpu = np.array(labels)
        
        # Create figure
        fig, axes = plt.subplots(2, 5, figsize=(15, 8))
        axes = axes.flatten()
        
        for i, (image, label, pred) in enumerate(zip(images_cpu, labels_cpu, preds_cpu)):
            ax = axes[i]
            ax.imshow(image)
            ax.set_title(f"True: {self.cifar10_classes[label]}\nPred: {self.cifar10_classes[pred]}")
            ax.axis('off')
            
            # Highlight correct/incorrect predictions
            if label == pred:
                # Green border for correct predictions
                for spine in ax.spines.values():
                    spine.set_edgecolor('green')
                    spine.set_linewidth(3)
                    spine.set_visible(True)
            else:
                # Red border for incorrect predictions
                for spine in ax.spines.values():
                    spine.set_edgecolor('red')
                    spine.set_linewidth(3)
                    spine.set_visible(True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.results_dir, 'sample_predictions.png'), dpi=300, bbox_inches='tight')
        plt.close()

In [None]:
# ==============================
# 12. TRAINER CLASS
# ==============================
class Trainer:
    """Handles the training loop and evaluation for Vision Transformer"""
    
    def __init__(self, config: ViTConfig, results_dir: str = "vit_results"):
        self.config = config
        self.results_dir = results_dir
        self.device = jax.devices()[0]
        
        # Initialize components
        self.vit = VisionTransformer(config)
        self.optimizer = AdamWOptimizer(config)
        self.initializer = ParameterInitializer(config)
        
        # Set up JAX key
        self.key = jax.device_put(random.PRNGKey(config.seed), self.device)
        
        # Initialize model and optimizer
        init_key, self.key = random.split(self.key)
        self.params = self.initializer.init_transformer_params(init_key)
        self.opt_state = self.optimizer.init_state(self.params)
        
        # Initialize learning rate scheduler (will be set in train method)
        self.lr_scheduler = None
        
        # Initialize state tracking
        self.state = TrainingState(
            params=self.params,
            opt_state=self.opt_state
        )
        
        # Create JIT-compiled functions
        self._compile_functions()
    
    def _compile_functions(self):
        """Compile JAX functions for better performance"""
        self.jit_train_step = jax.jit(self._train_step)
        self.jit_eval_step = jax.jit(self._eval_step)
    
    def _train_step(
        self, 
        params: Dict, 
        batch: Tuple[jnp.ndarray, jnp.ndarray], 
        dropout_key: jnp.ndarray, 
        opt_state: Dict, 
        learning_rate: float
    ) -> Tuple[Dict, float, float, Dict]:
        """Single training step"""
        images, labels = batch
        images = jax.device_put(images, self.device)
        labels = jax.device_put(labels, self.device)
        
        def loss_fn(params_weights):
            logits = self.vit(params_weights, images, dropout_key, training=True)
            loss = Metrics.cross_entropy_loss(logits, labels)
            return loss, logits
        
        (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
        
        # Apply optimizer
        new_params, new_opt_state = self.optimizer.update(
            grads, opt_state, params, learning_rate
        )
        
        acc = Metrics.accuracy(logits, labels)
        
        return new_params, loss, acc, new_opt_state
    
    def _eval_step(
        self, 
        params: Dict, 
        batch: Tuple[jnp.ndarray, jnp.ndarray], 
        return_features: bool = False
    ) -> Any:
        """Single evaluation step"""
        images, labels = batch
        images = jax.device_put(images, self.device)
        labels = jax.device_put(labels, self.device)
        
        # No dropout during evaluation
        eval_key = jax.device_put(random.PRNGKey(0), self.device)
        
        if return_features:
            logits, features = self.vit(params, images, eval_key, training=False, return_features=True)
        else:
            logits = self.vit(params, images, eval_key, training=False)
        
        loss = Metrics.cross_entropy_loss(logits, labels)
        acc = Metrics.accuracy(logits, labels)
        preds = Metrics.get_predictions(logits)
        
        if return_features:
            return loss, acc, preds, features
        else:
            return loss, acc, preds
    
    def train(
        self, 
        train_data: np.ndarray, 
        train_labels: np.ndarray, 
        test_data: np.ndarray, 
        test_labels: np.ndarray
    ) -> Dict:
        """Main training loop"""
        # Initialize learning rate scheduler
        total_steps = self.config.num_epochs * (len(train_data) // self.config.batch_size)
        self.lr_scheduler = LearningRateScheduler(self.config, total_steps)
        
        # Create results directory
        os.makedirs(self.results_dir, exist_ok=True)
        
        print(f"Starting training on {self.device}")
        print(f"Total steps: {total_steps}")
        
        # Training loop
        for epoch in range(self.config.num_epochs):
            self.state.epoch = epoch
            epoch_start_time = time.time()
            
            # Training epoch
            self._train_epoch(train_data, train_labels)
            
            # Evaluation epoch
            eval_metrics = self._eval_epoch(test_data, test_labels)
            
            # Update state
            self.state.train_losses.append(np.mean(self.epoch_train_losses))
            self.state.train_accs.append(np.mean(self.epoch_train_accs))
            self.state.eval_losses.append(eval_metrics['loss'])
            self.state.eval_accs.append(eval_metrics['accuracy'])
            
            # Check for best model
            if eval_metrics['accuracy'] > self.state.best_accuracy:
                self.state.best_accuracy = eval_metrics['accuracy']
                self.state.best_epoch = epoch
            
            epoch_time = time.time() - epoch_start_time
            
            print(f"Epoch {epoch+1}/{self.config.num_epochs} completed in {epoch_time:.2f}s")
            print(f"  Train Loss: {self.state.train_losses[-1]:.4f}, Train Accuracy: {self.state.train_accs[-1]:.4f}")
            print(f"  Eval Loss: {self.state.eval_losses[-1]:.4f}, Eval Accuracy: {self.state.eval_accs[-1]:.4f}")
            print(f"  Best Accuracy: {self.state.best_accuracy:.4f} at epoch {self.state.best_epoch+1}")
            
            # Save progress plots periodically
            if (epoch + 1) % 5 == 0 or epoch == self.config.num_epochs - 1:
                viz = Visualizer(self.results_dir)
                viz.plot_training_metrics(self.state)
        
        # Final evaluation with features
        final_metrics = self._final_evaluation(test_data, test_labels)
        
        return final_metrics
    
    def _train_epoch(self, train_data: np.ndarray, train_labels: np.ndarray):
        """Train for one epoch"""
        self.epoch_train_losses = []
        self.epoch_train_accs = []
        
        # Create batch iterator with progress bar
        num_batches = len(train_data) // self.config.batch_size
        batches = self._iterate_batches(train_data, train_labels, shuffle=True)
        progress_bar = tqdm(batches, total=num_batches, 
                           desc=f"Epoch {self.state.epoch+1}/{self.config.num_epochs} [Train]",
                           leave=True)
        
        for batch_idx, batch in enumerate(progress_bar):
            # Generate new dropout key
            dropout_key, self.key = random.split(self.key)
            dropout_key = jax.device_put(dropout_key, self.device)
            
            # Get current learning rate
            current_lr = float(self.lr_scheduler(self.state.step))
            
            # Update parameters
            self.state.params, loss, acc, self.state.opt_state = self.jit_train_step(
                self.state.params, batch, dropout_key, self.state.opt_state, current_lr
            )
            
            self.epoch_train_losses.append(float(loss))
            self.epoch_train_accs.append(float(acc))
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f"{float(loss):.4f}",
                'acc': f"{float(acc):.4f}",
                'lr': f"{current_lr:.6f}"
            })
            
            self.state.step += 1
    
    def _eval_epoch(self, test_data: np.ndarray, test_labels: np.ndarray) -> Dict:
        """Evaluate for one epoch"""
        epoch_eval_losses = []
        epoch_eval_accs = []
        all_preds = []
        all_labels = []
        
        # Create test batch iterator with progress bar
        num_test_batches = len(test_data) // self.config.batch_size
        if len(test_data) % self.config.batch_size != 0:
            num_test_batches += 1
            
        batches = self._iterate_batches(test_data, test_labels, shuffle=False)
        progress_bar = tqdm(batches, total=num_test_batches, 
                           desc=f"Epoch {self.state.epoch+1}/{self.config.num_epochs} [Eval]",
                           leave=True)
        
        for batch in progress_bar:
            loss, acc, preds = self.jit_eval_step(self.state.params, batch)
            epoch_eval_losses.append(float(loss))
            epoch_eval_accs.append(float(acc))
            
            # Collect predictions and labels
            all_preds.append(np.array(preds))
            all_labels.append(batch[1])
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f"{float(loss):.4f}",
                'acc': f"{float(acc):.4f}"
            })
        
        # Compute average metrics
        return {
            'loss': np.mean(epoch_eval_losses),
            'accuracy': np.mean(epoch_eval_accs),
            'predictions': np.concatenate(all_preds),
            'labels': np.concatenate(all_labels)
        }
    
    def _final_evaluation(self, test_data: np.ndarray, test_labels: np.ndarray) -> Dict:
        """Final comprehensive evaluation"""
        print("Running final evaluation with feature extraction...")
        
        all_features = []
        all_preds = []
        all_labels = []
        
        # Create progress bar for final evaluation
        num_test_batches = len(test_data) // self.config.batch_size
        if len(test_data) % self.config.batch_size != 0:
            num_test_batches += 1
            
        batches = self._iterate_batches(test_data, test_labels, shuffle=False)
        progress_bar = tqdm(batches, total=num_test_batches, 
                           desc="Final Evaluation",
                           leave=True)
        
        # Collect features and predictions
        for batch in progress_bar:
            loss, acc, preds, features = self._eval_step(self.state.params, batch, return_features=True)
            all_features.append(np.array(features))
            all_preds.append(np.array(preds))
            all_labels.append(np.array(batch[1]))
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f"{float(loss):.4f}",
                'acc': f"{float(acc):.4f}"
            })
        
        # Concatenate all results
        features = np.concatenate(all_features)
        predictions = np.concatenate(all_preds)
        labels = np.concatenate(all_labels)
        
        # Compute detailed metrics
        metrics = Metrics.compute_detailed_metrics(labels, predictions)
        
        # Generate visualizations
        viz = Visualizer(self.results_dir)
        viz.plot_confusion_matrix(metrics['confusion_matrix'])
        viz.plot_classification_metrics(metrics)
        viz.plot_feature_space(features, labels, subset_size=2000)
        
        return {
            'metrics': metrics,
            'features': features,
            'predictions': predictions,
            'labels': labels
        }
    
    def _iterate_batches(
        self, 
        images: np.ndarray, 
        labels: np.ndarray, 
        shuffle: bool = False
    ):
        """Iterator for creating batches"""
        num_samples = len(images)
        indices = np.arange(num_samples)
        
        if shuffle:
            np.random.shuffle(indices)
        
        for start_idx in range(0, num_samples, self.config.batch_size):
            end_idx = min(start_idx + self.config.batch_size, num_samples)
            batch_indices = indices[start_idx:end_idx]
            yield images[batch_indices], labels[batch_indices]



In [8]:
# ==============================
# 13. MAIN FUNCTION
# ==============================
def main():
    """Main entry point for training"""
    
    # Set up configuration - modify these values to customize training
    config = ViTConfig(
        # Model architecture
        img_size=32,
        patch_size=4,
        num_classes=10,
        num_heads=8,
        num_layers=6,
        hidden_dim=384,
        mlp_dim=1536,
        dropout_rate=0.1,
        
        # Training parameters
        batch_size=128,
        num_epochs=30,
        learning_rate=3e-4,
        weight_decay=0.01,
        warmup_steps=500,
        
        # Other settings
        seed=42
    )
    
    # Create unique results directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = os.path.join("vit_results", f"run_{timestamp}")
    os.makedirs(results_dir, exist_ok=True)
    
    # Print startup information
    print("\n" + "="*60)
    print("VISION TRANSFORMER TRAINING")
    print("="*60)
    print("Configuration:")
    print(f"  - Model: {config.num_layers} layers × {config.hidden_dim} dim")
    print(f"  - Training: {config.num_epochs} epochs × {config.batch_size} batch size")
    print(f"  - Learning rate: {config.learning_rate} (warmup: {config.warmup_steps} steps)")
    print(f"  - Device: {jax.devices()[0]}")
    print(f"  - Results: {results_dir}")
    print("="*60 + "\n")
    
    # Initialize all components
    data_loader = DataLoader(data_dir="cifar10_data")
    trainer = Trainer(config, results_dir)
    visualizer = Visualizer(results_dir)
    
    # Save configuration to file
    config_path = os.path.join(results_dir, 'config.txt')
    with open(config_path, 'w') as f:
        f.write("Vision Transformer Configuration\n")
        f.write("="*50 + "\n\n")
        for key, value in config.__dict__.items():
            f.write(f"{key}: {value}\n")
    
    try:
        # Load and prepare data
        print("Loading CIFAR-10 dataset...")
        train_data, train_labels, test_data, test_labels = data_loader.load_datasets()
        
        # Create learning rate visualization
        print("Creating visualizations...")
        total_steps = config.num_epochs * (len(train_data) // config.batch_size)
        visualizer.plot_learning_rate_schedule(trainer.lr_scheduler or 
                                             LearningRateScheduler(config, total_steps), 
                                             total_steps)
        
        # Run training
        print("\nTraining model...")
        start_time = time.time()
        
        final_metrics = trainer.train(train_data, train_labels, test_data, test_labels)
        
        # Create final visualizations
        print("\nGenerating final visualizations...")
        visualizer.plot_sample_predictions(
            trainer.vit,
            trainer.state.params,
            test_data,
            test_labels
        )
        
        # Calculate total training time
        total_time = time.time() - start_time
        hours = int(total_time // 3600)
        minutes = int((total_time % 3600) // 60)
        seconds = int(total_time % 60)
        
        # Save final results summary
        summary_path = os.path.join(results_dir, 'training_summary.txt')
        with open(summary_path, 'w') as f:
            f.write("Vision Transformer Training Summary\n")
            f.write("="*50 + "\n\n")
            f.write(f"Training Time: {hours:02d}:{minutes:02d}:{seconds:02d}\n")
            f.write(f"Best Validation Accuracy: {trainer.state.best_accuracy:.4f}\n")
            f.write(f"Best Epoch: {trainer.state.best_epoch + 1}\n")
            f.write(f"Final Macro F1 Score: {final_metrics['metrics']['macro_f1']:.4f}\n")
            f.write(f"Final Macro Precision: {final_metrics['metrics']['macro_precision']:.4f}\n")
            f.write(f"Final Macro Recall: {final_metrics['metrics']['macro_recall']:.4f}\n")
            f.write(f"\nResults Directory: {results_dir}\n")
        
        # Save hardware summary
        hardware_summary_path = os.path.join(results_dir, 'hardware_summary.txt')
        with open(hardware_summary_path, 'w') as f:
            f.write("Hardware Summary\n")
            f.write("="*50 + "\n\n")
            f.write(f"JAX version: {jax.__version__}\n")
            f.write(f"Device used: {jax.devices()[0]}\n")
            f.write(f"Number of devices: {jax.device_count()}\n")
            try:
                f.write(f"Device type: {jax.devices()[0].device_kind}\n")
            except:
                f.write("Device type: Unknown\n")
        
        # Print summary
        print("\n" + "="*60)
        print("TRAINING COMPLETED SUCCESSFULLY!")
        print("="*60)
        print(f"Total time: {hours:02d}:{minutes:02d}:{seconds:02d}")
        print(f"Best accuracy: {trainer.state.best_accuracy:.4f}")
        print(f"Final F1 score: {final_metrics['metrics']['macro_f1']:.4f}")
        print(f"Results in: {results_dir}")
        print("="*60)
        
        # List generated files
        print("\nGenerated files:")
        print("  Visualizations:")
        print("    - training_metrics.png - Learning curves")
        print("    - lr_schedule.png - Learning rate schedule")
        print("    - confusion_matrix.png - Confusion matrix")
        print("    - precision_recall_f1.png - Classification metrics")
        print("    - tsne_visualization.png - t-SNE feature visualization")
        print("    - pca_visualization.png - PCA feature visualization")
        print("    - umap_visualization.png - UMAP feature visualization")
        print("    - sample_predictions.png - Sample predictions")
        print("  Text files:")
        print("    - config.txt - Training configuration")
        print("    - training_summary.txt - Final results summary")
        print("    - hardware_summary.txt - Hardware information")
        print("\n")
        
    except Exception as e:
        print(f"\nTraining failed with error: {e}")
        print(f"Partial results saved to: {results_dir}")
        import traceback
        traceback.print_exc()
        return 1
    
    return 0


if __name__ == "__main__":
    exit_code = main()


VISION TRANSFORMER TRAINING
Configuration:
  - Model: 6 layers × 384 dim
  - Training: 30 epochs × 128 batch size
  - Learning rate: 0.0003 (warmup: 500 steps)
  - Device: cuda:0
  - Results: vit_results/run_20250506_204653

Loading CIFAR-10 dataset...
Loaded 50000 training samples and 10000 test samples
Creating visualizations...

Training model...
Starting training on cuda:0
Total steps: 11700


Epoch 1/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 1/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 1/30 completed in 317.00s
  Train Loss: 1.8906, Train Accuracy: 0.2865
  Eval Loss: 1.5669, Eval Accuracy: 0.4156
  Best Accuracy: 0.4156 at epoch 1


Epoch 2/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 2/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 2/30 completed in 7.58s
  Train Loss: 1.4729, Train Accuracy: 0.4619
  Eval Loss: 1.3942, Eval Accuracy: 0.4877
  Best Accuracy: 0.4877 at epoch 2


Epoch 3/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 3/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 3/30 completed in 7.60s
  Train Loss: 1.2997, Train Accuracy: 0.5307
  Eval Loss: 1.2964, Eval Accuracy: 0.5303
  Best Accuracy: 0.5303 at epoch 3


Epoch 4/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 4/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 4/30 completed in 7.66s
  Train Loss: 1.2284, Train Accuracy: 0.5565
  Eval Loss: 1.2476, Eval Accuracy: 0.5485
  Best Accuracy: 0.5485 at epoch 4


Epoch 5/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 5/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 5/30 completed in 7.65s
  Train Loss: 1.1712, Train Accuracy: 0.5793
  Eval Loss: 1.2490, Eval Accuracy: 0.5521
  Best Accuracy: 0.5521 at epoch 5


Epoch 6/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 6/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 6/30 completed in 7.65s
  Train Loss: 1.1309, Train Accuracy: 0.5925
  Eval Loss: 1.1386, Eval Accuracy: 0.5877
  Best Accuracy: 0.5877 at epoch 6


Epoch 7/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 7/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 7/30 completed in 7.69s
  Train Loss: 1.0873, Train Accuracy: 0.6106
  Eval Loss: 1.1367, Eval Accuracy: 0.5965
  Best Accuracy: 0.5965 at epoch 7


Epoch 8/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 8/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 8/30 completed in 7.68s
  Train Loss: 1.0537, Train Accuracy: 0.6237
  Eval Loss: 1.0884, Eval Accuracy: 0.6097
  Best Accuracy: 0.6097 at epoch 8


Epoch 9/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 9/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 9/30 completed in 7.74s
  Train Loss: 1.0173, Train Accuracy: 0.6357
  Eval Loss: 1.0940, Eval Accuracy: 0.6042
  Best Accuracy: 0.6097 at epoch 8


Epoch 10/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 10/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 10/30 completed in 7.72s
  Train Loss: 0.9809, Train Accuracy: 0.6471
  Eval Loss: 1.0874, Eval Accuracy: 0.6053
  Best Accuracy: 0.6097 at epoch 8


Epoch 11/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 11/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 11/30 completed in 7.65s
  Train Loss: 0.9461, Train Accuracy: 0.6606
  Eval Loss: 1.0798, Eval Accuracy: 0.6165
  Best Accuracy: 0.6165 at epoch 11


Epoch 12/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 12/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 12/30 completed in 7.65s
  Train Loss: 0.9102, Train Accuracy: 0.6718
  Eval Loss: 1.0379, Eval Accuracy: 0.6284
  Best Accuracy: 0.6284 at epoch 12


Epoch 13/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 13/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 13/30 completed in 7.72s
  Train Loss: 0.8685, Train Accuracy: 0.6885
  Eval Loss: 1.0481, Eval Accuracy: 0.6295
  Best Accuracy: 0.6295 at epoch 13


Epoch 14/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 14/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 14/30 completed in 7.65s
  Train Loss: 0.8324, Train Accuracy: 0.7032
  Eval Loss: 1.0247, Eval Accuracy: 0.6372
  Best Accuracy: 0.6372 at epoch 14


Epoch 15/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 15/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 15/30 completed in 7.68s
  Train Loss: 0.7926, Train Accuracy: 0.7170
  Eval Loss: 1.0353, Eval Accuracy: 0.6448
  Best Accuracy: 0.6448 at epoch 15


Epoch 16/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 16/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 16/30 completed in 7.73s
  Train Loss: 0.7434, Train Accuracy: 0.7332
  Eval Loss: 1.0492, Eval Accuracy: 0.6443
  Best Accuracy: 0.6448 at epoch 15


Epoch 17/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 17/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 17/30 completed in 7.72s
  Train Loss: 0.7024, Train Accuracy: 0.7471
  Eval Loss: 1.0911, Eval Accuracy: 0.6354
  Best Accuracy: 0.6448 at epoch 15


Epoch 18/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 18/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 18/30 completed in 7.63s
  Train Loss: 0.6553, Train Accuracy: 0.7660
  Eval Loss: 1.0523, Eval Accuracy: 0.6440
  Best Accuracy: 0.6448 at epoch 15


Epoch 19/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 19/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 19/30 completed in 7.63s
  Train Loss: 0.6047, Train Accuracy: 0.7827
  Eval Loss: 1.0732, Eval Accuracy: 0.6479
  Best Accuracy: 0.6479 at epoch 19


Epoch 20/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 20/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 20/30 completed in 7.74s
  Train Loss: 0.5570, Train Accuracy: 0.8005
  Eval Loss: 1.0925, Eval Accuracy: 0.6495
  Best Accuracy: 0.6495 at epoch 20


Epoch 21/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 21/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 21/30 completed in 7.67s
  Train Loss: 0.5072, Train Accuracy: 0.8183
  Eval Loss: 1.1373, Eval Accuracy: 0.6521
  Best Accuracy: 0.6521 at epoch 21


Epoch 22/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 22/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 22/30 completed in 7.78s
  Train Loss: 0.4608, Train Accuracy: 0.8357
  Eval Loss: 1.1836, Eval Accuracy: 0.6474
  Best Accuracy: 0.6521 at epoch 21


Epoch 23/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 23/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 23/30 completed in 7.80s
  Train Loss: 0.4117, Train Accuracy: 0.8528
  Eval Loss: 1.2176, Eval Accuracy: 0.6541
  Best Accuracy: 0.6541 at epoch 23


Epoch 24/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 24/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 24/30 completed in 7.78s
  Train Loss: 0.3729, Train Accuracy: 0.8658
  Eval Loss: 1.2480, Eval Accuracy: 0.6484
  Best Accuracy: 0.6541 at epoch 23


Epoch 25/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 25/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 25/30 completed in 7.68s
  Train Loss: 0.3369, Train Accuracy: 0.8791
  Eval Loss: 1.2915, Eval Accuracy: 0.6521
  Best Accuracy: 0.6541 at epoch 23


Epoch 26/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 26/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 26/30 completed in 7.64s
  Train Loss: 0.3057, Train Accuracy: 0.8918
  Eval Loss: 1.3325, Eval Accuracy: 0.6530
  Best Accuracy: 0.6541 at epoch 23


Epoch 27/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 27/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 27/30 completed in 7.64s
  Train Loss: 0.2849, Train Accuracy: 0.8985
  Eval Loss: 1.3331, Eval Accuracy: 0.6560
  Best Accuracy: 0.6560 at epoch 27


Epoch 28/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 28/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 28/30 completed in 7.65s
  Train Loss: 0.2713, Train Accuracy: 0.9052
  Eval Loss: 1.3497, Eval Accuracy: 0.6521
  Best Accuracy: 0.6560 at epoch 27


Epoch 29/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 29/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 29/30 completed in 7.88s
  Train Loss: 0.2635, Train Accuracy: 0.9084
  Eval Loss: 1.3557, Eval Accuracy: 0.6522
  Best Accuracy: 0.6560 at epoch 27


Epoch 30/30 [Train]:   0%|          | 0/390 [00:00<?, ?it/s]

Epoch 30/30 [Eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 30/30 completed in 7.86s
  Train Loss: 0.2574, Train Accuracy: 0.9096
  Eval Loss: 1.3568, Eval Accuracy: 0.6532
  Best Accuracy: 0.6560 at epoch 27
Running final evaluation with feature extraction...


Final Evaluation:   0%|          | 0/79 [00:00<?, ?it/s]

Computing t-SNE projection...
Computing PCA projection...
Computing UMAP projection...


  warn(



Generating final visualizations...

TRAINING COMPLETED SUCCESSFULLY!
Total time: 00:10:50
Best accuracy: 0.6560
Final F1 score: 0.6515
Results in: vit_results/run_20250506_204653

Generated files:
  Visualizations:
    - training_metrics.png - Learning curves
    - lr_schedule.png - Learning rate schedule
    - confusion_matrix.png - Confusion matrix
    - precision_recall_f1.png - Classification metrics
    - tsne_visualization.png - t-SNE feature visualization
    - pca_visualization.png - PCA feature visualization
    - umap_visualization.png - UMAP feature visualization
    - sample_predictions.png - Sample predictions
  Text files:
    - config.txt - Training configuration
    - training_summary.txt - Final results summary
    - hardware_summary.txt - Hardware information




<Figure size 1200x600 with 0 Axes>