In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.models as models
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import gc
import cv2
import timm
from timm.models.vision_transformer import VisionTransformer

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# GPU Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Memory management functions
def free_memory():
    """Free memory to avoid OOM errors"""
    gc.collect()
    torch.cuda.empty_cache()

def print_gpu_memory():
    """Print available and allocated GPU memory"""
    if torch.cuda.is_available():
        print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
        # Calculate available memory
        free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()
        print(f"GPU memory available: {free_memory / 1e9:.2f} GB")

# Custom dataset class for videos
class DeepFakeVideoDataset(Dataset):
    def __init__(self, video_paths, labels, num_frames=16, transform=None, max_frame_retries=3):
        self.video_paths = video_paths
        self.labels = labels
        self.num_frames = num_frames
        self.transform = transform
        self.max_frame_retries = max_frame_retries

        # Pre-check videos and filter out problematic ones
        self._filter_valid_videos()

    def _filter_valid_videos(self):
        """Pre-check videos and remove invalid ones."""
        valid_videos = []
        valid_labels = []

        for i, video_path in enumerate(tqdm(self.video_paths, desc="Validating videos")):
            try:
                cap = cv2.VideoCapture(video_path)
                if not cap.isOpened():
                    print(f"Warning: Could not open video {video_path}")
                    continue

                # Check if video has frames
                ret, _ = cap.read()
                if not ret:
                    print(f"Warning: No frames in video {video_path}")
                    continue

                # Video is valid
                valid_videos.append(video_path)
                valid_labels.append(self.labels[i])

                cap.release()
            except Exception as e:
                print(f"Error validating video {video_path}: {e}")

        # Update videos and labels
        self.video_paths = valid_videos
        self.labels = valid_labels
        print(f"Kept {len(valid_videos)} valid videos out of {len(self.video_paths)} total")

    def __len__(self):
        return len(self.video_paths)

    def extract_frames(self, video_path):
        frames = []
        retry_count = 0

        while retry_count < self.max_frame_retries:
            try:
                cap = cv2.VideoCapture(video_path)
                if not cap.isOpened():
                    raise ValueError(f"Cannot open video file: {video_path}")

                # Get total number of frames
                total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

                if total_frames <= 0:
                    raise ValueError(f"Video has no frames: {video_path}")

                # Calculate sampling interval to get num_frames evenly distributed
                if total_frames <= self.num_frames:
                    # If video has fewer frames than needed, duplicate frames
                    indices = list(range(total_frames)) * (self.num_frames // total_frames + 1)
                    indices = indices[:self.num_frames]
                else:
                    # Sample frames evenly
                    indices = np.linspace(0, total_frames - 1, self.num_frames, dtype=int)

                for idx in indices:
                    cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
                    ret, frame = cap.read()
                    if ret:
                        # Convert BGR to RGB
                        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                        if self.transform:
                            frame = self.transform(image=frame)["image"]
                        frames.append(frame)
                    else:
                        # If frame read fails, add a blank frame
                        blank_frame = np.zeros((224, 224, 3), dtype=np.uint8)
                        if self.transform:
                            blank_frame = self.transform(image=blank_frame)["image"]
                        frames.append(blank_frame)

                cap.release()

                # If we got all frames successfully, break out of retry loop
                if len(frames) == self.num_frames:
                    break

            except Exception as e:
                print(f"Error extracting frames (attempt {retry_count+1}): {e}")
                retry_count += 1

                # Clear frames and retry
                frames = []
                continue

            finally:
                # Make sure to release the capture object
                if 'cap' in locals() and cap is not None:
                    cap.release()

        # If we still don't have enough frames after all retries, create dummy frames
        if len(frames) < self.num_frames:
            missing_frames = self.num_frames - len(frames)
            print(f"Warning: Using {missing_frames} dummy frames for {video_path}")

            blank_frame = np.zeros((224, 224, 3), dtype=np.uint8)
            for _ in range(missing_frames):
                if self.transform:
                    dummy = self.transform(image=blank_frame)["image"]
                    frames.append(dummy)
                else:
                    # Convert numpy to tensor if no transform
                    dummy = torch.from_numpy(blank_frame.transpose(2, 0, 1)).float() / 255.0
                    frames.append(dummy)

        # Stack frames into a tensor
        frames = torch.stack(frames)
        return frames

    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        label = self.labels[idx]

        try:
            frames = self.extract_frames(video_path)
            return frames, torch.tensor(label, dtype=torch.long)
        except Exception as e:
            print(f"Error loading video at index {idx}, path: {video_path}, Error: {e}")
            # Return a dummy tensor and the label
            dummy_frames = torch.zeros((self.num_frames, 3, 224, 224))
            return dummy_frames, torch.tensor(label, dtype=torch.long)

# Vision Transformer for Video - Fixed Implementation
class ViTVideoClassifier(nn.Module):
    def __init__(self, num_classes=2, num_frames=16, pretrained=True, memory_efficient=True):
        super(ViTVideoClassifier, self).__init__()

        # Use a smaller ViT model for memory efficiency if needed
        if memory_efficient:
            # Use ViT-Small instead of ViT-Base or larger versions
            self.vit_encoder = timm.create_model('vit_small_patch16_224', pretrained=pretrained)
            embed_dim = 384  # ViT-Small embedding dimension
        else:
            # Use ViT-Base
            self.vit_encoder = timm.create_model('vit_base_patch16_224', pretrained=pretrained)
            embed_dim = 768  # ViT-Base embedding dimension

        # Remove the classification head
        self.vit_encoder.head = nn.Identity()

        # Create a proper temporal encoder with transformer architecture
        # First create the encoder layer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=8,
            dim_feedforward=1024,
            dropout=0.1,
            batch_first=True
        )

        # Then create the transformer encoder with multiple layers
        self.temporal_encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers=2
        )

        # Separate CNN feature extractor for temporal features if needed
        self.temporal_cnn = nn.Sequential(
            nn.Conv3d(3, 16, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
        )

        # Classifier (Decoder) - Simplified to work with the transformer outputs
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

        # Positional encoding for frames
        self.pos_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim))

        # Initialize positional embeddings
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        # x shape: [batch_size, num_frames, channels, height, width]
        batch_size, num_frames, c, h, w = x.shape

        # Process each frame with ViT encoder
        frame_features = []
        for i in range(num_frames):
            # Extract features for the current frame
            features = self.vit_encoder(x[:, i])  # [batch_size, embed_dim]
            frame_features.append(features)

        # Stack frame features along the temporal dimension
        x = torch.stack(frame_features, dim=1)  # [batch_size, num_frames, embed_dim]

        # Add positional embeddings
        x = x + self.pos_embed

        # Apply temporal encoder (Transformer)
        x = self.temporal_encoder(x)

        # Global temporal pooling (mean pooling across frames)
        x = torch.mean(x, dim=1)  # [batch_size, embed_dim]

        # Classification (Decoder)
        x = self.decoder(x)

        return x

