In [1]:
"""
Advanced TCN Deepfake Detection - Optimized for 90%+ Accuracy
Enhanced temporal modeling with attention mechanisms and advanced training
"""

import subprocess
import sys
import os

def install(package):
    """Auto-install missing packages"""
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", package, "-q"])
        print(f"‚úì Installed {package}")
    except:
        print(f"! Warning: Could not install {package}, continuing...")

# Auto-install required packages
required_packages = {
    'torch': 'torch torchvision',
    'cv2': 'opencv-python',
    'numpy': 'numpy',
    'sklearn': 'scikit-learn',
    'PIL': 'Pillow',
    'tqdm': 'tqdm',
    'matplotlib': 'matplotlib'
}

print("=== Checking and installing dependencies ===")
for module, package in required_packages.items():
    try:
        __import__(module)
    except ImportError:
        print(f"Installing {package}...")
        install(package)

# Now import everything
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, roc_auc_score
import random
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n=== Using device: {DEVICE} ===")

# Updated Paths
DATA_PATHS = {
    'real': r"C:\Users\dibya\OneDrive\Desktop\FF++\real",
    'fake': r"C:\Users\dibya\OneDrive\Desktop\FF++\fake"
}

# Optimized Hyperparameters for High Accuracy
IMG_SIZE = 224  # Increased for better feature extraction
SEQUENCE_LENGTH = 24  # Longer sequences for better temporal patterns
BATCH_SIZE = 4  # Smaller batch for more stable gradients
EPOCHS = 18  # Reduced epochs for faster testing
LEARNING_RATE = 0.00005  # Lower learning rate for fine-grained optimization
WEIGHT_DECAY = 1e-5  # L2 regularization
MAX_VIDEOS_PER_CLASS = 100  # More data for better generalization


class AdvancedVideoDataset(Dataset):
    """Enhanced dataset with advanced augmentation techniques"""
    def __init__(self, data_paths, sequence_length=24, img_size=224, max_videos=100, augment=False):
        self.sequence_length = sequence_length
        self.img_size = img_size
        self.augment = augment
        self.data = []
        
        print("\n=== Loading dataset ===")
        
        # Load real videos (label = 0)
        real_path = Path(data_paths['real'])
        if real_path.exists():
            real_videos = list(real_path.glob('**/*.mp4'))[:max_videos]
            self.data.extend([(str(v), 0) for v in real_videos])
            print(f"Loaded {len(real_videos)} real videos")
        else:
            print(f"! Warning: Real video path not found: {real_path}")
        
        # Load fake videos (label = 1)
        fake_path = Path(data_paths['fake'])
        if fake_path.exists():
            fake_videos = list(fake_path.glob('**/*.mp4'))[:max_videos]
            self.data.extend([(str(v), 1) for v in fake_videos])
            print(f"Loaded {len(fake_videos)} fake videos")
        else:
            print(f"! Warning: Fake video path not found: {fake_path}")
        
        print(f"Total dataset size: {len(self.data)} videos")
        
        # Shuffle data
        random.shuffle(self.data)
        
        if len(self.data) == 0:
            raise ValueError("No videos found! Check your paths.")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        video_path, label = self.data[idx]
        
        try:
            frames = self.extract_frames(video_path)
            if self.augment:
                frames = self.augment_frames(frames)
            return frames, label
        except Exception as e:
            print(f"! Error loading {video_path}: {e}")
            frames = torch.zeros(self.sequence_length, 3, self.img_size, self.img_size, dtype=torch.float32)  # Ensure float32
            return frames, label
    
    def extract_frames(self, video_path):
        """Extract frames with advanced preprocessing"""
        cap = cv2.VideoCapture(video_path)
        frames = []
        
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        if total_frames == 0:
            cap.release()
            raise ValueError(f"No frames in video: {video_path}")
        
        # Sample frames evenly
        if total_frames < self.sequence_length:
            indices = list(range(total_frames))
            while len(indices) < self.sequence_length:
                indices.extend(list(range(total_frames)))
            indices = indices[:self.sequence_length]
        else:
            indices = np.linspace(0, total_frames-1, self.sequence_length, dtype=int)
        
        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            
            if ret:
                # Resize
                frame = cv2.resize(frame, (self.img_size, self.img_size))
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                
                # Advanced normalization
                frame = frame.astype(np.float32) / 255.0
                
                # ImageNet normalization for better feature extraction
                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                frame = (frame - mean) / std
                
                frame = torch.from_numpy(frame).permute(2, 0, 1).float()  # Ensure float32
                frames.append(frame)
            else:
                if frames:
                    frames.append(frames[-1].clone())
                else:
                    frames.append(torch.zeros(3, self.img_size, self.img_size, dtype=torch.float32))  # Ensure float32
        
        cap.release()
        return torch.stack(frames)
    
    def augment_frames(self, frames):
        """Advanced data augmentation"""
        # Random horizontal flip
        if random.random() > 0.5:
            frames = torch.flip(frames, dims=[3])
        
        # Random brightness adjustment
        if random.random() > 0.5:
            brightness_factor = random.uniform(0.7, 1.3)
            frames = frames * brightness_factor
        
        # Random contrast adjustment
        if random.random() > 0.5:
            contrast_factor = random.uniform(0.8, 1.2)
            mean = frames.mean(dim=[2, 3], keepdim=True)
            frames = (frames - mean) * contrast_factor + mean
        
        # Random noise injection
        if random.random() > 0.7:
            noise = torch.randn_like(frames) * 0.02
            frames = frames + noise
        
        # Random rotation (small angles)
        if random.random() > 0.7:
            angle = random.uniform(-5, 5)
            # Simple rotation by shifting
            shift = int(angle / 5)
            if shift != 0:
                frames = torch.roll(frames, shifts=shift, dims=3)
        
        return frames


