# MobileViT Deepfake Detection Training

This notebook trains a MobileViT model for deepfake detection using the Real and Fake Images Dataset.

**Dataset**: Real and Fake Images Dataset for Image Forensics  
**Model**: MobileViT-S optimized for deepfake detection  
**Platform**: Google Colab with GPU acceleration  

## 1. Setup and Installation

In [None]:
# Install required packages for Colab
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install tensorboard
!pip install timm
!pip install einops
!pip install scikit-learn
!pip install matplotlib seaborn
!pip install pyyaml
!pip install pillow

In [None]:
import os
import sys
import warnings
import json
from datetime import datetime
import random
import numpy as np

import torch
import torch.nn as nn
import yaml

warnings.filterwarnings('ignore')

# Check if we're in Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running in Google Colab")
except ImportError:
    IN_COLAB = False
    print("Not running in Colab")

## 2. Dataset Configuration

In [None]:
# Dataset configuration for Kaggle dataset in Colab
DATASET_PATH = "/root/.cache/kagglehub/datasets/shivamardeshna/real-and-fake-images-dataset-for-image-forensics/versions/1"

# Verify dataset path exists
if os.path.exists(DATASET_PATH):
    print(f"Dataset found at: {DATASET_PATH}")
    print(f"Dataset contents: {os.listdir(DATASET_PATH)}")
else:
    print(f"Dataset not found at: {DATASET_PATH}")
    print("Please ensure the dataset is downloaded using kagglehub")
    
# Configuration dictionary
config = {
    # Data configuration
    'data_dir': DATASET_PATH,
    'batch_size': 32,
    'image_size': 224,
    'augmentation': 'advanced',
    
    # Model configuration
    'model_name': 'mobilevit_s',
    'num_classes': 2,
    'pretrained': True,
    
    # Training configuration
    'epochs': 50,  # Reduced for Colab
    'learning_rate': 1e-4,
    'weight_decay': 1e-2,
    'mixed_precision': True,
    'gradient_clipping': 1.0,
    'patience': 10,
    
    # System configuration
    'num_workers': 2,  # Reduced for Colab
    'seed': 42,
    
    # Logging configuration
    'experiment_name': f"mobilevit_deepfake_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
    'log_dir': '/content/runs',
    'checkpoint_dir': '/content/checkpoints',
    'results_dir': '/content/results'
}

print(f"Experiment: {config['experiment_name']}")
print(f"Epochs: {config['epochs']}")
print(f"Batch size: {config['batch_size']}")

## 3. Utility Functions

In [None]:
def setup_directories(config: dict):
    """Create necessary directories for logging, checkpoints, and results."""
    directories = ['log_dir', 'checkpoint_dir', 'results_dir']
    
    for dir_key in directories:
        if dir_key in config:
            os.makedirs(config[dir_key], exist_ok=True)
            print(f"Created directory: {config[dir_key]}")