# Training function
def train_epoch(model, dataloader, criterion, optimizer, scheduler, device):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []

    for frames, labels in tqdm(dataloader, desc="Training"):
        frames, labels = frames.to(device), labels.to(device)

        # Clear gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(frames)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Update learning rate
        if scheduler is not None:
            scheduler.step()

        # Collect statistics
        running_loss += loss.item() * frames.size(0)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

        # Free up memory
        del frames, labels, outputs, loss, preds

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = accuracy_score(all_labels, all_preds)

    return epoch_loss, epoch_acc

# Mixed precision training function
def train_epoch_mixed_precision(model, dataloader, criterion, optimizer, scheduler, device, scaler):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []

    for frames, labels in tqdm(dataloader, desc="Training (mixed precision)"):
        frames, labels = frames.to(device), labels.to(device)

        # Clear gradients
        optimizer.zero_grad()

        # Forward pass with mixed precision - handle both older and newer PyTorch versions
        try:
            # For newer PyTorch versions
            from torch.amp import autocast
            with autocast(device_type='cuda'):
                outputs = model(frames)
                loss = criterion(outputs, labels)
        except (ImportError, AttributeError):
            # For older PyTorch versions
            from torch.cuda.amp import autocast
            with autocast():
                outputs = model(frames)
                loss = criterion(outputs, labels)

        # Backward pass and optimize with scaler
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Update learning rate
        if scheduler is not None:
            scheduler.step()

        # Collect statistics
        running_loss += loss.item() * frames.size(0)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

        # Free up memory
        del frames, labels, outputs, loss, preds

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = accuracy_score(all_labels, all_preds)

    return epoch_loss, epoch_acc