class TemporalAttention(nn.Module):
    """Attention mechanism for temporal features"""
    def __init__(self, channels):
        super(TemporalAttention, self).__init__()
        self.query = nn.Conv1d(channels, channels // 8, 1)
        self.key = nn.Conv1d(channels, channels // 8, 1)
        self.value = nn.Conv1d(channels, channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        # x: [batch, channels, time]
        batch_size, channels, time = x.size()
        
        # Compute attention
        proj_query = self.query(x).permute(0, 2, 1)  # [B, T, C//8]
        proj_key = self.key(x)  # [B, C//8, T]
        
        energy = torch.bmm(proj_query, proj_key)  # [B, T, T]
        attention = torch.softmax(energy, dim=-1)
        
        proj_value = self.value(x)  # [B, C, T]
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        
        # Residual connection with learnable weight
        out = self.gamma * out + x
        return out


class AdvancedTemporalBlock(nn.Module):
    """Enhanced Temporal Convolutional Block with SE attention"""
    def __init__(self, in_channels, out_channels, kernel_size, dilation):
        super(AdvancedTemporalBlock, self).__init__()
        
        self.padding = (kernel_size - 1) * dilation
        
        # First conv layer
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size,
                              padding=self.padding, dilation=dilation)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.2)
        
        # Second conv layer
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size,
                              padding=self.padding, dilation=dilation)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.2)
        
        # Squeeze-and-Excitation block
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(out_channels, out_channels // 4, 1),
            nn.ReLU(),
            nn.Conv1d(out_channels // 4, out_channels, 1),
            nn.Sigmoid()
        )
        
        # Residual connection
        self.downsample = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else None
        self.relu = nn.ReLU()
    
    def forward(self, x):
        # First conv
        out = self.conv1(x)
        if self.padding > 0:
            out = out[:, :, :-self.padding]
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.dropout1(out)
        
        # Second conv
        out = self.conv2(out)
        if self.padding > 0:
            out = out[:, :, :-self.padding]
        out = self.bn2(out)
        out = self.relu2(out)
        out = self.dropout2(out)
        
        # SE attention
        se_weight = self.se(out)
        out = out * se_weight
        
        # Residual connection
        res = x if self.downsample is None else self.downsample(x)
        
        # Match dimensions
        if out.size(2) != res.size(2):
            res = res[:, :, :out.size(2)]
        
        return self.relu(out + res)


class AdvancedTCN(nn.Module):
    """State-of-the-art TCN model with attention and advanced architectures"""
    def __init__(self):
        super(AdvancedTCN, self).__init__()
        
        # Enhanced spatial feature extractor (ResNet-like)
        self.spatial = nn.Sequential(
            # Initial conv
            nn.Conv2d(3, 64, 7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1),
            
            # Block 1
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            # Block 2
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            # Block 3
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
            # Block 4
            nn.Conv2d(256, 512, 3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            
            nn.AdaptiveAvgPool2d(1)
        )
        
        # Advanced temporal processing
        self.temporal = nn.Sequential(
            AdvancedTemporalBlock(512, 512, kernel_size=3, dilation=1),
            AdvancedTemporalBlock(512, 512, kernel_size=3, dilation=2),
            AdvancedTemporalBlock(512, 512, kernel_size=3, dilation=4),
            AdvancedTemporalBlock(512, 512, kernel_size=3, dilation=8),
        )
        
        # Temporal attention
        self.attention = TemporalAttention(512)
        
        # Enhanced classifier
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 2)
        )
    
    def forward(self, x):
        # x shape: [batch, seq_len, channels, height, width]
        batch_size, seq_len = x.shape[:2]
        
        # Extract spatial features
        x = x.view(batch_size * seq_len, *x.shape[2:])
        x = self.spatial(x)
        x = x.view(batch_size, 512, seq_len)
        
        # Extract temporal features
        x = self.temporal(x)
        
        # Apply attention
        x = self.attention(x)
        
        # Multi-scale temporal pooling
        x_max = torch.max(x, dim=2)[0]
        x_avg = torch.mean(x, dim=2)
        x_std = torch.std(x, dim=2)
        
        # Combine different pooling strategies
        x = x_max + x_avg + 0.5 * x_std
        
        # Classification
        x = self.classifier(x)
        return x