def set_random_seeds(seed: int):
    """Set random seeds for reproducible results."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    
    print(f"Random seed set to: {seed}")


def setup_device_and_memory():
    """Setup device and memory optimizations."""
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"GPU: {torch.cuda.get_device_name()}")
        print(f"CUDA Version: {torch.version.cuda}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
        
        # Clear cache
        torch.cuda.empty_cache()
    else:
        device = torch.device('cpu')
        print("Using CPU")
    
    return device


# Setup environment
setup_directories(config)
set_random_seeds(config['seed'])
device = setup_device_and_memory()
print(f"Using device: {device}")

## 4. Model Definition

Since we can't import from local modules in Colab, we'll define a simplified MobileViT model here.

In [None]:
# Simplified MobileViT implementation for Colab
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.SiLU()
    
    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


class MobileViTBlock(nn.Module):
    def __init__(self, in_channels, out_channels, patch_size=2, num_heads=4, mlp_ratio=2):
        super().__init__()
        self.patch_size = patch_size
        self.num_heads = num_heads
        
        # Local representation
        self.local_conv = nn.Sequential(
            ConvBlock(in_channels, in_channels),
            ConvBlock(in_channels, out_channels, kernel_size=1, padding=0)
        )
        
        # Global representation with self-attention
        self.transformer = nn.TransformerEncoderLayer(
            d_model=out_channels,
            nhead=num_heads,
            dim_feedforward=out_channels * mlp_ratio,
            dropout=0.1,
            batch_first=True
        )
        
        # Fusion
        self.fusion = ConvBlock(out_channels, out_channels, kernel_size=1, padding=0)
    
    def forward(self, x):
        # Local representation
        local_rep = self.local_conv(x)
        
        # Patch embedding for transformer
        B, C, H, W = local_rep.shape
        patch_h, patch_w = H // self.patch_size, W // self.patch_size
        
        # Reshape for transformer
        patches = rearrange(local_rep, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', 
                          p1=self.patch_size, p2=self.patch_size)
        
        # Apply transformer
        global_rep = self.transformer(patches)
        
        # Reshape back
        global_rep = rearrange(global_rep, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)',
                             h=patch_h, w=patch_w, p1=self.patch_size, p2=self.patch_size)
        
        # Fusion
        output = self.fusion(global_rep + local_rep)
        
        return output


class SimplifiedMobileViT(nn.Module):
    def __init__(self, num_classes=2, image_size=224):
        super().__init__()
        
        # Stem
        self.stem = nn.Sequential(
            ConvBlock(3, 32, stride=2),
            ConvBlock(32, 64)
        )
        
        # MobileViT stages
        self.stage1 = nn.Sequential(
            ConvBlock(64, 96),
            MobileViTBlock(96, 128)
        )
        
        self.stage2 = nn.Sequential(
            ConvBlock(128, 144, stride=2),
            MobileViTBlock(144, 192)
        )
        
        self.stage3 = nn.Sequential(
            ConvBlock(192, 240, stride=2),
            MobileViTBlock(240, 320)
        )
        
        # Classifier
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(320, num_classes)
        )
    
    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        
        x = self.global_pool(x)
        x = x.flatten(1)
        x = self.classifier(x)
        
        return x


def create_model(config: dict) -> nn.Module:
    """Create and initialize the model."""
    num_classes = config.get('num_classes', 2)
    image_size = config.get('image_size', 224)
    
    print(f"Creating MobileViT model with {num_classes} classes...")
    model = SimplifiedMobileViT(num_classes=num_classes, image_size=image_size)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    return model


# Create the model
model = create_model(config)
model = model.to(device)
print("Model created and moved to device")

## 5. Data Loading and Preprocessing

In [None]:
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import glob


class DeepfakeDataset(Dataset):
    def __init__(self, data_dir, transform=None, split='train'):
        self.data_dir = data_dir
        self.transform = transform
        self.split = split
        
        # Find image files
        self.images = []
        self.labels = []
        
        # Look for real and fake subdirectories
        real_dir = os.path.join(data_dir, 'real')
        fake_dir = os.path.join(data_dir, 'fake')
        
        # If direct real/fake dirs don't exist, look for other patterns
        if not os.path.exists(real_dir) or not os.path.exists(fake_dir):
            print(f"Looking for image files in: {data_dir}")
            all_dirs = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
            print(f"Found directories: {all_dirs}")
            
            # Try to find real and fake patterns
            for dir_name in all_dirs:
                dir_path = os.path.join(data_dir, dir_name)
                if 'real' in dir_name.lower() or 'authentic' in dir_name.lower():
                    real_dir = dir_path
                elif 'fake' in dir_name.lower() or 'manipulated' in dir_name.lower():
                    fake_dir = dir_path
        
        # Load real images
        if os.path.exists(real_dir):
            real_images = glob.glob(os.path.join(real_dir, '**/*.jpg'), recursive=True) + \
                         glob.glob(os.path.join(real_dir, '**/*.png'), recursive=True) + \
                         glob.glob(os.path.join(real_dir, '**/*.jpeg'), recursive=True)
            self.images.extend(real_images)
            self.labels.extend([0] * len(real_images))  # 0 for real
            print(f"Found {len(real_images)} real images")
        
        # Load fake images
        if os.path.exists(fake_dir):
            fake_images = glob.glob(os.path.join(fake_dir, '**/*.jpg'), recursive=True) + \
                         glob.glob(os.path.join(fake_dir, '**/*.png'), recursive=True) + \
                         glob.glob(os.path.join(fake_dir, '**/*.jpeg'), recursive=True)
            self.images.extend(fake_images)
            self.labels.extend([1] * len(fake_images))  # 1 for fake
            print(f"Found {len(fake_images)} fake images")
        
        print(f"Total {split} images: {len(self.images)}")
        
        if len(self.images) == 0:
            raise ValueError(f"No images found in {data_dir}")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        
        # Load image
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a blank image if loading fails
            image = Image.new('RGB', (224, 224), color='black')
        
        if self.transform:
            image = self.transform(image)
        
        return image, label


def get_transforms(image_size=224, augmentation='advanced'):
    """Get data transforms for training and validation."""
    
    # Base transforms
    base_transforms = [
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
    
    # Training transforms with augmentation
    if augmentation == 'advanced':
        train_transforms = transforms.Compose([
            transforms.Resize((int(image_size * 1.1), int(image_size * 1.1))),
            transforms.RandomCrop((image_size, image_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        train_transforms = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    # Validation transforms (no augmentation)
    val_transforms = transforms.Compose(base_transforms)
    
    return train_transforms, val_transforms


def create_data_loaders(config):
    """Create data loaders for training and validation."""
    
    train_transforms, val_transforms = get_transforms(
        image_size=config['image_size'],
        augmentation=config['augmentation']
    )
    
    # For this demo, we'll use the same directory and split manually
    # In practice, you'd have separate train/val/test directories
    dataset = DeepfakeDataset(
        data_dir=config['data_dir'],
        transform=train_transforms,
        split='full'
    )
    
    # Split dataset into train/val (80/20)
    total_size = len(dataset)
    train_size = int(0.8 * total_size)
    val_size = total_size - train_size
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size]
    )
    
    # Update validation dataset transform
    val_dataset.dataset.transform = val_transforms
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    return train_loader, val_loader


# Create data loaders
try:
    train_loader, val_loader = create_data_loaders(config)
    print("Data loaders created successfully")
except Exception as e:
    print(f"Error creating data loaders: {e}")
    print("Please check the dataset path and structure")

## 6. Training Setup

In [None]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns


class DeepfakeTrainer:
    def __init__(self, model, train_loader, val_loader, config):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        self.device = device
        
        # Setup optimizer and scheduler
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )
        
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            T_max=config['epochs'],
            eta_min=1e-6
        )
        
        # Loss function
        self.criterion = nn.CrossEntropyLoss()
        
        # Mixed precision scaler
        self.scaler = torch.cuda.amp.GradScaler() if config['mixed_precision'] else None
        
        # Training state
        self.current_epoch = 0
        self.best_f1 = 0.0
        self.train_losses = []
        self.val_losses = []
        self.val_f1_scores = []
        self.patience_counter = 0
    
    def train_epoch(self):
        """Train for one epoch."""
        self.model.train()
        total_loss = 0.0
        num_batches = len(self.train_loader)
        
        for batch_idx, (data, target) in enumerate(self.train_loader):
            data, target = data.to(self.device), target.to(self.device)
            
            self.optimizer.zero_grad()
            
            if self.config['mixed_precision']:
                with torch.cuda.amp.autocast():
                    output = self.model(data)
                    loss = self.criterion(output, target)
                
                self.scaler.scale(loss).backward()
                
                if self.config['gradient_clipping'] > 0:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clipping'])
                
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                output = self.model(data)
                loss = self.criterion(output, target)
                loss.backward()
                
                if self.config['gradient_clipping'] > 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clipping'])
                
                self.optimizer.step()
            
            total_loss += loss.item()
            
            # Print progress
            if batch_idx % 50 == 0:
                print(f'Epoch {self.current_epoch}, Batch {batch_idx}/{num_batches}, Loss: {loss.item():.4f}')
        
        avg_loss = total_loss / num_batches
        self.train_losses.append(avg_loss)
        return avg_loss
    
    def validate(self):
        """Validate the model."""
        self.model.eval()
        total_loss = 0.0
        all_predictions = []
        all_targets = []
        
        with torch.no_grad():
            for data, target in self.val_loader:
                data, target = data.to(self.device), target.to(self.device)
                
                if self.config['mixed_precision']:
                    with torch.cuda.amp.autocast():
                        output = self.model(data)
                        loss = self.criterion(output, target)
                else:
                    output = self.model(data)
                    loss = self.criterion(output, target)
                
                total_loss += loss.item()
                
                # Get predictions
                pred = output.argmax(dim=1)
                all_predictions.extend(pred.cpu().numpy())
                all_targets.extend(target.cpu().numpy())
        
        # Calculate metrics
        avg_loss = total_loss / len(self.val_loader)
        accuracy = accuracy_score(all_targets, all_predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_predictions, average='weighted')
        
        self.val_losses.append(avg_loss)
        self.val_f1_scores.append(f1)
        
        return {
            'loss': avg_loss,
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1
        }
    
    def save_checkpoint(self, filepath, is_best=False):
        """Save model checkpoint."""
        checkpoint = {
            'epoch': self.current_epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_f1': self.best_f1,
            'config': self.config
        }
        
        torch.save(checkpoint, filepath)
        
        if is_best:
            best_path = os.path.join(os.path.dirname(filepath), 'best_model.pth')
            torch.save(checkpoint, best_path)
    
    def train(self):
        """Main training loop."""
        print(f"Starting training for {self.config['epochs']} epochs...")
        
        for epoch in range(self.config['epochs']):
            self.current_epoch = epoch
            
            # Train
            train_loss = self.train_epoch()
            
            # Validate
            val_metrics = self.validate()
            
            # Update scheduler
            self.scheduler.step()
            
            # Print epoch results
            print(f"\nEpoch {epoch}/{self.config['epochs']}:")
            print(f"Train Loss: {train_loss:.4f}")
            print(f"Val Loss: {val_metrics['loss']:.4f}")
            print(f"Val Accuracy: {val_metrics['accuracy']:.4f}")
            print(f"Val F1: {val_metrics['f1']:.4f}")
            print(f"Learning Rate: {self.optimizer.param_groups[0]['lr']:.6f}")
            
            # Save checkpoint
            is_best = val_metrics['f1'] > self.best_f1
            if is_best:
                self.best_f1 = val_metrics['f1']
                self.patience_counter = 0
                print(f"New best F1 score: {self.best_f1:.4f}")
            else:
                self.patience_counter += 1
            
            # Save checkpoint
            checkpoint_path = os.path.join(self.config['checkpoint_dir'], f'checkpoint_epoch_{epoch}.pth')
            self.save_checkpoint(checkpoint_path, is_best)
            
            # Early stopping
            if self.patience_counter >= self.config['patience']:
                print(f"Early stopping triggered after {epoch} epochs")
                break
            
            print("-" * 60)
        
        return {
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'val_f1_scores': self.val_f1_scores
        }


# Create trainer
trainer = DeepfakeTrainer(model, train_loader, val_loader, config)
print("Trainer initialized")

## 7. Start Training

In [None]:
# Start training
print("Starting training...")
print(f"Device: {device}")
print(f"Mixed precision: {config['mixed_precision']}")
print(f"Batch size: {config['batch_size']}")
print(f"Learning rate: {config['learning_rate']}")
print("=" * 60)

# Train the model
training_history = trainer.train()

print("\nTraining completed!")
print(f"Best validation F1 score: {trainer.best_f1:.4f}")

## 8. Training Visualization

In [None]:
# Plot training history
plt.figure(figsize=(15, 5))

# Plot losses
plt.subplot(1, 3, 1)
plt.plot(training_history['train_losses'], label='Train Loss')
plt.plot(training_history['val_losses'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Plot F1 scores
plt.subplot(1, 3, 2)
plt.plot(training_history['val_f1_scores'], label='Validation F1', color='orange')
plt.title('Validation F1 Score')
plt.xlabel('Epoch')
plt.ylabel('F1 Score')
plt.legend()
plt.grid(True)

# Plot learning rate (if available)
plt.subplot(1, 3, 3)
epochs = range(len(training_history['train_losses']))
lrs = [trainer.scheduler.get_last_lr()[0] for _ in epochs]
plt.plot(epochs, lrs, label='Learning Rate', color='green')
plt.title('Learning Rate Schedule')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.legend()
plt.grid(True)
plt.yscale('log')

plt.tight_layout()
plt.savefig(os.path.join(config['results_dir'], 'training_history.png'), dpi=300, bbox_inches='tight')
plt.show()

print(f"Training plots saved to {config['results_dir']}")

## 9. Model Evaluation

In [None]:
# Load best model for evaluation
best_model_path = os.path.join(config['checkpoint_dir'], 'best_model.pth')

if os.path.exists(best_model_path):
    print(f"Loading best model from {best_model_path}")
    checkpoint = torch.load(best_model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Best model F1 score: {checkpoint['best_f1']:.4f}")
else:
    print("No best model checkpoint found, using current model")

# Evaluate on validation set
model.eval()
all_predictions = []
all_targets = []
all_probs = []

with torch.no_grad():
    for data, target in val_loader:
        data, target = data.to(device), target.to(device)
        
        with torch.cuda.amp.autocast() if config['mixed_precision'] else torch.no_grad():
            output = model(data)
        
        probs = torch.softmax(output, dim=1)
        pred = output.argmax(dim=1)
        
        all_predictions.extend(pred.cpu().numpy())
        all_targets.extend(target.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

# Calculate final metrics
final_accuracy = accuracy_score(all_targets, all_predictions)
final_precision, final_recall, final_f1, _ = precision_recall_fscore_support(
    all_targets, all_predictions, average='weighted'
)

print("\n" + "="*60)
print("FINAL EVALUATION RESULTS")
print("="*60)
print(f"Accuracy: {final_accuracy:.4f}")
print(f"Precision: {final_precision:.4f}")
print(f"Recall: {final_recall:.4f}")
print(f"F1 Score: {final_f1:.4f}")
print("="*60)

# Confusion Matrix
cm = confusion_matrix(all_targets, all_predictions)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Real', 'Fake'], yticklabels=['Real', 'Fake'])
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.savefig(os.path.join(config['results_dir'], 'confusion_matrix.png'), dpi=300, bbox_inches='tight')
plt.show()

# Save final results
final_results = {
    'experiment_name': config['experiment_name'],
    'final_metrics': {
        'accuracy': float(final_accuracy),
        'precision': float(final_precision),
        'recall': float(final_recall),
        'f1': float(final_f1)
    },
    'training_config': config,
    'best_epoch_f1': float(trainer.best_f1)
}

results_path = os.path.join(config['results_dir'], 'final_results.json')
with open(results_path, 'w') as f:
    json.dump(final_results, f, indent=2)

print(f"\nResults saved to: {results_path}")
print(f"All outputs saved to: {config['results_dir']}")

## 10. Download Results (Optional)

In [None]:
# Compress results for download
import zipfile

def create_results_zip():
    zip_path = '/content/deepfake_detection_results.zip'
    
    with zipfile.ZipFile(zip_path, 'w') as zipf:
        # Add results directory
        for root, dirs, files in os.walk(config['results_dir']):
            for file in files:
                file_path = os.path.join(root, file)
                arc_name = os.path.relpath(file_path, '/content')
                zipf.write(file_path, arc_name)
        
        # Add best model checkpoint
        best_model_path = os.path.join(config['checkpoint_dir'], 'best_model.pth')
        if os.path.exists(best_model_path):
            arc_name = os.path.relpath(best_model_path, '/content')
            zipf.write(best_model_path, arc_name)
    
    print(f"Results compressed to: {zip_path}")
    return zip_path

# Create zip file
zip_path = create_results_zip()

# Download in Colab
if IN_COLAB:
    from google.colab import files
    print("Downloading results...")
    files.download(zip_path)
else:
    print(f"Results available at: {zip_path}")

print("\n" + "="*60)
print("TRAINING COMPLETED SUCCESSFULLY!")
print(f"Final F1 Score: {final_f1:.4f}")
print(f"Experiment: {config['experiment_name']}")
print("="*60)