# Validation function
def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for frames, labels in tqdm(dataloader, desc="Validation"):
            frames, labels = frames.to(device), labels.to(device)

            # Forward pass
            outputs = model(frames)
            loss = criterion(outputs, labels)

            # Collect statistics
            running_loss += loss.item() * frames.size(0)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            # Free up memory
            del frames, labels, outputs, loss, preds

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = accuracy_score(all_labels, all_preds)
    report = classification_report(all_labels, all_preds)
    conf_matrix = confusion_matrix(all_labels, all_preds)

    return epoch_loss, epoch_acc, report, conf_matrix

# Main training loop with memory optimization
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, checkpoint_path="model_checkpoints", scaler=None):
    # Create checkpoint directory if it doesn't exist
    os.makedirs(checkpoint_path, exist_ok=True)

    best_val_acc = 0.0
    train_losses, train_accs = [], []
    val_losses, val_accs = [], []

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 20)

        # Print GPU memory usage before training
        print("Before training:")
        print_gpu_memory()

        # Train for one epoch
        if scaler is not None:
            train_loss, train_acc = train_epoch_mixed_precision(model, train_loader, criterion, optimizer, scheduler, device, scaler)
        else:
            train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, scheduler, device)

        # Free memory
        free_memory()

        # Print GPU memory usage after training
        print("After training, before validation:")
        print_gpu_memory()

        # Validate
        val_loss, val_acc, report, conf_matrix = validate(model, val_loader, criterion, device)

        # Free memory
        free_memory()

        # Print statistics
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        # print("\nClassification Report:")
        # print(report)

        # Save statistics
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        # Save model if it's the best so far
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            print(f"New best validation accuracy: {best_val_acc:.4f}")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'val_loss': val_loss,
            }, os.path.join(checkpoint_path, f"best_model_epoch_{epoch+1}.pth"))

        # Save checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'val_loss': val_loss,
            }, os.path.join(checkpoint_path, f"checkpoint_epoch_{epoch+1}.pth"))

        print(f"GPU memory after epoch {epoch+1}:")
        print_gpu_memory()

    # Plot training history
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train')
    plt.plot(val_losses, label='Validation')
    plt.title('Loss over epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train')
    plt.plot(val_accs, label='Validation')
    plt.title('Accuracy over epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.close()

    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Real', 'Fake'], yticklabels=['Real', 'Fake'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.png')
    plt.close()

    return train_losses, train_accs, val_losses, val_accs