class FocalLoss(nn.Module):
    """Focal Loss for handling class imbalance"""
    def __init__(self, alpha=0.25, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        
    def forward(self, inputs, targets):
        ce_loss = nn.CrossEntropyLoss(reduction='none')(inputs, targets)
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal_loss.mean()


def train_model(model, train_loader, val_loader, epochs=18):
    """Advanced training with mixed precision and gradient accumulation"""
    print(f"\n{'='*70}")
    print(f"TRAINING ADVANCED TCN MODEL")
    print(f"{'='*70}")
    
    # Use Focal Loss for better handling of hard examples
    criterion = FocalLoss(alpha=0.25, gamma=2.0)
    
    # AdamW optimizer with weight decay
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    # Cosine annealing with warm restarts
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=10, T_mult=2, eta_min=1e-7
    )
    
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_val_acc = 0.0
    patience_counter = 0
    patience = 10
    
    # Enable mixed precision training if available
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_preds, train_labels = [], []
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
        for batch_idx, (frames, labels) in enumerate(pbar):
            try:
                frames, labels = frames.to(DEVICE), labels.to(DEVICE)
                
                optimizer.zero_grad()
                
                # Mixed precision forward pass
                if scaler is not None:
                    with torch.cuda.amp.autocast():
                        outputs = model(frames)
                        loss = criterion(outputs, labels)
                    
                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    outputs = model(frames)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()
                
                train_loss += loss.item()
                train_preds.extend(outputs.argmax(dim=1).cpu().numpy())
                train_labels.extend(labels.cpu().numpy())
                
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            except Exception as e:
                print(f"! Error in batch {batch_idx}: {e}")
                continue
        
        train_loss /= len(train_loader)
        train_acc = accuracy_score(train_labels, train_preds)
        
        # Validation
        model.eval()
        val_loss = 0
        val_preds, val_labels = [], []
        
        with torch.no_grad():
            for frames, labels in val_loader:
                try:
                    frames, labels = frames.to(DEVICE), labels.to(DEVICE)
                    
                    if scaler is not None:
                        with torch.cuda.amp.autocast():
                            outputs = model(frames)
                            loss = criterion(outputs, labels)
                    else:
                        outputs = model(frames)
                        loss = criterion(outputs, labels)
                    
                    val_loss += loss.item()
                    val_preds.extend(outputs.argmax(dim=1).cpu().numpy())
                    val_labels.extend(labels.cpu().numpy())
                except Exception as e:
                    print(f"! Validation error: {e}")
                    continue
        
        val_loss /= len(val_loader)
        val_acc = accuracy_score(val_labels, val_preds)
        
        # Learning rate scheduling
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f} | LR: {current_lr:.7f}")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save(model.state_dict(), 'TCN_best.pth')
            print(f"‚úì Saved best model with Val Acc: {val_acc:.4f}")
        else:
            patience_counter += 1
            
        # Early stopping
        if patience_counter >= patience:
            print(f"\n‚úì Early stopping triggered after {epoch+1} epochs")
            break
    
    # Load best model
    model.load_state_dict(torch.load('TCN_best.pth'))
    return history


