In [2]:
!pip install -q tensorflow_datasets umap-learn "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.htmlimport jax

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.4.1 requires nvidia-cudnn-cu12==9.1.0.70; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cudnn-cu12 9.9.0.52 which is incompatible.
tensorflow 2.17.0 requires ml-dtypes<0.5.0,>=0.3.1, but you have ml-dtypes 0.5.1 which is incompatible.
tensorflow 2.17.0 requires numpy<2.0.0,>=1.23.5; python_version <= "3.11", but you have numpy 2.0.2 which is incompatible.[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.2[0m[39;49m -> [0m[32;49m25.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [12]:
# ===============================================
# ENHANCED VISION TRANSFORMER WITH REGULARIZATION
# ===============================================

import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap
import numpy as np
from functools import partial
import time
import os
import urllib.request
import tarfile
import pickle
from datetime import datetime
from tqdm.auto import tqdm
from dataclasses import dataclass, field
from typing import Dict, Tuple, List, Any, NamedTuple

# For visualization and metrics
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import umap

# ==============================
# 1. ENHANCED CONFIGURATION 
# ==============================

@dataclass
class ViTConfig:
    """Enhanced Configuration for Vision Transformer with Regularization"""
    # Image processing
    img_size: int = 32
    patch_size: int = 4
    
    # Model architecture
    num_classes: int = 10
    num_heads: int = 8
    num_layers: int = 6
    hidden_dim: int = 384
    mlp_dim: int = 1536
    
    # Enhanced Dropout Configuration
    dropout_rate: float = 0.2
    attention_dropout_rate: float = 0.3
    projection_dropout_rate: float = 0.3
    path_dropout_rate: float = 0.2  # For DroidOut/StochasticDepth
    
    # Batch Normalization
    use_batch_norm: bool = True
    bn_momentum: float = 0.9
    bn_eps: float = 1e-5
    
    # Data Augmentation
    use_data_augmentation: bool = True
    augment_prob: float = 0.4
    mixup_alpha: float = 0.8
    cutmix_alpha: float = 1.0
    cutout_size: int = 8

    # enable_random_flip: bool = True
    # enable_color_jitter: bool = True
    # enable_gaussian_noise: bool = True
    # enable_mixup: bool = True
    # enable_cutout: bool = True
    
    # Label Smoothing
    label_smoothing: float = 0.1
    
    # Weight Decay
    weight_decay: float = 0.05  # Increased for better regularization
    
    # Training parameters
    batch_size: int = 128
    num_epochs: int = 50  # Increased for better convergence
    learning_rate: float = 5e-4
    warmup_steps: int = 1000
    early_stopping_patience: int = 5 
    
    # Optimizer parameters
    beta1: float = 0.9
    beta2: float = 0.999
    eps: float = 1e-8
    
    # EMA (Exponential Moving Average)
    use_ema: bool = True
    ema_decay: float = 0.9999
    
    # Other settings
    seed: int = 42
    initializer_range: float = 0.02
    
    @property
    def num_patches(self) -> int:
        """Calculate the number of patches based on image and patch size"""
        return (self.img_size // self.patch_size) ** 2


@dataclass
class TrainingState:
    """State for tracking training progress"""
    params: Any
    opt_state: Any
    train_losses: List[float] = field(default_factory=list)
    train_accs: List[float] = field(default_factory=list)
    eval_losses: List[float] = field(default_factory=list)
    eval_accs: List[float] = field(default_factory=list)
    step: int = 0
    epoch: int = 0
    best_accuracy: float = 0.0
    best_epoch: int = 0


# ===================================
# 2. UTILITY CLASSES (from original)
# ===================================
class Utils:
    """Utility functions for the Vision Transformer"""
    
    @staticmethod
    def to_device(x):
        """Helper function to move data to the selected device"""
        device = jax.devices()[0]
        return jax.device_put(x, device)
    
    @staticmethod
    def get_initializer(scale: float = 0.02):
        """Get weight initializer function"""
        return lambda key, shape, dtype=jnp.float32: random.normal(key, shape, dtype) * scale


class Metrics:
    """Class for computing and tracking metrics"""
    
    @staticmethod
    def cross_entropy_loss(logits: jnp.ndarray, labels: jnp.ndarray) -> float:
        """Compute cross-entropy loss"""
        one_hot_labels = jax.nn.one_hot(labels, 10)  # CIFAR-10 has 10 classes
        softmax_logits = jax.nn.log_softmax(logits)
        loss = -jnp.sum(one_hot_labels * softmax_logits) / labels.shape[0]
        return loss
    
    @staticmethod
    def accuracy(logits: jnp.ndarray, labels: jnp.ndarray) -> float:
        """Compute accuracy"""
        preds = jnp.argmax(logits, axis=-1)
        return jnp.mean(preds == labels)
    
    @staticmethod
    def get_predictions(logits: jnp.ndarray) -> jnp.ndarray:
        """Get class predictions from logits"""
        return jnp.argmax(logits, axis=-1)
    
    @staticmethod
    def compute_detailed_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict:
        """Compute detailed metrics including confusion matrix"""
        # Compute confusion matrix
        cm = confusion_matrix(y_true, y_pred)
        
        # Compute precision, recall, and F1 score for each class
        precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average=None)
        
        # Compute macro-averaged metrics
        macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(
            y_true, y_pred, average='macro'
        )
        
        return {
            'confusion_matrix': cm,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'macro_precision': macro_precision,
            'macro_recall': macro_recall,
            'macro_f1': macro_f1
        }


# ==============================
# 3. DATA AUGMENTATION MODULE
# ==============================
class DataAugmentation:
    """Comprehensive data augmentation techniques"""
    
    def __init__(self, config: ViTConfig):
        self.config = config
        
    def random_rotation(self, key: jnp.ndarray, images: jnp.ndarray, max_angle=15) -> jnp.ndarray:
        """Simple random rotation using static angles"""
        batch_size = images.shape[0]
        
        # Pre-define static rotation angles
        angles = jnp.array([-15, -10, -5, 0, 5, 10, 15]) * jnp.pi / 180
        
        # Select random angles for each image
        angle_indices = random.randint(key, (batch_size,), 0, len(angles))
        selected_angles = angles[angle_indices]
        
        def rotate_single(args):
            image, angle = args
            # Simple rotation matrix
            cos_angle = jnp.cos(angle)
            sin_angle = jnp.sin(angle)
            
            # Create rotation matrix
            rotation_matrix = jnp.array([[cos_angle, -sin_angle],
                                       [sin_angle, cos_angle]])
            
            # Apply rotation using a simple coordinate transformation
            # Note: This is a simplified rotation that works with JAX
            return image  # Placeholder - actual rotation would be more complex
        
        return vmap(rotate_single)((images, selected_angles))
    
    def apply_augmentation(self, key: jnp.ndarray, images: jnp.ndarray, labels: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Apply all augmentation techniques"""
        if not self.config.use_data_augmentation:
            return images, labels
            
        key, subkey = random.split(key)
        
        # Apply basic augmentations that work with JAX
        images = self.random_flip(subkey, images)
        
        key, subkey = random.split(key)
        images = self.color_jitter(subkey, images)
        
        # Apply Gaussian noise as an alternative to cutout
        key, subkey = random.split(key)
        noise = random.normal(subkey, images.shape) * 0.1
        images = jnp.clip(images + noise, 0, 1)
        
        # Apply Mixup
        if self.config.mixup_alpha > 0:
            key, subkey = random.split(key)
            images, labels = self.mixup(subkey, images, labels)
        
        return images, labels
    
    def random_flip(self, key: jnp.ndarray, images: jnp.ndarray) -> jnp.ndarray:
        """Random horizontal flip"""
        def flip_image(image, flip):
            return jnp.where(flip, jnp.fliplr(image), image)
        
        flips = random.bernoulli(key, 0.5, (images.shape[0],))
        return vmap(flip_image)(images, flips)
    
    def random_crop_and_resize(self, key: jnp.ndarray, images: jnp.ndarray, scale=(0.8, 1.0)) -> jnp.ndarray:
        """Random crop and resize using masked operations"""
        batch_size, height, width, channels = images.shape
        
        keys = random.split(key, batch_size)
        
        def crop_and_resize_single(image, key_single):
            # Sample random scale
            scale_value = random.uniform(key_single, (), minval=scale[0], maxval=scale[1])
            
            # Calculate crop size
            crop_height = jnp.round(height * scale_value).astype(jnp.int32)
            crop_width = jnp.round(width * scale_value).astype(jnp.int32)
            
            # Random crop position
            key_single, subkey = random.split(key_single)
            max_h = jnp.maximum(1, height - crop_height + 1)
            max_w = jnp.maximum(1, width - crop_width + 1)
            start_h = random.randint(subkey, (), 0, max_h)
            key_single, subkey = random.split(key_single)
            start_w = random.randint(subkey, (), 0, max_w)
            
            # Create a mask for the crop region
            y_indices = jnp.arange(height)[:, None]
            x_indices = jnp.arange(width)[None, :]
            mask = ((y_indices >= start_h) & (y_indices < start_h + crop_height) &
                   (x_indices >= start_w) & (x_indices < start_w + crop_width))
            
            # Create cropped image by masking
            masked_image = jnp.where(mask[:, :, None], image, 0)
            
            # For simplicity, we can just resize the entire image with the mask
            return jax.image.resize(masked_image, (height, width, channels), method='bilinear')
        
        # Apply to each image in the batch
        return vmap(crop_and_resize_single)(images, keys)
    
    def color_jitter(self, key: jnp.ndarray, images: jnp.ndarray, 
                    brightness=0.4, contrast=0.4, saturation=0.4) -> jnp.ndarray:
        """Apply color jittering"""
        batch_size = images.shape[0]
        
        # Generate random factors
        key, b_key = random.split(key)
        key, c_key = random.split(key)
        key, s_key = random.split(key)
        
        brightness_factor = random.uniform(b_key, (batch_size,), 
                                         minval=1-brightness, maxval=1+brightness)
        contrast_factor = random.uniform(c_key, (batch_size,), 
                                       minval=1-contrast, maxval=1+contrast)
        saturation_factor = random.uniform(s_key, (batch_size,), 
                                         minval=1-saturation, maxval=1+saturation)
        
        def jitter_single(image, b_factor, c_factor, s_factor):
            # Brightness
            image = jnp.clip(image * b_factor, 0, 1)
            
            # Contrast
            mean = jnp.mean(image)
            image = jnp.clip((image - mean) * c_factor + mean, 0, 1)
            
            # Saturation (convert to grayscale then blend)
            gray = jnp.mean(image, axis=2, keepdims=True)
            image = jnp.clip(gray + (image - gray) * s_factor, 0, 1)
            
            return image
        
        return vmap(jitter_single)(images, brightness_factor, contrast_factor, saturation_factor)
    
    def cutout(self, key: jnp.ndarray, images: jnp.ndarray) -> jnp.ndarray:
        """Apply Cutout augmentation"""
        batch_size, height, width, channels = images.shape
        
        def apply_cutout_single(args):
            image, key_single = args
            # Random position ensuring we don't go out of bounds
            max_x = jnp.maximum(1, width - self.config.cutout_size)
            max_y = jnp.maximum(1, height - self.config.cutout_size)
            
            key_single, subkey = random.split(key_single)
            x = random.randint(subkey, (), 0, max_x)
            key_single, subkey = random.split(key_single)
            y = random.randint(subkey, (), 0, max_y)
            
            # Create mask
            mask = jnp.ones_like(image)
            # Use a safe cutout size that doesn't exceed image boundaries
            safe_size = jnp.minimum(self.config.cutout_size, jnp.minimum(width - x, height - y))
            mask = mask.at[y:y+safe_size, x:x+safe_size].set(0)
            
            return image * mask
        
        # Split keys for each image
        keys = random.split(key, batch_size)
        return vmap(apply_cutout_single)((images, keys))
    
    def mixup(self, key: jnp.ndarray, images: jnp.ndarray, labels: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Apply Mixup augmentation"""
        batch_size = images.shape[0]
        
        # Sample mixing ratio
        lam = random.beta(key, self.config.mixup_alpha, self.config.mixup_alpha)
        
        # Shuffle indices
        key, subkey = random.split(key)
        indices = random.permutation(subkey, batch_size)
        
        # Mix images and labels
        mixed_images = lam * images + (1 - lam) * images[indices]
        mixed_labels = lam * labels + (1 - lam) * labels[indices]
        
        return mixed_images, mixed_labels


# ==============================
# 4. ENHANCED MODEL LAYERS
# ==============================
class EnhancedModelLayers:
    """Enhanced neural network layers with regularization"""
    
    @staticmethod
    def linear(x: jnp.ndarray, w: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
        """Linear layer"""
        return jnp.dot(x, w) + b
    
    @staticmethod
    def layer_norm(x: jnp.ndarray, scale: jnp.ndarray, bias: jnp.ndarray, eps: float = 1e-6) -> jnp.ndarray:
        """Layer normalization"""
        mean = jnp.mean(x, axis=-1, keepdims=True)
        var = jnp.var(x, axis=-1, keepdims=True)
        return scale * (x - mean) / jnp.sqrt(var + eps) + bias
    
    @staticmethod
    def batch_norm(x: jnp.ndarray, scale: jnp.ndarray, bias: jnp.ndarray, 
                   running_mean: jnp.ndarray, running_var: jnp.ndarray,
                   momentum: float = 0.9, eps: float = 1e-5, training: bool = True) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """Batch normalization"""
        if training:
            # Compute batch statistics
            mean = jnp.mean(x, axis=(0, 1), keepdims=True)
            var = jnp.var(x, axis=(0, 1), keepdims=True)
            
            # Update running statistics
            new_running_mean = momentum * running_mean + (1 - momentum) * mean
            new_running_var = momentum * running_var + (1 - momentum) * var
        else:
            # Use running statistics
            mean = running_mean
            var = running_var
            new_running_mean = running_mean
            new_running_var = running_var
        
        # Normalize
        normalized = (x - mean) / jnp.sqrt(var + eps)
        scaled = scale * normalized + bias
        
        return scaled, new_running_mean.squeeze(), new_running_var.squeeze()
    
    @staticmethod
    def gelu(x: jnp.ndarray) -> jnp.ndarray:
        """GELU activation function"""
        return 0.5 * x * (1 + jnp.tanh(jnp.sqrt(2 / jnp.pi) * (x + 0.044715 * x**3)))
    
    @staticmethod
    def dropout(key: jnp.ndarray, x: jnp.ndarray, rate: float) -> jnp.ndarray:
        """Apply dropout"""
        if rate == 0.0:
            return x
        keep_prob = 1.0 - rate
        mask = random.bernoulli(key, keep_prob, x.shape)
        return x * mask / keep_prob


# ===============================
# 5. ENHANCED TRANSFORMER BLOCKS
# ===============================
class EnhancedTransformerBlocks:
    """Transformer blocks with advanced regularization"""
    
    def __init__(self, config: ViTConfig):
        self.config = config
        self.layers = EnhancedModelLayers()
    
    def multi_head_attention(
        self, 
        x: jnp.ndarray, 
        params: Dict, 
        key: jnp.ndarray, 
        training: bool = True
    ) -> jnp.ndarray:
        """Enhanced multi-head self-attention with stronger dropout"""
        batch_size, seq_len, hidden_dim = x.shape
        head_dim = hidden_dim // self.config.num_heads
        
        # Project queries, keys, and values
        qkv = self.layers.linear(x, params['qkv_w'], params['qkv_b'])
        qkv = qkv.reshape(batch_size, seq_len, 3, self.config.num_heads, head_dim)
        qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Scaled dot-product attention
        scale = jnp.sqrt(head_dim)
        attention_scores = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) / scale
        attention_probs = jax.nn.softmax(attention_scores)
        
        # Apply stronger attention dropout
        if training and self.config.attention_dropout_rate > 0:
            attention_key = random.split(key, num=1)[0]
            attention_probs = self.layers.dropout(attention_key, attention_probs, self.config.attention_dropout_rate)
        
        context = jnp.matmul(attention_probs, v)
        context = jnp.transpose(context, (0, 2, 1, 3))
        context = context.reshape(batch_size, seq_len, hidden_dim)
        
        # Output projection with separate dropout
        output = self.layers.linear(context, params['out_w'], params['out_b'])
        
        if training and self.config.projection_dropout_rate > 0:
            output_key = random.split(key, num=1)[0]
            output = self.layers.dropout(output_key, output, self.config.projection_dropout_rate)
        
        return output
    
    def mlp_block(
        self, 
        x: jnp.ndarray, 
        params: Dict, 
        key: jnp.ndarray, 
        training: bool = True
    ) -> jnp.ndarray:
        """Enhanced MLP block with better regularization"""
        # First dense layer
        x = self.layers.linear(x, params['mlp1_w'], params['mlp1_b'])
        x = self.layers.gelu(x)
        
        # Dropout after first layer
        if training and self.config.dropout_rate > 0:
            mlp1_key = random.split(key, num=1)[0]
            x = self.layers.dropout(mlp1_key, x, self.config.dropout_rate)
        
        # Second dense layer
        x = self.layers.linear(x, params['mlp2_w'], params['mlp2_b'])
        
        # Dropout after second layer
        if training and self.config.dropout_rate > 0:
            mlp2_key = random.split(key, num=1)[0]
            x = self.layers.dropout(mlp2_key, x, self.config.dropout_rate)
        
        return x
    
    def encoder_block(
        self, 
        x: jnp.ndarray, 
        params: Dict, 
        key: jnp.ndarray, 
        layer_id: int,
        training: bool = True
    ) -> jnp.ndarray:
        """Enhanced encoder block with improved regularization"""
        
        # Pre-normalization architecture
        norm1 = self.layers.layer_norm(x, params['ln1_scale'], params['ln1_bias'])
        
        # Multi-head attention with residual
        key, attn_key = random.split(key)
        attn_output = self.multi_head_attention(norm1, params, attn_key, training)
        x = x + attn_output
        
        # Apply stateless batch norm if enabled
        if self.config.use_batch_norm and 'bn1_scale' in params:
            mean = jnp.mean(x, axis=(0,), keepdims=True)
            var = jnp.var(x, axis=(0,), keepdims=True)
            normalized = (x - mean) / jnp.sqrt(var + self.config.bn_eps)
            x = params['bn1_scale'] * normalized + params['bn1_bias']
        
        # Second normalization
        norm2 = self.layers.layer_norm(x, params['ln2_scale'], params['ln2_bias'])
        
        # MLP block with residual
        key, mlp_key = random.split(key)
        mlp_output = self.mlp_block(norm2, params, mlp_key, training)
        x = x + mlp_output
        
        # Apply stateless batch norm if enabled
        if self.config.use_batch_norm and 'bn2_scale' in params:
            mean = jnp.mean(x, axis=(0,), keepdims=True)
            var = jnp.var(x, axis=(0,), keepdims=True)
            normalized = (x - mean) / jnp.sqrt(var + self.config.bn_eps)
            x = params['bn2_scale'] * normalized + params['bn2_bias']
        
        # DropPath (Stochastic Depth) - Fixed implementation
        if training and self.config.path_dropout_rate > 0:
            # Calculate survival probability
            survival_prob = 1.0 - self.config.path_dropout_rate * (layer_id / self.config.num_layers)
            
            # Path dropout decision
            key, path_key = random.split(key)
            keep_prob = random.bernoulli(path_key, survival_prob)
            
            # Apply DropPath using jnp.where (JAX-compatible)
            # If keep_prob is False, return the input unchanged; otherwise return the transformed output
            x = jnp.where(keep_prob, x, x - (x - x))  # This effectively returns x when keep_prob is False
            
            # Better approach: scale the output by survival probability
            x = x / survival_prob * jnp.where(keep_prob, 1.0, survival_prob)
        
        return x


# ==============================
# 6. ENHANCED VISION TRANSFORMER
# ==============================
class EnhancedVisionTransformer:
    """Enhanced Vision Transformer with comprehensive regularization"""
    
    def __init__(self, config: ViTConfig):
        self.config = config
        self.layers = EnhancedModelLayers()
        self.transformer = EnhancedTransformerBlocks(config)
        self.augmenter = DataAugmentation(config)
    
    def __call__(
        self, 
        params: Dict,
        images: jnp.ndarray, 
        key: jnp.ndarray, 
        training: bool = True,
        return_features: bool = False
    ) -> Any:
        """Enhanced forward pass with data augmentation"""
        batch_size, height, width, channels = images.shape
        
        # Reshape images into patches
        patches = jnp.reshape(
            images, 
            (batch_size, height // self.config.patch_size, self.config.patch_size, 
             width // self.config.patch_size, self.config.patch_size, channels)
        )
        patches = jnp.transpose(patches, (0, 1, 3, 2, 4, 5))
        patches = jnp.reshape(
            patches, 
            (batch_size, self.config.num_patches, self.config.patch_size * self.config.patch_size * channels)
        )
        
        # Linear projection with dropout
        patch_embeddings = self.layers.linear(
            patches, 
            params['patch_projection_w'], 
            params['patch_projection_b']
        )
        
        if training and self.config.projection_dropout_rate > 0:
            key, proj_key = random.split(key)
            patch_embeddings = self.layers.dropout(proj_key, patch_embeddings, self.config.projection_dropout_rate)
        
        # Add class token
        cls_tokens = jnp.broadcast_to(
            params['cls_token'], 
            (batch_size, 1, self.config.hidden_dim)
        )
        x = jnp.concatenate([cls_tokens, patch_embeddings], axis=1)
        
        # Add position embeddings
        x = x + params['pos_embedding']
        
        # Embedding dropout
        if training and self.config.dropout_rate > 0:
            key, embed_key = random.split(key)
            x = self.layers.dropout(embed_key, x, self.config.dropout_rate)
        
        # Apply transformer encoder blocks
        for i, block_params in enumerate(params['encoder_blocks']):
            key, block_key = random.split(key)
            x = self.transformer.encoder_block(x, block_params, block_key, i, training)
        
        # Final layer norm
        x = self.layers.layer_norm(x, params['ln_final_scale'], params['ln_final_bias'])
        
        # Use [CLS] token representation
        cls_representation = x[:, 0]
        
        # Pre-head dropout
        if training and self.config.dropout_rate > 0:
            key, head_key = random.split(key)
            cls_representation = self.layers.dropout(head_key, cls_representation, self.config.dropout_rate)
        
        # Classification head
        logits = self.layers.linear(cls_representation, params['head_w'], params['head_b'])
        
        if return_features:
            return logits, cls_representation
        else:
            return logits


# ==============================
# 7. ENHANCED PARAMETER INITIALIZATION
# ==============================
class EnhancedParameterInitializer:
    """Enhanced parameter initialization with batch norm parameters"""
    
    def __init__(self, config: ViTConfig):
        self.config = config
        self.device = jax.devices()[0]
    
    def init_transformer_params(self, key: jnp.ndarray) -> Dict:
        """Initialize all enhanced transformer parameters"""
        keys_needed = self.config.num_layers * (6 if self.config.use_batch_norm else 4) + 5
        keys = random.split(key, num=keys_needed)
        key_idx = 0
        
        # Projection layer
        patch_projection_w, patch_projection_b = self.init_linear_params(
            keys[key_idx], 
            self.config.patch_size * self.config.patch_size * 3, 
            self.config.hidden_dim
        )
        key_idx += 1
        
        # Class token and position embeddings with proper scaling
        cls_token = random.normal(keys[key_idx], (1, self.config.hidden_dim)) * (self.config.hidden_dim ** -0.5)
        key_idx += 1
        pos_embedding = random.normal(
            keys[key_idx], 
            (1, self.config.num_patches + 1, self.config.hidden_dim)
        ) * (self.config.hidden_dim ** -0.5)
        key_idx += 1
        
        # Initialize transformer blocks
        encoder_blocks = []
        for _ in range(self.config.num_layers):
            block_size = 6 if self.config.use_batch_norm else 4
            block = self._init_transformer_block(keys[key_idx:key_idx+block_size])
            encoder_blocks.append(block)
            key_idx += block_size
        
        # Final layer norm and classification head
        ln_final_scale, ln_final_bias = self.init_layer_norm_params(self.config.hidden_dim)
        head_w, head_b = self.init_linear_params(keys[key_idx], self.config.hidden_dim, self.config.num_classes)
        
        # Use Xavier initialization for final layer
        head_w = head_w * jnp.sqrt(2.0 / self.config.hidden_dim)
        
        params = {
            'patch_projection_w': patch_projection_w,
            'patch_projection_b': patch_projection_b,
            'cls_token': cls_token,
            'pos_embedding': pos_embedding,
            'encoder_blocks': encoder_blocks,
            'ln_final_scale': ln_final_scale,
            'ln_final_bias': ln_final_bias,
            'head_w': head_w,
            'head_b': head_b,
        }
        
        return jax.tree_util.tree_map(lambda x: jax.device_put(x, self.device), params)
    
    def _init_transformer_block(self, keys: List[jnp.ndarray]) -> Dict:
        """Initialize a single enhanced transformer block"""
        block_params = {}
        key_idx = 0
        
        # Layer norm 1
        block_params['ln1_scale'], block_params['ln1_bias'] = self.init_layer_norm_params(self.config.hidden_dim)
        
        # Multi-head attention with proper scaling
        qkv_w, qkv_b = self.init_linear_params(keys[key_idx], self.config.hidden_dim, 3 * self.config.hidden_dim)
        qkv_w = qkv_w * (self.config.hidden_dim ** -0.5)  # Proper attention weight scaling
        key_idx += 1
        
        out_w, out_b = self.init_linear_params(keys[key_idx], self.config.hidden_dim, self.config.hidden_dim)
        key_idx += 1
        
        block_params.update({
            'qkv_w': qkv_w,
            'qkv_b': qkv_b,
            'out_w': out_w,
            'out_b': out_b,
        })
        
        # Batch norm after attention
        if self.config.use_batch_norm:
            block_params['bn1_scale'], block_params['bn1_bias'] = self.init_batch_norm_params(self.config.hidden_dim)
            block_params['bn1_running_mean'] = jnp.zeros(self.config.hidden_dim)
            block_params['bn1_running_var'] = jnp.ones(self.config.hidden_dim)
        
        # Layer norm 2
        block_params['ln2_scale'], block_params['ln2_bias'] = self.init_layer_norm_params(self.config.hidden_dim)
        
        # MLP with He initialization
        mlp1_w, mlp1_b = self.init_linear_params(keys[key_idx], self.config.hidden_dim, self.config.mlp_dim)
        mlp1_w = mlp1_w * jnp.sqrt(2.0 / self.config.hidden_dim)  # He initialization for GELU
        key_idx += 1
        
        mlp2_w, mlp2_b = self.init_linear_params(keys[key_idx], self.config.mlp_dim, self.config.hidden_dim)
        key_idx += 1
        
        block_params.update({
            'mlp1_w': mlp1_w,
            'mlp1_b': mlp1_b,
            'mlp2_w': mlp2_w,
            'mlp2_b': mlp2_b,
        })
        
        # Batch norm after MLP
        if self.config.use_batch_norm:
            block_params['bn2_scale'], block_params['bn2_bias'] = self.init_batch_norm_params(self.config.hidden_dim)
            block_params['bn2_running_mean'] = jnp.zeros(self.config.hidden_dim)
            block_params['bn2_running_var'] = jnp.ones(self.config.hidden_dim)
        
        return block_params
    
    def init_linear_params(self, key: jnp.ndarray, in_dim: int, out_dim: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Initialize linear layer parameters with Xavier initialization"""
        k1, k2 = random.split(key)
        scale = jnp.sqrt(2.0 / (in_dim + out_dim))
        weight = random.normal(k1, (in_dim, out_dim)) * scale
        bias = jnp.zeros((out_dim,))
        return weight, bias
    
    def init_layer_norm_params(self, dim: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Initialize layer normalization parameters"""
        scale = jnp.ones((dim,))
        bias = jnp.zeros((dim,))
        return scale, bias
    
    def init_batch_norm_params(self, dim: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Initialize batch normalization parameters"""
        scale = jnp.ones((dim,))
        bias = jnp.zeros((dim,))
        return scale, bias


# ==============================
# 8. EXPONENTIAL MOVING AVERAGE
# ==============================
class ExponentialMovingAverage:
    """Exponential Moving Average for model parameters"""
    
    def __init__(self, decay: float = 0.9999):
        self.decay = decay
        self.ema_params = None
    
    def init(self, params: Dict) -> None:
        """Initialize EMA parameters"""
        self.ema_params = jax.tree_util.tree_map(lambda x: x.copy(), params)
    
    def update(self, params: Dict) -> None:
        """Update EMA parameters"""
        if self.ema_params is None:
            self.init(params)
        else:
            def update_param(ema_param, param):
                return self.decay * ema_param + (1 - self.decay) * param
            
            self.ema_params = jax.tree_util.tree_map(update_param, self.ema_params, params)
    
    def get_params(self) -> Dict:
        """Get EMA parameters"""
        return self.ema_params


# ==============================
# 9. LABEL SMOOTHING LOSS
# ==============================
class LabelSmoothingCrossEntropy:
    """Cross-entropy loss with label smoothing"""
    
    def __init__(self, smoothing: float = 0.1):
        self.smoothing = smoothing
    
    def __call__(self, logits: jnp.ndarray, labels: jnp.ndarray) -> float:
        """Compute label smoothing cross-entropy loss"""
        num_classes = logits.shape[-1]
        
        # One-hot encode labels
        one_hot_labels = jax.nn.one_hot(labels, num_classes)
        
        # Apply label smoothing
        smoothed_labels = one_hot_labels * (1 - self.smoothing) + self.smoothing / num_classes
        
        # Compute loss
        log_probs = jax.nn.log_softmax(logits)
        loss = -jnp.sum(smoothed_labels * log_probs) / labels.shape[0]
        
        return loss


# ==============================
# 10. ENHANCED OPTIMIZER & SCHEDULER (from original)
# ==============================
class AdamWOptimizer:
    """AdamW optimizer implementation"""
    
    def __init__(self, config: ViTConfig):
        self.config = config
        
    def init_state(self, params: Dict) -> Dict:
        """Initialize optimizer state"""
        # Initialize first and second moments to zeros
        m = jax.tree_util.tree_map(jnp.zeros_like, params)
        v = jax.tree_util.tree_map(jnp.zeros_like, params)
        
        return {
            'm': m,
            'v': v,
            't': 0,
        }
    
    def update(
        self, 
        grads: Dict, 
        opt_state: Dict, 
        params: Dict, 
        learning_rate: float
    ) -> Tuple[Dict, Dict]:
        """Apply AdamW update step"""
        # Extract optimizer state
        m, v, t = opt_state['m'], opt_state['v'], opt_state['t']
        
        # Increment step
        t = t + 1
        
        # Compute bias correction terms
        bias_correction1 = 1 - self.config.beta1 ** t
        bias_correction2 = 1 - self.config.beta2 ** t
        
        # Define update function for a single parameter pair
        def update_momentum(m_param, grad):
            return self.config.beta1 * m_param + (1 - self.config.beta1) * grad
        
        def update_velocity(v_param, grad):
            return self.config.beta2 * v_param + (1 - self.config.beta2) * jnp.square(grad)
        
        def update_param(param, m_param, v_param):
            # Compute bias-corrected moment estimates
            m_hat = m_param / bias_correction1
            v_hat = v_param / bias_correction2
            
            # Apply weight decay
            param_with_decay = param - learning_rate * self.config.weight_decay * param
            
            # Update parameter
            return param_with_decay - learning_rate * m_hat / (jnp.sqrt(v_hat) + self.config.eps)
        
        # Update momentum terms
        new_m = jax.tree_util.tree_map(update_momentum, m, grads)
        
        # Update velocity terms  
        new_v = jax.tree_util.tree_map(update_velocity, v, grads)
        
        # Update parameters
        new_params = jax.tree_util.tree_map(update_param, params, new_m, new_v)
        
        # Update optimizer state
        new_opt_state = {
            'm': new_m,
            'v': new_v,
            't': t,
        }
        
        return new_params, new_opt_state


class LearningRateScheduler:
    """Learning rate scheduler with warmup and cosine decay"""
    
    def __init__(self, config: ViTConfig, total_steps: int):
        self.config = config
        self.total_steps = total_steps
        self.warmup_steps = config.warmup_steps
        self.peak_lr = config.learning_rate
        self.end_lr = 0.0
        
    def __call__(self, step: int) -> float:
        """Get learning rate for the current step"""
        # Linear warmup
        warmup_lr = self.peak_lr * jnp.minimum(1.0, step / self.warmup_steps)
        
        # Cosine decay
        decay_steps = self.total_steps - self.warmup_steps
        decay_factor = 0.5 * (1 + jnp.cos(
            jnp.pi * jnp.minimum(step - self.warmup_steps, decay_steps) / decay_steps
        ))
        cosine_lr = self.end_lr + (self.peak_lr - self.end_lr) * decay_factor
        
        # Use warmup for steps < warmup_steps, cosine decay afterward
        lr = jnp.where(step < self.warmup_steps, warmup_lr, cosine_lr)
        
        return lr


# class LearningRateScheduler:
#     """Learning rate scheduler with warmup and logarithmic decay"""
    
#     def __init__(self, config, total_steps: int):
#         self.config = config
#         self.total_steps = total_steps
#         self.warmup_steps = config.warmup_steps
#         self.peak_lr = config.learning_rate
#         self.epsilon = 1e-8  # To avoid division by log(0)

#     def __call__(self, step: int) -> float:
#         """Get learning rate for the current step"""
#         # Linear warmup
#         warmup_lr = self.peak_lr * jnp.minimum(1.0, step / self.warmup_steps)

#         # Logarithmic decay
#         decay_step = jnp.maximum(step - self.warmup_steps, 1)  # Ensure > 0 for log
#         decay_steps = jnp.maximum(self.total_steps - self.warmup_steps, 1)
#         log_lr = self.peak_lr / (1.0 + jnp.log1p(decay_step))  # log1p = log(1 + x)

#         # Use warmup for steps < warmup_steps, log decay afterward
#         lr = jnp.where(step < self.warmup_steps, warmup_lr, log_lr)

#         return lr


# class LearningRateScheduler:
#     """Learning rate scheduler with warmup and step decay"""
    
#     def __init__(self, config, total_steps: int):
#         self.config = config
#         self.total_steps = total_steps
#         self.warmup_steps = config.warmup_steps
#         self.peak_lr = config.learning_rate
#         self.drop_rate = getattr(config, "drop_rate", 0.5)         # Default: halve the LR each drop
#         self.drop_every = getattr(config, "drop_every", 1000)      # Default: drop every 1000 steps

#     def __call__(self, step: int) -> float:
#         """Get learning rate for the current step"""
#         # Linear warmup
#         warmup_lr = self.peak_lr * jnp.minimum(1.0, step / self.warmup_steps)

#         # Step decay
#         step_after_warmup = jnp.maximum(step - self.warmup_steps, 0)
#         num_drops = jnp.floor(step_after_warmup / self.drop_every)
#         step_lr = self.peak_lr * (self.drop_rate ** num_drops)

#         # Use warmup for steps < warmup_steps, step decay afterward
#         lr = jnp.where(step < self.warmup_steps, warmup_lr, step_lr)

#         return lr




# ==============================
# 11. DATA LOADER (from original)
# ==============================
class DataLoader:
    """Handles data loading for CIFAR-10"""
    
    CIFAR10_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                       'dog', 'frog', 'horse', 'ship', 'truck']
    
    def __init__(self, data_dir: str = "cifar10_data"):
        self.data_dir = data_dir
        self.url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
        
    def load_datasets(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Load CIFAR-10 dataset"""
        # Create data directory
        os.makedirs(self.data_dir, exist_ok=True)
        
        # Download if necessary
        tar_file_path = os.path.join(self.data_dir, "cifar-10-python.tar.gz")
        if not os.path.exists(tar_file_path):
            print(f"Downloading CIFAR-10 dataset from {self.url}")
            self._download_dataset(tar_file_path)
        
        # Extract if necessary
        extract_dir = os.path.join(self.data_dir, "cifar-10-batches-py")
        if not os.path.exists(extract_dir):
            print(f"Extracting dataset to {extract_dir}")
            self._extract_dataset(tar_file_path, self.data_dir)
        
        # Load training and test data
        train_data, train_labels = self._load_training_data(extract_dir)
        test_data, test_labels = self._load_test_data(extract_dir)
        
        # Preprocess data
        train_data = self._preprocess_data(train_data)
        test_data = self._preprocess_data(test_data)
        
        print(f"Loaded {len(train_data)} training samples and {len(test_data)} test samples")
        
        return train_data, train_labels, test_data, test_labels
    
    def _download_dataset(self, file_path: str):
        """Download CIFAR-10 dataset"""
        urllib.request.urlretrieve(self.url, file_path)
        print(f"Dataset downloaded to {file_path}")
    
    def _extract_dataset(self, tar_file_path: str, extract_to: str):
        """Extract CIFAR-10 dataset"""
        with tarfile.open(tar_file_path, 'r:gz') as tar:
            tar.extractall(path=extract_to)
        print("Dataset extracted")
    
    def _load_training_data(self, extract_dir: str) -> Tuple[np.ndarray, np.ndarray]:
        """Load training data from multiple batch files"""
        train_data = []
        train_labels = []
        
        # Load training batches
        for i in range(1, 6):
            batch_file = os.path.join(extract_dir, f'data_batch_{i}')
            with open(batch_file, 'rb') as f:
                batch_data = pickle.load(f, encoding='bytes')
                train_data.append(batch_data[b'data'])
                train_labels.extend(batch_data[b'labels'])
        
        # Concatenate all training data
        train_data = np.vstack(train_data)
        train_labels = np.array(train_labels)
        
        return train_data, train_labels
    
    def _load_test_data(self, extract_dir: str) -> Tuple[np.ndarray, np.ndarray]:
        """Load test data"""
        test_file = os.path.join(extract_dir, 'test_batch')
        with open(test_file, 'rb') as f:
            test_batch = pickle.load(f, encoding='bytes')
            test_data = test_batch[b'data']
            test_labels = np.array(test_batch[b'labels'])
        
        return test_data, test_labels
    
    def _preprocess_data(self, data: np.ndarray) -> np.ndarray:
        """Preprocess CIFAR-10 data"""
        # Reshape and normalize the data
        # CIFAR-10 data layout: [N, 3072] where 3072 = 3 x 32 x 32
        # We reshape to [N, 32, 32, 3] for image processing
        data = data.reshape(-1, 3, 32, 32)
        data = data.transpose(0, 2, 3, 1)
        data = data.astype(np.float32) / 255.0
        
        return data


# ==============================
# 12. VISUALIZER (from original)
# ==============================
class Visualizer:
    """Handles all visualization tasks"""
    
    def __init__(self, results_dir: str):
        self.results_dir = results_dir
        self.cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                               'dog', 'frog', 'horse', 'ship', 'truck']
    
    def plot_training_metrics(self, state: TrainingState):
        """Plot training and validation metrics over epochs"""
        epochs = range(1, len(state.train_losses) + 1)
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # Loss plot
        ax1.plot(epochs, state.train_losses, 'b-', label='Training Loss')
        ax1.plot(epochs, state.eval_losses, 'r-', label='Validation Loss')
        ax1.set_title('Training and Validation Loss')
        ax1.set_xlabel('Epochs')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True)
        
        # Accuracy plot
        ax2.plot(epochs, state.train_accs, 'b-', label='Training Accuracy')
        ax2.plot(epochs, state.eval_accs, 'r-', label='Validation Accuracy')
        ax2.set_title('Training and Validation Accuracy')
        ax2.set_xlabel('Epochs')
        ax2.set_ylabel('Accuracy')
        ax2.legend()
        ax2.grid(True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.results_dir, 'training_metrics.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    def plot_learning_rate_schedule(self, lr_scheduler, total_steps: int):
        """Plot the learning rate schedule"""
        steps = range(total_steps)
        lrs = [float(lr_scheduler(step)) for step in steps]
        
        plt.figure(figsize=(10, 6))
        plt.plot(steps, lrs)
        plt.title('Learning Rate Schedule')
        plt.xlabel('Step')
        plt.ylabel('Learning Rate')
        plt.grid(True)
        plt.savefig(os.path.join(self.results_dir, 'lr_schedule.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    def plot_confusion_matrix(self, cm: np.ndarray):
        """Plot confusion matrix"""
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                    xticklabels=self.cifar10_classes, 
                    yticklabels=self.cifar10_classes)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title('Confusion Matrix')
        plt.tight_layout()
        plt.savefig(os.path.join(self.results_dir, 'confusion_matrix.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    def plot_classification_metrics(self, metrics: Dict):
        """Plot precision, recall, and F1 scores"""
        metrics_df = pd.DataFrame({
            'Precision': metrics['precision'],
            'Recall': metrics['recall'],
            'F1 Score': metrics['f1']
        }, index=self.cifar10_classes)
        
        plt.figure(figsize=(12, 6))
        metrics_df.plot(kind='bar', figsize=(12, 6))
        plt.title('Precision, Recall, and F1 Score by Class')
        plt.xlabel('Class')
        plt.ylabel('Score')
        plt.xticks(rotation=45)
        plt.ylim(0, 1.0)
        plt.tight_layout()
        plt.legend(loc='best')
        plt.savefig(os.path.join(self.results_dir, 'precision_recall_f1.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    def plot_feature_space(self, features: np.ndarray, labels: np.ndarray, subset_size: int = 2000):
        """Visualize feature space using t-SNE, PCA, and UMAP"""
        # Use subset if needed
        if len(features) > subset_size:
            subset_indices = np.random.choice(len(features), subset_size, replace=False)
            features = features[subset_indices]
            labels = labels[subset_indices]
        
        # t-SNE visualization
        print("Computing t-SNE projection...")
        tsne = TSNE(n_components=2, random_state=42, n_jobs=-1)
        tsne_result = tsne.fit_transform(features)
        
        self._plot_2d_embedding(tsne_result, labels, 't-SNE Visualization of Feature Space', 'tsne_visualization.png')
        
        # PCA visualization
        print("Computing PCA projection...")
        pca = PCA(n_components=2, random_state=42)
        pca_result = pca.fit_transform(features)
        
        self._plot_2d_embedding(pca_result, labels, 'PCA Visualization of Feature Space', 'pca_visualization.png')
        
        # UMAP visualization
        print("Computing UMAP projection...")
        umap_model = umap.UMAP(n_components=2, random_state=42)
        umap_result = umap_model.fit_transform(features)
        
        self._plot_2d_embedding(umap_result, labels, 'UMAP Visualization of Feature Space', 'umap_visualization.png')
    
    def _plot_2d_embedding(self, embedding: np.ndarray, labels: np.ndarray, title: str, filename: str):
        """Plot 2D embedding with class colors"""
        plt.figure(figsize=(10, 8))
        
        # Create scatter plot for each class
        for i, class_name in enumerate(self.cifar10_classes):
            indices = labels == i
            plt.scatter(embedding[indices, 0], embedding[indices, 1], 
                       label=class_name, alpha=0.7, s=20)
        
        plt.title(title)
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.savefig(os.path.join(self.results_dir, filename), dpi=300, bbox_inches='tight')
        plt.close()
    
    def plot_sample_predictions(
        self, 
        vit: EnhancedVisionTransformer,
        params: Dict,
        test_data: np.ndarray,
        test_labels: np.ndarray,
        num_samples: int = 10
    ):
        """Visualize sample predictions"""
        # Get random subset of test samples
        indices = np.random.choice(len(test_data), num_samples, replace=False)
        images = test_data[indices]
        labels = test_labels[indices]
        
        # Move to device and get predictions
        device = jax.devices()[0]
        images_device = jax.device_put(images, device)
        labels_device = jax.device_put(labels, device)
        
        # Get predictions
        eval_key = jax.device_put(random.PRNGKey(0), device)
        logits = vit(params, images_device, eval_key, training=False)
        preds = Metrics.get_predictions(logits)
        
        # Move back to CPU for visualization
        images_cpu = np.array(images)
        preds_cpu = np.array(preds)
        labels_cpu = np.array(labels)
        
        # Create figure
        fig, axes = plt.subplots(2, 5, figsize=(15, 8))
        axes = axes.flatten()
        
        for i, (image, label, pred) in enumerate(zip(images_cpu, labels_cpu, preds_cpu)):
            ax = axes[i]
            ax.imshow(image)
            ax.set_title(f"True: {self.cifar10_classes[label]}\nPred: {self.cifar10_classes[pred]}")
            ax.axis('off')
            
            # Highlight correct/incorrect predictions
            if label == pred:
                # Green border for correct predictions
                for spine in ax.spines.values():
                    spine.set_edgecolor('green')
                    spine.set_linewidth(3)
                    spine.set_visible(True)
            else:
                # Red border for incorrect predictions
                for spine in ax.spines.values():
                    spine.set_edgecolor('red')
                    spine.set_linewidth(3)
                    spine.set_visible(True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.results_dir, 'sample_predictions.png'), dpi=300, bbox_inches='tight')
        plt.close()


# ==============================
# 13. ENHANCED TRAINER CLASS
# ==============================
class EnhancedTrainer:
    """Enhanced trainer with all regularization techniques"""
    
    def __init__(self, config: ViTConfig, results_dir: str = "enhanced_vit_results"):
        self.config = config
        self.results_dir = results_dir
        self.device = jax.devices()[0]
        
        # Initialize components
        self.vit = EnhancedVisionTransformer(config)
        self.optimizer = AdamWOptimizer(config)
        self.initializer = EnhancedParameterInitializer(config)
        self.augmenter = DataAugmentation(config)
        
        # Initialize loss function with label smoothing
        self.loss_fn = LabelSmoothingCrossEntropy(config.label_smoothing)
        
        # Set up JAX key
        self.key = jax.device_put(random.PRNGKey(config.seed), self.device)
        
        # Initialize model and optimizer
        init_key, self.key = random.split(self.key)
        self.params = self.initializer.init_transformer_params(init_key)
        self.opt_state = self.optimizer.init_state(self.params)
        
        # Initialize EMA
        self.ema = ExponentialMovingAverage(config.ema_decay) if config.use_ema else None
        if self.ema:
            self.ema.init(self.params)
        
        # Initialize training state
        self.state = TrainingState(
            params=self.params,
            opt_state=self.opt_state
        )
        
        # Initialize learning rate scheduler early - we need it for visualization
        # We'll use a placeholder total_steps value that will be updated in train()
        placeholder_steps = config.num_epochs * 391  # Approximate steps for CIFAR-10
        self.lr_scheduler = LearningRateScheduler(config, placeholder_steps)
        
        # Compile functions
        self._compile_functions()
    
    def _compile_functions(self):
        """Compile JAX functions for better performance"""
        self.jit_train_step = jax.jit(self._train_step)
        self.jit_eval_step = jax.jit(self._eval_step)
    
    def _train_step(
        self, 
        params: Dict, 
        batch: Tuple[jnp.ndarray, jnp.ndarray], 
        keys: Dict[str, jnp.ndarray],
        opt_state: Dict, 
        learning_rate: float
    ) -> Tuple[Dict, float, float, Dict]:
        """Enhanced training step with augmentation"""
        images, labels = batch
        
        # Apply augmentation
        aug_key = keys['augmentation']
        if self.config.use_data_augmentation:
            # Create one-hot labels for mixup
            one_hot_labels = jax.nn.one_hot(labels, self.config.num_classes)
            augmented_images, augmented_labels = self.augmenter.apply_augmentation(aug_key, images, one_hot_labels)
        else:
            augmented_images = images
            augmented_labels = labels
        
        def loss_fn(params_weights):
            logits = self.vit(params_weights, augmented_images, keys['model'], training=True)
            
            # Handle soft labels from mixup
            if augmented_labels.ndim > 1:
                # Soft labels - use label smoothing loss with pre-smoothed labels
                log_probs = jax.nn.log_softmax(logits)
                loss = -jnp.mean(jnp.sum(augmented_labels * log_probs, axis=-1))
            else:
                # Hard labels - use standard label smoothing
                loss = self.loss_fn(logits, augmented_labels)
            
            return loss, logits
        
        (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
        
        # Apply gradient clipping
        grads = self._clip_gradients(grads, max_norm=1.0)
        
        # Apply optimizer
        new_params, new_opt_state = self.optimizer.update(
            grads, opt_state, params, learning_rate
        )
        
        # # Update EMA
        # if self.ema:
        #     self.ema.update(new_params)
        
        # Compute accuracy (using original labels)
        if augmented_labels.ndim > 1:
            # For mixup, use argmax of soft labels
            hard_labels = jnp.argmax(augmented_labels, axis=-1)
        else:
            hard_labels = augmented_labels
        
        acc = Metrics.accuracy(logits, hard_labels)
        
        return new_params, loss, acc, new_opt_state
    
    def _clip_gradients(self, grads: Dict, max_norm: float = 1.0) -> Dict:
        """Apply gradient clipping"""
        def clip_param_grad(grad):
            return jnp.clip(grad, -max_norm, max_norm)
        
        return jax.tree_util.tree_map(clip_param_grad, grads)
    
    def _eval_step(
        self, 
        params: Dict, 
        batch: Tuple[jnp.ndarray, jnp.ndarray], 
        return_features: bool = False
    ) -> Any:
        """Evaluation step"""
        images, labels = batch
        
        # No dropout during evaluation
        eval_key = jax.device_put(random.PRNGKey(0), self.device)
        
        if return_features:
            logits, features = self.vit(params, images, eval_key, training=False, return_features=True)
        else:
            logits = self.vit(params, images, eval_key, training=False)
        
        loss = self.loss_fn(logits, labels)
        acc = Metrics.accuracy(logits, labels)
        preds = Metrics.get_predictions(logits)
        
        if return_features:
            return loss, acc, preds, features
        else:
            return loss, acc, preds
    
    def train(
        self, 
        train_data: np.ndarray, 
        train_labels: np.ndarray, 
        test_data: np.ndarray, 
        test_labels: np.ndarray
    ) -> Dict:
        """Enhanced training loop"""
        # Initialize learning rate scheduler with actual total steps
        total_steps = self.config.num_epochs * (len(train_data) // self.config.batch_size)
        self.lr_scheduler = LearningRateScheduler(self.config, total_steps)  # Reinitialize with correct steps
        
        # Create results directory
        os.makedirs(self.results_dir, exist_ok=True)
        
        print(f"Starting enhanced training on {self.device}")
        print(f"Total steps: {total_steps}")
        print(f"Regularization: Dropout={self.config.dropout_rate}, BatchNorm={self.config.use_batch_norm}")
        print(f"Data Augmentation: Enabled={self.config.use_data_augmentation}")
        print(f"Label Smoothing: {self.config.label_smoothing}")
        
        # Training loop
        best_val_accuracy = 0.0
        patience_counter = 0
        patience = self.config.early_stopping_patience  # Early stopping patience
        
        for epoch in range(self.config.num_epochs):
            self.state.epoch = epoch
            epoch_start_time = time.time()
            
            # Training epoch
            self._train_epoch(train_data, train_labels)
            
            # Evaluation with both current and EMA parameters
            eval_metrics = self._eval_epoch(test_data, test_labels)
            
            if self.ema:
                ema_metrics = self._eval_with_ema(test_data, test_labels)
                print(f"EMA Eval Accuracy: {ema_metrics['accuracy']:.4f}")
            
            # Update state
            self.state.train_losses.append(np.mean(self.epoch_train_losses))
            self.state.train_accs.append(np.mean(self.epoch_train_accs))
            self.state.eval_losses.append(eval_metrics['loss'])
            self.state.eval_accs.append(eval_metrics['accuracy'])
            
            # Check for best model
            current_accuracy = eval_metrics['accuracy']
            if self.ema and ema_metrics['accuracy'] > current_accuracy:
                current_accuracy = ema_metrics['accuracy']
            
            if current_accuracy > self.state.best_accuracy:
                self.state.best_accuracy = current_accuracy
                self.state.best_epoch = epoch
                patience_counter = 0
                
                # Save best model
                self._save_checkpoint('best_model.pkl')
            else:
                patience_counter += 1
            
            epoch_time = time.time() - epoch_start_time
            
            print(f"Epoch {epoch+1}/{self.config.num_epochs} completed in {epoch_time:.2f}s")
            print(f"  Train Loss: {self.state.train_losses[-1]:.4f}, Train Accuracy: {self.state.train_accs[-1]:.4f}")
            print(f"  Eval Loss: {self.state.eval_losses[-1]:.4f}, Eval Accuracy: {self.state.eval_accs[-1]:.4f}")
            print(f"  Best Accuracy: {self.state.best_accuracy:.4f} at epoch {self.state.best_epoch+1}")
            
            # Early stopping
            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break
            
            # Save periodic checkpoint
            if (epoch + 1) % 10 == 0:
                self._save_checkpoint(f'checkpoint_epoch_{epoch+1}.pkl')
            
            # Plot progress periodically
            if (epoch + 1) % 5 == 0 or epoch == self.config.num_epochs - 1:
                viz = Visualizer(self.results_dir)
                viz.plot_training_metrics(self.state)
        
        # Final evaluation with EMA parameters if available
        if self.ema:
            final_metrics = self._final_evaluation_with_ema(test_data, test_labels)
        else:
            final_metrics = self._final_evaluation(test_data, test_labels)
        
        return final_metrics
    
    def _final_evaluation(self, test_data: np.ndarray, test_labels: np.ndarray) -> Dict:
        """Final comprehensive evaluation"""
        print("Running final evaluation with feature extraction...")
        
        all_features = []
        all_preds = []
        all_labels = []
        
        num_test_batches = len(test_data) // self.config.batch_size
        if len(test_data) % self.config.batch_size != 0:
            num_test_batches += 1
            
        batches = self._iterate_batches(test_data, test_labels, shuffle=False)
        progress_bar = tqdm(batches, total=num_test_batches, 
                           desc="Final Evaluation",
                           leave=True)
        
        for batch in progress_bar:
            loss, acc, preds, features = self._eval_step(self.state.params, batch, return_features=True)
            all_features.append(np.array(features))
            all_preds.append(np.array(preds))
            all_labels.append(np.array(batch[1]))
            
            progress_bar.set_postfix({
                'loss': f"{float(loss):.4f}",
                'acc': f"{float(acc):.4f}"
            })
        
        # Concatenate results
        features = np.concatenate(all_features)
        predictions = np.concatenate(all_preds)
        labels = np.concatenate(all_labels)
        
        # Compute detailed metrics
        metrics = Metrics.compute_detailed_metrics(labels, predictions)
        
        # Generate visualizations
        viz = Visualizer(self.results_dir)
        viz.plot_confusion_matrix(metrics['confusion_matrix'])
        viz.plot_classification_metrics(metrics)
        viz.plot_feature_space(features, labels, subset_size=2000)
        viz.plot_sample_predictions(self.vit, self.state.params, test_data, test_labels)
        
        return {
            'metrics': metrics,
            'features': features,
            'predictions': predictions,
            'labels': labels
        }
    
    def _final_evaluation_with_ema(self, test_data: np.ndarray, test_labels: np.ndarray) -> Dict:
        """Final evaluation with EMA parameters"""
        print("Running final evaluation with EMA parameters...")
        ema_params = self.ema.get_params()
        
        all_features = []
        all_preds = []
        all_labels = []
        
        num_test_batches = len(test_data) // self.config.batch_size
        if len(test_data) % self.config.batch_size != 0:
            num_test_batches += 1
            
        batches = self._iterate_batches(test_data, test_labels, shuffle=False)
        progress_bar = tqdm(batches, total=num_test_batches, 
                           desc="Final EMA Evaluation",
                           leave=True)
        
        for batch in progress_bar:
            loss, acc, preds, features = self._eval_step(ema_params, batch, return_features=True)
            all_features.append(np.array(features))
            all_preds.append(np.array(preds))
            all_labels.append(np.array(batch[1]))
            
            progress_bar.set_postfix({
                'loss': f"{float(loss):.4f}",
                'acc': f"{float(acc):.4f}"
            })
        
        # Concatenate results
        features = np.concatenate(all_features)
        predictions = np.concatenate(all_preds)
        labels = np.concatenate(all_labels)
        
        # Compute detailed metrics
        metrics = Metrics.compute_detailed_metrics(labels, predictions)
        
        # Generate visualizations with "ema_" prefix
        viz = Visualizer(self.results_dir)
        
        # Save confusion matrix with EMA prefix
        orig_results_dir = viz.results_dir
        viz.results_dir = orig_results_dir
        cm_filename = 'ema_confusion_matrix.png'
        plt.figure(figsize=(10, 8))
        sns.heatmap(metrics['confusion_matrix'], annot=True, fmt='d', cmap='Blues', 
                    xticklabels=viz.cifar10_classes, 
                    yticklabels=viz.cifar10_classes)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title('EMA Confusion Matrix')
        plt.tight_layout()
        plt.savefig(os.path.join(orig_results_dir, cm_filename), dpi=300, bbox_inches='tight')
        plt.close()
        
        return {
            'metrics': metrics,
            'features': features,
            'predictions': predictions,
            'labels': labels
        }
    
    def _eval_with_ema(self, test_data: np.ndarray, test_labels: np.ndarray) -> Dict:
        """Evaluate using EMA parameters"""
        if not self.ema:
            return {'accuracy': 0.0, 'loss': float('inf')}
        
        ema_params = self.ema.get_params()
        return self._eval_epoch(test_data, test_labels, params=ema_params)
    
    def _save_checkpoint(self, filename: str):
        """Save model checkpoint"""
        checkpoint = {
            'params': self.state.params,
            'opt_state': self.state.opt_state,
            'epoch': self.state.epoch,
            'best_accuracy': self.state.best_accuracy,
            'config': self.config
        }
        
        if self.ema:
            checkpoint['ema_params'] = self.ema.get_params()
        
        filepath = os.path.join(self.results_dir, filename)
        with open(filepath, 'wb') as f:
            pickle.dump(checkpoint, f)
    
    def load_checkpoint(self, filepath: str):
        """Load model checkpoint"""
        with open(filepath, 'rb') as f:
            checkpoint = pickle.load(f)
        
        self.state.params = checkpoint['params']
        self.state.opt_state = checkpoint['opt_state']
        self.state.epoch = checkpoint['epoch']
        self.state.best_accuracy = checkpoint['best_accuracy']
        
        if 'ema_params' in checkpoint and self.ema:
            self.ema.ema_params = checkpoint['ema_params']
    
    def _train_epoch(self, train_data: np.ndarray, train_labels: np.ndarray):
        """Enhanced training epoch"""
        self.epoch_train_losses = []
        self.epoch_train_accs = []
        
        num_batches = len(train_data) // self.config.batch_size
        batches = self._iterate_batches(train_data, train_labels, shuffle=True)
        progress_bar = tqdm(batches, total=num_batches, 
                           desc=f"Epoch {self.state.epoch+1}/{self.config.num_epochs} [Train]",
                           leave=True)
        
        for batch_idx, batch in enumerate(progress_bar):
            # Generate keys for different random operations
            self.key, *sub_keys = random.split(self.key, 4)
            keys = {
                'model': sub_keys[0],
                'augmentation': sub_keys[1],
                'optimizer': sub_keys[2]
            }
            
            # Get current learning rate
            current_lr = float(self.lr_scheduler(self.state.step))
            
            # Update parameters
            self.state.params, loss, acc, self.state.opt_state = self.jit_train_step(
                self.state.params, batch, keys, self.state.opt_state, current_lr
            )


            if self.ema:
                self.ema.update(self.state.params)
            
            self.epoch_train_losses.append(float(loss))
            self.epoch_train_accs.append(float(acc))
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f"{float(loss):.4f}",
                'acc': f"{float(acc):.4f}",
                'lr': f"{current_lr:.6f}"
            })
            
            self.state.step += 1
    
    def _eval_epoch(self, test_data: np.ndarray, test_labels: np.ndarray, params: Dict = None) -> Dict:
        """Enhanced evaluation epoch"""
        if params is None:
            params = self.state.params
        
        epoch_eval_losses = []
        epoch_eval_accs = []
        all_preds = []
        all_labels = []
        
        num_test_batches = len(test_data) // self.config.batch_size
        if len(test_data) % self.config.batch_size != 0:
            num_test_batches += 1
            
        batches = self._iterate_batches(test_data, test_labels, shuffle=False)
        progress_bar = tqdm(batches, total=num_test_batches, 
                           desc=f"Epoch {self.state.epoch+1}/{self.config.num_epochs} [Eval]",
                           leave=True)
        
        for batch in progress_bar:
            loss, acc, preds = self.jit_eval_step(params, batch)
            epoch_eval_losses.append(float(loss))
            epoch_eval_accs.append(float(acc))
            
            all_preds.append(np.array(preds))
            all_labels.append(batch[1])
            
            progress_bar.set_postfix({
                'loss': f"{float(loss):.4f}",
                'acc': f"{float(acc):.4f}"
            })
        
        return {
            'loss': np.mean(epoch_eval_losses),
            'accuracy': np.mean(epoch_eval_accs),
            'predictions': np.concatenate(all_preds),
            'labels': np.concatenate(all_labels)
        }
    
    def _iterate_batches(
        self, 
        images: np.ndarray, 
        labels: np.ndarray, 
        shuffle: bool = False
    ):
        """Iterator for creating batches"""
        num_samples = len(images)
        indices = np.arange(num_samples)
        
        if shuffle:
            np.random.shuffle(indices)
        
        for start_idx in range(0, num_samples, self.config.batch_size):
            end_idx = min(start_idx + self.config.batch_size, num_samples)
            batch_indices = indices[start_idx:end_idx]
            yield images[batch_indices], labels[batch_indices]


# ==============================
# 14. ENHANCED MAIN FUNCTION
# ==============================
def main():
    """Enhanced main entry point with all regularization techniques"""
    
    # Enhanced configuration
    config = ViTConfig(
        # Model architecture
        img_size=32,
        patch_size=4,
        num_classes=10,
        num_heads=3,
        num_layers=2,
        hidden_dim=384,
        mlp_dim=1536,
        
        # Enhanced dropout configuration
        dropout_rate = 0.1,
        attention_dropout_rate = 0.1,
        projection_dropout_rate = 0.1,
        path_dropout_rate = 0.1,
        
        # Batch normalization
        use_batch_norm=True,
        bn_momentum=0.9,
        
        # Data augmentation
        # use_data_augmentation=True,
        # mixup_alpha=0.8,
        # cutmix_alpha=1.0,
        # cutout_size=8,

        use_data_augmentation = True,
        augment_prob = 0.3,
        mixup_alpha = 0.8,
        cutmix_alpha = 1.0,
        cutout_size = 8,
        
        # Label smoothing
        label_smoothing=0.1,
        
        # Training parameters
        batch_size=256,
        num_epochs=100,  # Increased for better convergence
        learning_rate=3e-4,
        weight_decay=0.05,
        warmup_steps=1000,
        early_stopping_patience=10,
        
        # EMA
        use_ema=True,
        ema_decay=0.9999,
        
        # Other settings
        seed=42
    )
    
    # Create unique results directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = os.path.join("enhanced_vit_results", f"run_{timestamp}")
    os.makedirs(results_dir, exist_ok=True)
    
    # Print startup information
    print("\n" + "="*80)
    print("ENHANCED VISION TRANSFORMER WITH ADVANCED REGULARIZATION")
    print("="*80)
    print("Configuration:")
    print(f"  Architecture: {config.num_layers} layers × {config.hidden_dim} dim")
    print(f"  Regularization Techniques:")
    print(f"    - Dropout: {config.dropout_rate} (attention: {config.attention_dropout_rate})")
    print(f"    - DropPath: {config.path_dropout_rate}")
    print(f"    - Batch Normalization: {config.use_batch_norm}")
    print(f"    - Data Augmentation: {config.use_data_augmentation}")
    print(f"    - Label Smoothing: {config.label_smoothing}")
    print(f"    - EMA: {config.use_ema} (decay: {config.ema_decay})")
    print(f"    - Weight Decay: {config.weight_decay}")
    print(f"  Training: {config.num_epochs} epochs × {config.batch_size} batch size")
    print(f"  Learning rate: {config.learning_rate} (warmup: {config.warmup_steps} steps)")
    print(f"  Device: {jax.devices()[0]}")
    print(f"  Results: {results_dir}")
    print("="*80 + "\n")
    
    # Initialize components
    data_loader = DataLoader(data_dir="cifar10_data")
    trainer = EnhancedTrainer(config, results_dir)
    visualizer = Visualizer(results_dir)
    
    # Save configuration
    config_path = os.path.join(results_dir, 'enhanced_config.txt')
    with open(config_path, 'w') as f:
        f.write("Enhanced Vision Transformer Configuration\n")
        f.write("="*60 + "\n\n")
        for key, value in config.__dict__.items():
            f.write(f"{key}: {value}\n")
    
    try:
        # Load data
        print("Loading CIFAR-10 dataset...")
        train_data, train_labels, test_data, test_labels = data_loader.load_datasets()
        
        # Create visualizations
        print("Creating initial visualizations...")
        total_steps = config.num_epochs * (len(train_data) // config.batch_size)
        visualizer.plot_learning_rate_schedule(trainer.lr_scheduler or 
                                             LearningRateScheduler(config, total_steps), 
                                             total_steps)
        
        # Run enhanced training
        print("\nStarting enhanced training with regularization...")
        start_time = time.time()
        
        final_metrics = trainer.train(train_data, train_labels, test_data, test_labels)
        
        # Calculate total time
        total_time = time.time() - start_time
        hours = int(total_time // 3600)
        minutes = int((total_time % 3600) // 60)
        seconds = int(total_time % 60)
        
        # Save final results summary
        summary_path = os.path.join(results_dir, 'training_summary.txt')
        with open(summary_path, 'w') as f:
            f.write("Enhanced Vision Transformer Training Summary\n")
            f.write("="*60 + "\n\n")
            f.write(f"Training Time: {hours:02d}:{minutes:02d}:{seconds:02d}\n")
            f.write(f"Best Validation Accuracy: {trainer.state.best_accuracy:.4f}\n")
            f.write(f"Best Epoch: {trainer.state.best_epoch + 1}\n")
            f.write(f"Final Macro F1 Score: {final_metrics['metrics']['macro_f1']:.4f}\n")
            f.write(f"Final Macro Precision: {final_metrics['metrics']['macro_precision']:.4f}\n")
            f.write(f"Final Macro Recall: {final_metrics['metrics']['macro_recall']:.4f}\n")
            f.write(f"\nRegularization Techniques Used:\n")
            f.write(f"  - Dropout: {config.dropout_rate}\n")
            f.write(f"  - Attention Dropout: {config.attention_dropout_rate}\n")
            f.write(f"  - DropPath: {config.path_dropout_rate}\n")
            f.write(f"  - Batch Normalization: {config.use_batch_norm}\n")
            f.write(f"  - Data Augmentation: {config.use_data_augmentation}\n")
            f.write(f"  - Label Smoothing: {config.label_smoothing}\n")
            f.write(f"  - EMA: {config.use_ema}\n")
            f.write(f"  - Weight Decay: {config.weight_decay}\n")
            f.write(f"\nResults Directory: {results_dir}\n")
        
        # Save hardware summary
        hardware_summary_path = os.path.join(results_dir, 'hardware_summary.txt')
        with open(hardware_summary_path, 'w') as f:
            f.write("Hardware Summary\n")
            f.write("="*50 + "\n\n")
            f.write(f"JAX version: {jax.__version__}\n")
            f.write(f"Device used: {jax.devices()[0]}\n")
            f.write(f"Number of devices: {jax.device_count()}\n")
            try:
                f.write(f"Device type: {jax.devices()[0].device_kind}\n")
            except:
                f.write("Device type: Unknown\n")
        
        # Print summary
        print("\n" + "="*80)
        print("ENHANCED TRAINING COMPLETED SUCCESSFULLY!")
        print("="*80)
        print(f"Total time: {hours:02d}:{minutes:02d}:{seconds:02d}")
        print(f"Best accuracy: {trainer.state.best_accuracy:.4f}")
        print(f"Final F1 score: {final_metrics['metrics']['macro_f1']:.4f}")
        print(f"Results in: {results_dir}")
        print("="*80)
        
        # List generated files
        print("\nGenerated files:")
        print("  Visualizations:")
        print("    - training_metrics.png - Learning curves")
        print("    - lr_schedule.png - Learning rate schedule")
        print("    - confusion_matrix.png - Confusion matrix")
        print("    - ema_confusion_matrix.png - EMA Confusion matrix")
        print("    - precision_recall_f1.png - Classification metrics")
        print("    - tsne_visualization.png - t-SNE feature visualization")
        print("    - pca_visualization.png - PCA feature visualization")
        print("    - umap_visualization.png - UMAP feature visualization")
        print("    - sample_predictions.png - Sample predictions")
        print("  Model Files:")
        print("    - best_model.pkl - Best model checkpoint")
        print("    - checkpoint_epoch_*.pkl - Periodic checkpoints")
        print("  Text files:")
        print("    - enhanced_config.txt - Training configuration")
        print("    - training_summary.txt - Final results summary")
        print("    - hardware_summary.txt - Hardware information")
        print("\n")
        
    except Exception as e:
        print(f"\nTraining failed with error: {e}")
        print(f"Partial results saved to: {results_dir}")
        import traceback
        traceback.print_exc()
        return 1
    
    return 0


if __name__ == "__main__":
    exit_code = main()
    # exit(exit_code)


ENHANCED VISION TRANSFORMER WITH ADVANCED REGULARIZATION
Configuration:
  Architecture: 3 layers × 384 dim
  Regularization Techniques:
    - Dropout: 0.2 (attention: 0.3)
    - DropPath: 0.2
    - Batch Normalization: True
    - Data Augmentation: True
    - Label Smoothing: 0.1
    - EMA: True (decay: 0.9999)
    - Weight Decay: 0.05
  Training: 100 epochs × 256 batch size
  Learning rate: 0.0003 (warmup: 1000 steps)
  Device: cuda:0
  Results: enhanced_vit_results/run_20250507_053823

Loading CIFAR-10 dataset...
Loaded 50000 training samples and 10000 test samples
Creating initial visualizations...

Starting enhanced training with regularization...
Starting enhanced training on cuda:0
Total steps: 19500
Regularization: Dropout=0.2, BatchNorm=True
Data Augmentation: Enabled=True
Label Smoothing: 0.1


Epoch 1/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 1/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 1/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.1604
Epoch 1/100 completed in 273.65s
  Train Loss: 2.2110, Train Accuracy: 0.1749
  Eval Loss: 2.0424, Eval Accuracy: 0.2616
  Best Accuracy: 0.2616 at epoch 1


Epoch 2/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 2/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 2/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.1895
Epoch 2/100 completed in 6.92s
  Train Loss: 2.0734, Train Accuracy: 0.2567
  Eval Loss: 1.8708, Eval Accuracy: 0.3621
  Best Accuracy: 0.3621 at epoch 2


Epoch 3/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 3/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 3/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2040
Epoch 3/100 completed in 6.94s
  Train Loss: 1.9507, Train Accuracy: 0.3294
  Eval Loss: 1.7209, Eval Accuracy: 0.4476
  Best Accuracy: 0.4476 at epoch 3


Epoch 4/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 4/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 4/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2118
Epoch 4/100 completed in 6.95s
  Train Loss: 1.9091, Train Accuracy: 0.3592
  Eval Loss: 1.6695, Eval Accuracy: 0.4646
  Best Accuracy: 0.4646 at epoch 4


Epoch 5/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 5/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 5/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2181
Epoch 5/100 completed in 6.93s
  Train Loss: 1.8854, Train Accuracy: 0.3737
  Eval Loss: 1.6557, Eval Accuracy: 0.4691
  Best Accuracy: 0.4691 at epoch 5


Epoch 6/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 6/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 6/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2222
Epoch 6/100 completed in 6.93s
  Train Loss: 1.8783, Train Accuracy: 0.3741
  Eval Loss: 1.6333, Eval Accuracy: 0.4764
  Best Accuracy: 0.4764 at epoch 6


Epoch 7/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 7/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 7/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2276
Epoch 7/100 completed in 6.82s
  Train Loss: 1.8262, Train Accuracy: 0.4027
  Eval Loss: 1.6017, Eval Accuracy: 0.4934
  Best Accuracy: 0.4934 at epoch 7


Epoch 8/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 8/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 8/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2334
Epoch 8/100 completed in 6.91s
  Train Loss: 1.8208, Train Accuracy: 0.4049
  Eval Loss: 1.5812, Eval Accuracy: 0.5000
  Best Accuracy: 0.5000 at epoch 8


Epoch 9/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 9/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 9/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2354
Epoch 9/100 completed in 6.86s
  Train Loss: 1.8078, Train Accuracy: 0.4123
  Eval Loss: 1.5566, Eval Accuracy: 0.5128
  Best Accuracy: 0.5128 at epoch 9


Epoch 10/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 10/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 10/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2377
Epoch 10/100 completed in 6.74s
  Train Loss: 1.7842, Train Accuracy: 0.4232
  Eval Loss: 1.5434, Eval Accuracy: 0.5168
  Best Accuracy: 0.5168 at epoch 10


Epoch 11/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 11/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 11/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2408
Epoch 11/100 completed in 7.00s
  Train Loss: 1.7930, Train Accuracy: 0.4226
  Eval Loss: 1.5231, Eval Accuracy: 0.5230
  Best Accuracy: 0.5230 at epoch 11


Epoch 12/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 12/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 12/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2466
Epoch 12/100 completed in 6.92s
  Train Loss: 1.7554, Train Accuracy: 0.4421
  Eval Loss: 1.5116, Eval Accuracy: 0.5342
  Best Accuracy: 0.5342 at epoch 12


Epoch 13/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 13/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 13/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2503
Epoch 13/100 completed in 6.89s
  Train Loss: 1.7472, Train Accuracy: 0.4456
  Eval Loss: 1.5001, Eval Accuracy: 0.5393
  Best Accuracy: 0.5393 at epoch 13


Epoch 14/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 14/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 14/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2526
Epoch 14/100 completed in 6.88s
  Train Loss: 1.7591, Train Accuracy: 0.4411
  Eval Loss: 1.4963, Eval Accuracy: 0.5422
  Best Accuracy: 0.5422 at epoch 14


Epoch 15/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 15/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 15/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2549
Epoch 15/100 completed in 6.89s
  Train Loss: 1.7350, Train Accuracy: 0.4523
  Eval Loss: 1.4953, Eval Accuracy: 0.5343
  Best Accuracy: 0.5422 at epoch 14


Epoch 16/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 16/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 16/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2572
Epoch 16/100 completed in 6.99s
  Train Loss: 1.7351, Train Accuracy: 0.4548
  Eval Loss: 1.4746, Eval Accuracy: 0.5519
  Best Accuracy: 0.5519 at epoch 16


Epoch 17/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 17/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 17/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2633
Epoch 17/100 completed in 6.80s
  Train Loss: 1.7404, Train Accuracy: 0.4575
  Eval Loss: 1.4619, Eval Accuracy: 0.5556
  Best Accuracy: 0.5556 at epoch 17


Epoch 18/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 18/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 18/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2676
Epoch 18/100 completed in 6.99s
  Train Loss: 1.7325, Train Accuracy: 0.4606
  Eval Loss: 1.4589, Eval Accuracy: 0.5583
  Best Accuracy: 0.5583 at epoch 18


Epoch 19/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 19/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 19/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2738
Epoch 19/100 completed in 7.11s
  Train Loss: 1.7153, Train Accuracy: 0.4662
  Eval Loss: 1.4376, Eval Accuracy: 0.5661
  Best Accuracy: 0.5661 at epoch 19


Epoch 20/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 20/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 20/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2783
Epoch 20/100 completed in 6.68s
  Train Loss: 1.7118, Train Accuracy: 0.4688
  Eval Loss: 1.4436, Eval Accuracy: 0.5590
  Best Accuracy: 0.5661 at epoch 19


Epoch 21/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 21/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 21/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2849
Epoch 21/100 completed in 6.87s
  Train Loss: 1.6853, Train Accuracy: 0.4816
  Eval Loss: 1.4234, Eval Accuracy: 0.5749
  Best Accuracy: 0.5749 at epoch 21


Epoch 22/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 22/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 22/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.2920
Epoch 22/100 completed in 6.69s
  Train Loss: 1.6942, Train Accuracy: 0.4775
  Eval Loss: 1.4249, Eval Accuracy: 0.5729
  Best Accuracy: 0.5749 at epoch 21


Epoch 23/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 23/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 23/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.3007
Epoch 23/100 completed in 6.64s
  Train Loss: 1.6845, Train Accuracy: 0.4805
  Eval Loss: 1.4214, Eval Accuracy: 0.5731
  Best Accuracy: 0.5749 at epoch 21


Epoch 24/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 24/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 24/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.3086
Epoch 24/100 completed in 6.81s
  Train Loss: 1.7049, Train Accuracy: 0.4784
  Eval Loss: 1.4020, Eval Accuracy: 0.5819
  Best Accuracy: 0.5819 at epoch 24


Epoch 25/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 25/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 25/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.3182
Epoch 25/100 completed in 6.92s
  Train Loss: 1.7065, Train Accuracy: 0.4842
  Eval Loss: 1.3978, Eval Accuracy: 0.5847
  Best Accuracy: 0.5847 at epoch 25


Epoch 26/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 26/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 26/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.3266
Epoch 26/100 completed in 6.92s
  Train Loss: 1.6969, Train Accuracy: 0.4889
  Eval Loss: 1.3984, Eval Accuracy: 0.5880
  Best Accuracy: 0.5880 at epoch 26


Epoch 27/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 27/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 27/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.3364
Epoch 27/100 completed in 6.90s
  Train Loss: 1.6662, Train Accuracy: 0.4942
  Eval Loss: 1.3829, Eval Accuracy: 0.5955
  Best Accuracy: 0.5955 at epoch 27


Epoch 28/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 28/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 28/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.3435
Epoch 28/100 completed in 6.73s
  Train Loss: 1.6593, Train Accuracy: 0.4923
  Eval Loss: 1.3867, Eval Accuracy: 0.5843
  Best Accuracy: 0.5955 at epoch 27


Epoch 29/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 29/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 29/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.3530
Epoch 29/100 completed in 6.75s
  Train Loss: 1.6594, Train Accuracy: 0.5005
  Eval Loss: 1.3916, Eval Accuracy: 0.5850
  Best Accuracy: 0.5955 at epoch 27


Epoch 30/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 30/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 30/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.3632
Epoch 30/100 completed in 6.93s
  Train Loss: 1.6421, Train Accuracy: 0.5057
  Eval Loss: 1.3659, Eval Accuracy: 0.5970
  Best Accuracy: 0.5970 at epoch 30


Epoch 31/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 31/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 31/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.3745
Epoch 31/100 completed in 6.87s
  Train Loss: 1.6527, Train Accuracy: 0.5007
  Eval Loss: 1.3685, Eval Accuracy: 0.6048
  Best Accuracy: 0.6048 at epoch 31


Epoch 32/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 32/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 32/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.3835
Epoch 32/100 completed in 6.91s
  Train Loss: 1.6411, Train Accuracy: 0.5025
  Eval Loss: 1.3463, Eval Accuracy: 0.6065
  Best Accuracy: 0.6065 at epoch 32


Epoch 33/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 33/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 33/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.3917
Epoch 33/100 completed in 6.61s
  Train Loss: 1.6217, Train Accuracy: 0.5156
  Eval Loss: 1.3501, Eval Accuracy: 0.6063
  Best Accuracy: 0.6065 at epoch 32


Epoch 34/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 34/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 34/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.3997
Epoch 34/100 completed in 6.84s
  Train Loss: 1.6278, Train Accuracy: 0.5116
  Eval Loss: 1.3461, Eval Accuracy: 0.6066
  Best Accuracy: 0.6066 at epoch 34


Epoch 35/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 35/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 35/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.4102
Epoch 35/100 completed in 6.69s
  Train Loss: 1.6423, Train Accuracy: 0.5060
  Eval Loss: 1.3496, Eval Accuracy: 0.6064
  Best Accuracy: 0.6066 at epoch 34


Epoch 36/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 36/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 36/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.4189
Epoch 36/100 completed in 7.01s
  Train Loss: 1.5929, Train Accuracy: 0.5308
  Eval Loss: 1.3385, Eval Accuracy: 0.6163
  Best Accuracy: 0.6163 at epoch 36


Epoch 37/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 37/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 37/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.4274
Epoch 37/100 completed in 6.71s
  Train Loss: 1.6317, Train Accuracy: 0.5135
  Eval Loss: 1.3389, Eval Accuracy: 0.6132
  Best Accuracy: 0.6163 at epoch 36


Epoch 38/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 38/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 38/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.4347
Epoch 38/100 completed in 7.01s
  Train Loss: 1.6132, Train Accuracy: 0.5248
  Eval Loss: 1.3282, Eval Accuracy: 0.6164
  Best Accuracy: 0.6164 at epoch 38


Epoch 39/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 39/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 39/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.4420
Epoch 39/100 completed in 6.85s
  Train Loss: 1.6200, Train Accuracy: 0.5163
  Eval Loss: 1.3313, Eval Accuracy: 0.6171
  Best Accuracy: 0.6171 at epoch 39


Epoch 40/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 40/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 40/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.4503
Epoch 40/100 completed in 6.83s
  Train Loss: 1.6017, Train Accuracy: 0.5268
  Eval Loss: 1.3159, Eval Accuracy: 0.6214
  Best Accuracy: 0.6214 at epoch 40


Epoch 41/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 41/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 41/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.4568
Epoch 41/100 completed in 6.87s
  Train Loss: 1.6312, Train Accuracy: 0.5195
  Eval Loss: 1.3125, Eval Accuracy: 0.6286
  Best Accuracy: 0.6286 at epoch 41


Epoch 42/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 42/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 42/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.4633
Epoch 42/100 completed in 6.86s
  Train Loss: 1.6180, Train Accuracy: 0.5275
  Eval Loss: 1.3123, Eval Accuracy: 0.6294
  Best Accuracy: 0.6294 at epoch 42


Epoch 43/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 43/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 43/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.4707
Epoch 43/100 completed in 6.79s
  Train Loss: 1.6006, Train Accuracy: 0.5296
  Eval Loss: 1.3160, Eval Accuracy: 0.6269
  Best Accuracy: 0.6294 at epoch 42


Epoch 44/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 44/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 44/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.4796
Epoch 44/100 completed in 6.63s
  Train Loss: 1.6190, Train Accuracy: 0.5221
  Eval Loss: 1.3068, Eval Accuracy: 0.6284
  Best Accuracy: 0.6294 at epoch 42


Epoch 45/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 45/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 45/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.4859
Epoch 45/100 completed in 6.73s
  Train Loss: 1.6038, Train Accuracy: 0.5333
  Eval Loss: 1.3102, Eval Accuracy: 0.6250
  Best Accuracy: 0.6294 at epoch 42


Epoch 46/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 46/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 46/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.4913
Epoch 46/100 completed in 6.97s
  Train Loss: 1.5659, Train Accuracy: 0.5448
  Eval Loss: 1.2943, Eval Accuracy: 0.6376
  Best Accuracy: 0.6376 at epoch 46


Epoch 47/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 47/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 47/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.4979
Epoch 47/100 completed in 6.67s
  Train Loss: 1.5765, Train Accuracy: 0.5412
  Eval Loss: 1.2984, Eval Accuracy: 0.6347
  Best Accuracy: 0.6376 at epoch 46


Epoch 48/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 48/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 48/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5035
Epoch 48/100 completed in 6.73s
  Train Loss: 1.6077, Train Accuracy: 0.5358
  Eval Loss: 1.2968, Eval Accuracy: 0.6363
  Best Accuracy: 0.6376 at epoch 46


Epoch 49/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 49/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 49/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5090
Epoch 49/100 completed in 6.78s
  Train Loss: 1.5769, Train Accuracy: 0.5394
  Eval Loss: 1.3036, Eval Accuracy: 0.6324
  Best Accuracy: 0.6376 at epoch 46


Epoch 50/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 50/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 50/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5127
Epoch 50/100 completed in 6.81s
  Train Loss: 1.5709, Train Accuracy: 0.5465
  Eval Loss: 1.2882, Eval Accuracy: 0.6401
  Best Accuracy: 0.6401 at epoch 50


Epoch 51/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 51/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 51/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5174
Epoch 51/100 completed in 6.65s
  Train Loss: 1.5490, Train Accuracy: 0.5519
  Eval Loss: 1.2897, Eval Accuracy: 0.6363
  Best Accuracy: 0.6401 at epoch 50


Epoch 52/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 52/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 52/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5206
Epoch 52/100 completed in 6.89s
  Train Loss: 1.5219, Train Accuracy: 0.5631
  Eval Loss: 1.2974, Eval Accuracy: 0.6402
  Best Accuracy: 0.6402 at epoch 52


Epoch 53/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 53/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 53/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5281
Epoch 53/100 completed in 6.90s
  Train Loss: 1.5333, Train Accuracy: 0.5582
  Eval Loss: 1.2796, Eval Accuracy: 0.6478
  Best Accuracy: 0.6478 at epoch 53


Epoch 54/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 54/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 54/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5354
Epoch 54/100 completed in 6.62s
  Train Loss: 1.5229, Train Accuracy: 0.5605
  Eval Loss: 1.2787, Eval Accuracy: 0.6442
  Best Accuracy: 0.6478 at epoch 53


Epoch 55/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 55/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 55/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5396
Epoch 55/100 completed in 6.94s
  Train Loss: 1.5419, Train Accuracy: 0.5554
  Eval Loss: 1.2639, Eval Accuracy: 0.6533
  Best Accuracy: 0.6533 at epoch 55


Epoch 56/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 56/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 56/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5431
Epoch 56/100 completed in 6.74s
  Train Loss: 1.5337, Train Accuracy: 0.5627
  Eval Loss: 1.2768, Eval Accuracy: 0.6468
  Best Accuracy: 0.6533 at epoch 55


Epoch 57/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 57/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 57/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5474
Epoch 57/100 completed in 6.55s
  Train Loss: 1.5297, Train Accuracy: 0.5580
  Eval Loss: 1.2756, Eval Accuracy: 0.6513
  Best Accuracy: 0.6533 at epoch 55


Epoch 58/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 58/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 58/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5524
Epoch 58/100 completed in 6.70s
  Train Loss: 1.5263, Train Accuracy: 0.5662
  Eval Loss: 1.2722, Eval Accuracy: 0.6490
  Best Accuracy: 0.6533 at epoch 55


Epoch 59/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 59/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 59/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5581
Epoch 59/100 completed in 6.65s
  Train Loss: 1.5443, Train Accuracy: 0.5570
  Eval Loss: 1.2739, Eval Accuracy: 0.6517
  Best Accuracy: 0.6533 at epoch 55


Epoch 60/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 60/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 60/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5636
Epoch 60/100 completed in 6.81s
  Train Loss: 1.5557, Train Accuracy: 0.5601
  Eval Loss: 1.2652, Eval Accuracy: 0.6530
  Best Accuracy: 0.6533 at epoch 55


Epoch 61/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 61/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 61/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5675
Epoch 61/100 completed in 6.82s
  Train Loss: 1.5329, Train Accuracy: 0.5612
  Eval Loss: 1.2698, Eval Accuracy: 0.6576
  Best Accuracy: 0.6576 at epoch 61


Epoch 62/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 62/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 62/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5710
Epoch 62/100 completed in 6.88s
  Train Loss: 1.4820, Train Accuracy: 0.5818
  Eval Loss: 1.2614, Eval Accuracy: 0.6644
  Best Accuracy: 0.6644 at epoch 62


Epoch 63/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 63/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 63/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5747
Epoch 63/100 completed in 6.68s
  Train Loss: 1.4737, Train Accuracy: 0.5812
  Eval Loss: 1.2848, Eval Accuracy: 0.6515
  Best Accuracy: 0.6644 at epoch 62


Epoch 64/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 64/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 64/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5777
Epoch 64/100 completed in 6.84s
  Train Loss: 1.4871, Train Accuracy: 0.5792
  Eval Loss: 1.2705, Eval Accuracy: 0.6561
  Best Accuracy: 0.6644 at epoch 62


Epoch 65/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 65/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 65/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5819
Epoch 65/100 completed in 6.72s
  Train Loss: 1.5336, Train Accuracy: 0.5629
  Eval Loss: 1.2722, Eval Accuracy: 0.6640
  Best Accuracy: 0.6644 at epoch 62


Epoch 66/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 66/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 66/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5855
Epoch 66/100 completed in 6.71s
  Train Loss: 1.5153, Train Accuracy: 0.5734
  Eval Loss: 1.2762, Eval Accuracy: 0.6600
  Best Accuracy: 0.6644 at epoch 62


Epoch 67/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 67/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 67/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5898
Epoch 67/100 completed in 6.74s
  Train Loss: 1.4995, Train Accuracy: 0.5777
  Eval Loss: 1.2736, Eval Accuracy: 0.6608
  Best Accuracy: 0.6644 at epoch 62


Epoch 68/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 68/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 68/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5942
Epoch 68/100 completed in 6.73s
  Train Loss: 1.5425, Train Accuracy: 0.5651
  Eval Loss: 1.2641, Eval Accuracy: 0.6610
  Best Accuracy: 0.6644 at epoch 62


Epoch 69/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 69/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 69/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5946
Epoch 69/100 completed in 6.85s
  Train Loss: 1.4722, Train Accuracy: 0.5839
  Eval Loss: 1.2739, Eval Accuracy: 0.6645
  Best Accuracy: 0.6645 at epoch 69


Epoch 70/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 70/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 70/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5967
Epoch 70/100 completed in 6.88s
  Train Loss: 1.4638, Train Accuracy: 0.5874
  Eval Loss: 1.2721, Eval Accuracy: 0.6661
  Best Accuracy: 0.6661 at epoch 70


Epoch 71/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 71/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 71/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.5981
Epoch 71/100 completed in 6.71s
  Train Loss: 1.4646, Train Accuracy: 0.5890
  Eval Loss: 1.2777, Eval Accuracy: 0.6647
  Best Accuracy: 0.6661 at epoch 70


Epoch 72/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 72/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 72/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6003
Epoch 72/100 completed in 6.60s
  Train Loss: 1.4886, Train Accuracy: 0.5831
  Eval Loss: 1.2720, Eval Accuracy: 0.6650
  Best Accuracy: 0.6661 at epoch 70


Epoch 73/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 73/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 73/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6028
Epoch 73/100 completed in 6.66s
  Train Loss: 1.5054, Train Accuracy: 0.5701
  Eval Loss: 1.2804, Eval Accuracy: 0.6656
  Best Accuracy: 0.6661 at epoch 70


Epoch 74/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 74/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 74/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6049
Epoch 74/100 completed in 6.76s
  Train Loss: 1.4913, Train Accuracy: 0.5881
  Eval Loss: 1.2824, Eval Accuracy: 0.6648
  Best Accuracy: 0.6661 at epoch 70


Epoch 75/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 75/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 75/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6078
Epoch 75/100 completed in 6.62s
  Train Loss: 1.4887, Train Accuracy: 0.5830
  Eval Loss: 1.2814, Eval Accuracy: 0.6623
  Best Accuracy: 0.6661 at epoch 70


Epoch 76/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 76/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 76/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6105
Epoch 76/100 completed in 6.69s
  Train Loss: 1.4565, Train Accuracy: 0.5902
  Eval Loss: 1.2784, Eval Accuracy: 0.6637
  Best Accuracy: 0.6661 at epoch 70


Epoch 77/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 77/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 77/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6139
Epoch 77/100 completed in 6.86s
  Train Loss: 1.4660, Train Accuracy: 0.5943
  Eval Loss: 1.2831, Eval Accuracy: 0.6667
  Best Accuracy: 0.6667 at epoch 77


Epoch 78/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 78/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 78/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6161
Epoch 78/100 completed in 6.85s
  Train Loss: 1.4372, Train Accuracy: 0.6007
  Eval Loss: 1.2788, Eval Accuracy: 0.6692
  Best Accuracy: 0.6692 at epoch 78


Epoch 79/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 79/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 79/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6174
Epoch 79/100 completed in 6.75s
  Train Loss: 1.4883, Train Accuracy: 0.5821
  Eval Loss: 1.2758, Eval Accuracy: 0.6685
  Best Accuracy: 0.6692 at epoch 78


Epoch 80/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 80/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 80/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6184
Epoch 80/100 completed in 6.77s
  Train Loss: 1.4824, Train Accuracy: 0.5880
  Eval Loss: 1.2817, Eval Accuracy: 0.6646
  Best Accuracy: 0.6692 at epoch 78


Epoch 81/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 81/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 81/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6205
Epoch 81/100 completed in 6.93s
  Train Loss: 1.4482, Train Accuracy: 0.5959
  Eval Loss: 1.2828, Eval Accuracy: 0.6709
  Best Accuracy: 0.6709 at epoch 81


Epoch 82/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 82/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 82/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6219
Epoch 82/100 completed in 6.66s
  Train Loss: 1.4959, Train Accuracy: 0.5886
  Eval Loss: 1.2846, Eval Accuracy: 0.6672
  Best Accuracy: 0.6709 at epoch 81


Epoch 83/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 83/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 83/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6231
Epoch 83/100 completed in 6.62s
  Train Loss: 1.4832, Train Accuracy: 0.5899
  Eval Loss: 1.2796, Eval Accuracy: 0.6673
  Best Accuracy: 0.6709 at epoch 81


Epoch 84/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 84/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 84/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6246
Epoch 84/100 completed in 6.66s
  Train Loss: 1.4782, Train Accuracy: 0.5827
  Eval Loss: 1.2771, Eval Accuracy: 0.6699
  Best Accuracy: 0.6709 at epoch 81


Epoch 85/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 85/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 85/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6252
Epoch 85/100 completed in 6.62s
  Train Loss: 1.4724, Train Accuracy: 0.5915
  Eval Loss: 1.2788, Eval Accuracy: 0.6668
  Best Accuracy: 0.6709 at epoch 81


Epoch 86/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 86/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 86/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6264
Epoch 86/100 completed in 6.99s
  Train Loss: 1.4679, Train Accuracy: 0.5912
  Eval Loss: 1.2807, Eval Accuracy: 0.6715
  Best Accuracy: 0.6715 at epoch 86


Epoch 87/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 87/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 87/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6288
Epoch 87/100 completed in 6.73s
  Train Loss: 1.4447, Train Accuracy: 0.5931
  Eval Loss: 1.2824, Eval Accuracy: 0.6693
  Best Accuracy: 0.6715 at epoch 86


Epoch 88/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 88/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 88/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6305
Epoch 88/100 completed in 6.68s
  Train Loss: 1.4665, Train Accuracy: 0.5905
  Eval Loss: 1.2803, Eval Accuracy: 0.6670
  Best Accuracy: 0.6715 at epoch 86


Epoch 89/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 89/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 89/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6315
Epoch 89/100 completed in 6.69s
  Train Loss: 1.4942, Train Accuracy: 0.5834
  Eval Loss: 1.2828, Eval Accuracy: 0.6696
  Best Accuracy: 0.6715 at epoch 86


Epoch 90/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 90/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 90/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6350
Epoch 90/100 completed in 6.79s
  Train Loss: 1.4733, Train Accuracy: 0.5899
  Eval Loss: 1.2814, Eval Accuracy: 0.6708
  Best Accuracy: 0.6715 at epoch 86


Epoch 91/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 91/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 91/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6358
Epoch 91/100 completed in 6.71s
  Train Loss: 1.4825, Train Accuracy: 0.5845
  Eval Loss: 1.2817, Eval Accuracy: 0.6699
  Best Accuracy: 0.6715 at epoch 86


Epoch 92/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 92/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 92/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6375
Epoch 92/100 completed in 6.67s
  Train Loss: 1.4367, Train Accuracy: 0.6008
  Eval Loss: 1.2834, Eval Accuracy: 0.6688
  Best Accuracy: 0.6715 at epoch 86


Epoch 93/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 93/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 93/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6384
Epoch 93/100 completed in 6.53s
  Train Loss: 1.4374, Train Accuracy: 0.6024
  Eval Loss: 1.2850, Eval Accuracy: 0.6687
  Best Accuracy: 0.6715 at epoch 86


Epoch 94/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 94/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 94/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6394
Epoch 94/100 completed in 6.57s
  Train Loss: 1.4436, Train Accuracy: 0.6067
  Eval Loss: 1.2840, Eval Accuracy: 0.6684
  Best Accuracy: 0.6715 at epoch 86


Epoch 95/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 95/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 95/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6402
Epoch 95/100 completed in 6.47s
  Train Loss: 1.4744, Train Accuracy: 0.5971
  Eval Loss: 1.2829, Eval Accuracy: 0.6701
  Best Accuracy: 0.6715 at epoch 86


Epoch 96/100 [Train]:   0%|          | 0/195 [00:00<?, ?it/s]

Epoch 96/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 96/100 [Eval]:   0%|          | 0/40 [00:00<?, ?it/s]

EMA Eval Accuracy: 0.6421
Epoch 96/100 completed in 6.50s
  Train Loss: 1.4592, Train Accuracy: 0.5984
  Eval Loss: 1.2826, Eval Accuracy: 0.6684
  Best Accuracy: 0.6715 at epoch 86
Early stopping triggered at epoch 96
Running final evaluation with EMA parameters...


Final EMA Evaluation:   0%|          | 0/40 [00:00<?, ?it/s]


ENHANCED TRAINING COMPLETED SUCCESSFULLY!
Total time: 00:15:50
Best accuracy: 0.6715
Final F1 score: 0.6403
Results in: enhanced_vit_results/run_20250507_053823

Generated files:
  Visualizations:
    - training_metrics.png - Learning curves
    - lr_schedule.png - Learning rate schedule
    - confusion_matrix.png - Confusion matrix
    - ema_confusion_matrix.png - EMA Confusion matrix
    - precision_recall_f1.png - Classification metrics
    - tsne_visualization.png - t-SNE feature visualization
    - pca_visualization.png - PCA feature visualization
    - umap_visualization.png - UMAP feature visualization
    - sample_predictions.png - Sample predictions
  Model Files:
    - best_model.pkl - Best model checkpoint
    - checkpoint_epoch_*.pkl - Periodic checkpoints
  Text files:
    - enhanced_config.txt - Training configuration
    - training_summary.txt - Final results summary
    - hardware_summary.txt - Hardware information