# Main execution code
def main():
    # Hyperparameters optimized for Celeb-DF-v2 dataset with 15GB GPU
    BATCH_SIZE = 2  # Very small batch size to save memory
    NUM_FRAMES = 6  # Reduced for memory efficiency
    NUM_EPOCHS = 10  # As requested
    LEARNING_RATE = 2e-5  # Slightly lower learning rate for stability
    NUM_WORKERS = 2  # Adjust based on your system

    # Enable mixed precision for better memory efficiency
    USE_MIXED_PRECISION = True

    # Dataset balancing and limiting (in case the dataset is too large)
    MAX_VIDEOS_PER_CLASS = 1000  # Limit videos per class if needed to fit in memory
    BALANCED_DATASET = True  # Ensure equal number of real/fake samples

    # Progress tracking on console
    print("Starting DeepFake detection pipeline...")

    # Data transformations
    print("Setting up data transformations...")
    train_transforms = A.Compose([
        A.Resize(224, 224),
        A.HorizontalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.2),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

    val_transforms = A.Compose([
        A.Resize(224, 224),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

    # Load dataset files directly from Celeb-DF-v2 dataset structure
    data_path = "/kaggle/input/1000-df/dataset"
    print("Loading dataset from:", data_path)

    # Celeb-DF typically has this structure:
    # - Celeb-real (real videos)
    # - Celeb-synthesis (fake/deepfake videos)
    # - YouTube-real (additional real videos)

    real_paths = []
    fake_paths = []

    # Function to recursively find all video files in a directory
    def find_videos(directory, extensions=('.mp4', '.avi')):
        video_paths = []
        if not os.path.exists(directory):
            print(f"Directory not found: {directory}")
            return video_paths

        print(f"Scanning directory: {directory}")
        for root, _, files in os.walk(directory):
            for file in files:
                if file.lower().endswith(extensions):
                    video_paths.append(os.path.join(root, file))
        return video_paths

    # Different possible dataset structures for Celeb-DF-v2
    possible_structures = [
        # Standard structure
        {
            "real_dirs": ["real"],
            "fake_dirs": ["fake"]
        },
        # Alternative structure
        {
            "real_dirs": ["real"],
            "fake_dirs": ["fake"]
        },
        # Another possible structure
        {
            "real_dirs": ["original"],
            "fake_dirs": ["manipulated", "deepfake"]
        }
    ]

    # Try each structure until we find videos
    for structure in possible_structures:
        real_paths = []
        fake_paths = []

        # Check real directories
        for real_dir in structure["real_dirs"]:
            dir_path = os.path.join(data_path, real_dir)
            real_paths.extend(find_videos(dir_path))

        # Check fake directories
        for fake_dir in structure["fake_dirs"]:
            dir_path = os.path.join(data_path, fake_dir)
            fake_paths.extend(find_videos(dir_path))

        # If we found videos, break out of the loop
        if real_paths and fake_paths:
            print(f"Found valid dataset structure with {len(real_paths)} real videos and {len(fake_paths)} fake videos")
            break

    # If we still don't have videos, try scanning the entire directory
    if not real_paths or not fake_paths:
        print("Could not find videos in expected directories. Scanning entire dataset directory...")

        # Look for directories or filename patterns that might indicate real/fake
        all_videos = find_videos(data_path)

        # Try to classify based on filename/path
        for video_path in all_videos:
            lower_path = video_path.lower()
            if any(keyword in lower_path for keyword in ["real", "original", "true"]):
                real_paths.append(video_path)
            elif any(keyword in lower_path for keyword in ["fake", "synthesis", "deepfake", "manipulated"]):
                fake_paths.append(video_path)
            else:
                # If can't determine, default to real
                real_paths.append(video_path)

        print(f"After scanning, found {len(real_paths)} likely real videos and {len(fake_paths)} likely fake videos")

    # Balance and limit dataset if needed
    if BALANCED_DATASET:
        # Ensure equal number of real and fake samples
        min_class_count = min(len(real_paths), len(fake_paths))

        # Further limit if MAX_VIDEOS_PER_CLASS is specified
        if MAX_VIDEOS_PER_CLASS > 0:
            min_class_count = min(min_class_count, MAX_VIDEOS_PER_CLASS)

        # Randomly sample from both classes
        np.random.seed(42)  # For reproducibility
        sampled_real_paths = np.random.choice(real_paths, min_class_count, replace=False)
        sampled_fake_paths = np.random.choice(fake_paths, min_class_count, replace=False)

        # Create balanced dataset
        video_paths = list(sampled_real_paths) + list(sampled_fake_paths)
        labels = [0] * len(sampled_real_paths) + [1] * len(sampled_fake_paths)  # 0 for real, 1 for fake
    else:
        # Use all videos (might be imbalanced)
        video_paths = real_paths + fake_paths
        labels = [0] * len(real_paths) + [1] * len(fake_paths)  # 0 for real, 1 for fake

    print(f"Total videos in dataset: {len(video_paths)}")
    print(f"Real videos: {len(real_paths)} (using {np.sum(np.array(labels) == 0)})")
    print(f"Fake videos: {len(fake_paths)} (using {np.sum(np.array(labels) == 1)})")

    # Split the data
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        video_paths, labels, test_size=0.2, random_state=42, stratify=labels
    )

    # Create datasets with progress information
    print(f"Creating training dataset with {len(train_paths)} videos...")
    train_dataset = DeepFakeVideoDataset(train_paths, train_labels, num_frames=NUM_FRAMES, transform=train_transforms)

    print(f"Creating validation dataset with {len(val_paths)} videos...")
    val_dataset = DeepFakeVideoDataset(val_paths, val_labels, num_frames=NUM_FRAMES, transform=val_transforms)

    print(f"Final dataset sizes - Training: {len(train_dataset)}, Validation: {len(val_dataset)}")

    # Create data loaders with appropriate memory settings
    print("Creating data loaders...")
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        persistent_workers=True if NUM_WORKERS > 0 else False,
        drop_last=True  # Drop last incomplete batch to avoid size mismatch issues
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        persistent_workers=True if NUM_WORKERS > 0 else False,
        drop_last=False  # Keep all validation samples
    )

    # Initialize model with proper seeding for reproducibility
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    print("Initializing Vision Transformer model...")
    model = ViTVideoClassifier(num_classes=2, num_frames=NUM_FRAMES, pretrained=True, memory_efficient=True)
    model = model.to(device)

    # Print model architecture and parameter count
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"\nTotal parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Non-trainable parameters: {total_params - trainable_params:,}")

    # Optional: Print detailed model summary
    # print(model)

    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)

    # Learning rate scheduler
    total_steps = len(train_loader) * NUM_EPOCHS
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=LEARNING_RATE, total_steps=total_steps
    )

    # Initialize mixed precision scaler if enabled
    scaler = None
    if USE_MIXED_PRECISION:
        print("Using mixed precision training for better memory efficiency")
        try:
            # Try the new import format first (for newer PyTorch versions)
            from torch.amp import GradScaler
            scaler = GradScaler()
        except (ImportError, AttributeError):
            # Fall back to the older import format
            from torch.cuda.amp import GradScaler
            scaler = GradScaler()

        # Suppress FutureWarning about GradScaler deprecation
        import warnings
        warnings.filterwarnings("ignore", category=FutureWarning)

    # Initial GPU memory check
    print("Initial GPU memory:")
    print_gpu_memory()

    # Train the model
    train_losses, train_accs, val_losses, val_accs = train_model(
        model, train_loader, val_loader, criterion, optimizer, scheduler, NUM_EPOCHS, device,
        checkpoint_path="model_checkpoints", scaler=scaler
    )

    # Final evaluation on validation set
    model.load_state_dict(torch.load('./model_checkpoints/best_model_epoch_x.pth')['model_state_dict'])
    val_loss, val_acc, report, conf_matrix = validate(model, val_loader, criterion, device)

    print("\nFinal model performance:")
    print(f"Validation Accuracy: {val_acc:.4f}")
    print("\nClassification Report:")
    print(report)

    # Save final model
    torch.save({
        'model_state_dict': model.state_dict(),
        'val_acc': val_acc,
    }, 'final_deepfake_detector.pth')

    print("Training completed successfully!")

if __name__ == "__main__":
    main()