def evaluate_model(model, test_loader):
    """Comprehensive evaluation"""
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    
    with torch.no_grad():
        for frames, labels in tqdm(test_loader, desc="Evaluating TCN"):
            try:
                frames = frames.to(DEVICE)
                outputs = model(frames)
                probs = torch.softmax(outputs, dim=1)
                
                all_preds.extend(outputs.argmax(dim=1).cpu().numpy())
                all_labels.extend(labels.numpy())
                all_probs.extend(probs[:, 1].cpu().numpy())
            except:
                continue
    
    acc = accuracy_score(all_labels, all_preds)
    prec, rec, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
    cm = confusion_matrix(all_labels, all_preds)
    auc = roc_auc_score(all_labels, all_probs)
    
    print(f"\n{'='*70}")
    print(f"ADVANCED TCN RESULTS")
    print(f"{'='*70}")
    print(f"Accuracy:  {acc:.4f} ({acc*100:.2f}%)")
    print(f"Precision: {prec:.4f}")
    print(f"Recall:    {rec:.4f}")
    print(f"F1-Score:  {f1:.4f}")
    print(f"AUC-ROC:   {auc:.4f}")
    print(f"\nConfusion Matrix:")
    print(f"                Predicted")
    print(f"              Real  Fake")
    print(f"Actual Real   {cm[0][0]:4d}  {cm[0][1]:4d}")
    print(f"       Fake   {cm[1][0]:4d}  {cm[1][1]:4d}")
    
    return {'accuracy': acc, 'precision': prec, 'recall': rec, 'f1': f1, 'auc': auc}


def plot_results(history, results):
    """Visualization of training and results"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Training Loss
    axes[0, 0].plot(history['train_loss'], label='Train', marker='o', linewidth=2, color='#2196F3')
    axes[0, 0].plot(history['val_loss'], label='Validation', marker='s', linewidth=2, color='#FF9800')
    axes[0, 0].set_title('Training and Validation Loss', fontsize=14, weight='bold')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Training Accuracy
    axes[0, 1].plot(history['train_acc'], label='Train', marker='o', linewidth=2, color='#4CAF50')
    axes[0, 1].plot(history['val_acc'], label='Validation', marker='s', linewidth=2, color='#F44336')
    axes[0, 1].set_title('Training and Validation Accuracy', fontsize=14, weight='bold')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    axes[0, 1].axhline(y=0.9, color='green', linestyle='--', alpha=0.5, label='90% Target')
    
    # Final Metrics
    metrics = ['accuracy', 'precision', 'recall', 'f1', 'auc']
    values = [results[m] for m in metrics]
    
    colors = ['#4CAF50' if v >= 0.9 else '#FF9800' if v >= 0.8 else '#F44336' for v in values]
    bars = axes[1, 0].bar(range(len(metrics)), values, color=colors, alpha=0.8)
    axes[1, 0].set_title('Final Performance Metrics', fontsize=14, weight='bold')
    axes[1, 0].set_ylabel('Score')
    axes[1, 0].set_xticks(range(len(metrics)))
    axes[1, 0].set_xticklabels([m.upper() for m in metrics])
    axes[1, 0].grid(True, axis='y', alpha=0.3)
    axes[1, 0].set_ylim([0, 1])
    axes[1, 0].axhline(y=0.9, color='green', linestyle='--', alpha=0.5)
    
    # Add value labels
    for bar in bars:
        height = bar.get_height()
        axes[1, 0].text(bar.get_x() + bar.get_width()/2., height,
                      f'{height:.3f}', ha='center', va='bottom', fontsize=10, weight='bold')
    
    # Performance Summary
    axes[1, 1].axis('off')
    summary_text = f"""
    ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
         ADVANCED TCN PERFORMANCE
    ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
    
    üìä Test Set Results:
    ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    Accuracy:     {results['accuracy']*100:6.2f}%
    Precision:    {results['precision']*100:6.2f}%
    Recall:       {results['recall']*100:6.2f}%
    F1-Score:     {results['f1']*100:6.2f}%
    AUC-ROC:      {results['auc']*100:6.2f}%
    
    üéØ Target Achievement:
    ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    90% Accuracy: {'‚úì ACHIEVED' if results['accuracy'] >= 0.9 else '‚úó Not Yet'}
    
    üìà Training Performance:
    ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    Best Val Acc: {max(history['val_acc'])*100:6.2f}%
    Final Loss:   {history['val_loss'][-1]:6.4f}
    """
    
    axes[1, 1].text(0.5, 0.5, summary_text, ha='center', va='center',
                   fontsize=11, family='monospace',
                   bbox=dict(boxstyle='round', facecolor='#E8F5E9', alpha=0.8))
    
    plt.tight_layout()
    plt.savefig('advanced_tcn_results.png', dpi=300, bbox_inches='tight')
    print("\n‚úì Results saved to 'advanced_tcn_results.png'")
    plt.show()


def main():
    """Main execution"""
    try:
        print("\n" + "="*70)
        print("ADVANCED TCN DEEPFAKE DETECTION")
        print("Target: 90%+ Accuracy")
        print("="*70)
        
        # Load dataset with augmentation for training
        print("\n=== Creating training dataset (with augmentation) ===")
        train_dataset_full = AdvancedVideoDataset(
            DATA_PATHS, SEQUENCE_LENGTH, IMG_SIZE, MAX_VIDEOS_PER_CLASS, augment=True
        )
        
        print("\n=== Creating validation/test dataset (no augmentation) ===")
        eval_dataset_full = AdvancedVideoDataset(
            DATA_PATHS, SEQUENCE_LENGTH, IMG_SIZE, MAX_VIDEOS_PER_CLASS, augment=False
        )
        
        # Split dataset
        total_size = len(train_dataset_full)
        train_size = int(0.7 * total_size)
        val_size = int(0.15 * total_size)
        test_size = total_size - train_size - val_size
        
        train_indices = list(range(train_size))
        val_indices = list(range(train_size, train_size + val_size))
        test_indices = list(range(train_size + val_size, total_size))
        
        train_dataset = torch.utils.data.Subset(train_dataset_full, train_indices)
        val_dataset = torch.utils.data.Subset(eval_dataset_full, val_indices)
        test_dataset = torch.utils.data.Subset(eval_dataset_full, test_indices)
        
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                                 num_workers=0, pin_memory=True)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                               num_workers=0, pin_memory=True)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                                num_workers=0, pin_memory=True)
        
        print(f"\nDataset split: Train={train_size}, Val={val_size}, Test={test_size}")
        
        # Train TCN
        print("\n" + "="*70)
        print("TRAINING ADVANCED TCN MODEL")
        print("="*70)
        tcn_model = AdvancedTCN().to(DEVICE)
        
        # Print model info
        total_params = sum(p.numel() for p in tcn_model.parameters())
        trainable_params = sum(p.numel() for p in tcn_model.parameters() if p.requires_grad)
        print(f"\nModel Parameters:")
        print(f"  Total: {total_params:,}")
        print(f"  Trainable: {trainable_params:,}")
        
        tcn_history = train_model(tcn_model, train_loader, val_loader, EPOCHS)
        tcn_results = evaluate_model(tcn_model, test_loader)
        
        # Plot results
        plot_results(tcn_history, tcn_results)
        
        print("\n" + "="*70)
        print("EXPERIMENT COMPLETED SUCCESSFULLY!")
        print("="*70)
        print(f"\nüìä FINAL TCN RESULTS:")
        print(f"  Accuracy:  {tcn_results['accuracy']:.4f} ({tcn_results['accuracy']*100:.2f}%)")
        print(f"  Precision: {tcn_results['precision']:.4f}")
        print(f"  Recall:    {tcn_results['recall']:.4f}")
        print(f"  F1-Score:  {tcn_results['f1']:.4f}")
        print(f"  AUC-ROC:   {tcn_results['auc']:.4f}")
        
        if tcn_results['accuracy'] >= 0.9:
            print(f"\nüéâ TARGET ACHIEVED! Accuracy >= 90%")
        else:
            print(f"\n‚ö†Ô∏è  Target not reached. Consider:")
            print(f"     - Increasing MAX_VIDEOS_PER_CLASS")
            print(f"     - Training for more epochs")
            print(f"     - Adjusting learning rate")
        
        print(f"\n‚úì Best model saved as 'TCN_best.pth'")
        
    except Exception as e:
        print(f"\n! CRITICAL ERROR: {e}")
        print("! Check your data paths and ensure videos exist.")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()

=== Checking and installing dependencies ===

=== Using device: cpu ===

ADVANCED TCN DEEPFAKE DETECTION
Target: 90%+ Accuracy

=== Creating training dataset (with augmentation) ===

=== Loading dataset ===
Loaded 100 real videos
Loaded 100 fake videos
Total dataset size: 200 videos

=== Creating validation/test dataset (no augmentation) ===

=== Loading dataset ===
Loaded 100 real videos
Loaded 100 fake videos
Total dataset size: 200 videos

Dataset split: Train=140, Val=30, Test=30

TRAINING ADVANCED TCN MODEL

Model Parameters:
  Total: 12,065,475
  Trainable: 12,065,475

TRAINING ADVANCED TCN MODEL


Epoch 1/18 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 35/35 [36:12<00:00, 62.06s/it, loss=0.0576]  


Epoch 1: Train Loss: 0.0535, Train Acc: 0.5357 | Val Loss: 0.0429, Val Acc: 0.6000 | LR: 0.0000488
‚úì Saved best model with Val Acc: 0.6000


Epoch 2/18 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 35/35 [30:21<00:00, 52.05s/it, loss=0.0715]


Epoch 2: Train Loss: 0.0510, Train Acc: 0.5143 | Val Loss: 0.0427, Val Acc: 0.6000 | LR: 0.0000452


Epoch 3/18 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 35/35 [30:27<00:00, 52.21s/it, loss=0.0437]


Epoch 3: Train Loss: 0.0555, Train Acc: 0.4143 | Val Loss: 0.0407, Val Acc: 0.6000 | LR: 0.0000397


Epoch 4/18 [Train]:  51%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè    | 18/35 [16:39<15:43, 55.51s/it, loss=0.0370]


KeyboardInterrupt: 