## Preprocessing

#### Preprocessing & Features Extraction

In [1]:
import os
import cv2
import numpy as np
from tqdm import tqdm
from mtcnn.mtcnn import MTCNN
import matplotlib.pyplot as plt
from scipy.signal import convolve2d
from multiprocessing import Pool, cpu_count, current_process

# --- Input Paths ---
# IMPORTANT: Update this path to your Celeb-DF dataset location
BASE_DIR = "/mnt/data/CelebDF-V2"
REAL_VIDEOS_PATH = os.path.join(BASE_DIR, "Celeb-real")
FAKE_VIDEOS_PATH = os.path.join(BASE_DIR, "Celeb-synthesis")

# --- Output Path ---
# This directory will be created to store the processed features
PREPROCESSED_FEATURES_PATH = "/mnt/LIDeepDet Features"

# --- Preprocessing Settings ---
MAX_VIDEOS_PER_FOLDER = 890
MAX_FRAMES_PER_VIDEO = 200
IMG_SIZE = (224, 224)

# Global MTCNN detector for each process to avoid re-initialization per video
# This will be initialized once per worker process
global_face_detector = None

def init_worker():
    """Initializes the MTCNN detector for each worker process."""
    global global_face_detector
    if global_face_detector is None:
        print(f"[{current_process().name}] Initializing MTCNN detector...")
        global_face_detector = MTCNN()
        print(f"[{current_process().name}] MTCNN detector initialized.")

def extract_illumination_map_paper(face_crop):
    """
    CORRECTED IMPLEMENTATION.
    Creates an illumination map that aligns with the paper's goal: "preserve the
    image's overall structure and smooth texture details". This is achieved robustly
    using a Guided Filter, an edge-preserving smoothing technique.
    """
    if face_crop is None: return None
    
    # 1. Get initial illumination map M_hat (from Equation 2)
    m_hat = np.max(face_crop, axis=-1).astype(np.float32) / 255.0

    # 2. Use the original image as a guide to preserve structure
    # The guided filter will smooth m_hat, but not across the edges present in the guide
    guide_image = cv2.cvtColor(face_crop, cv2.COLOR_BGR2GRAY)

    # 3. Create and apply the guided filter
    # Radius and epsilon are key parameters. A larger radius means more smoothing.
    radius = 32
    epsilon = 0.01
    guided_filter = cv2.ximgproc.createGuidedFilter(guide=guide_image, radius=radius, eps=epsilon)
    M = guided_filter.filter(src=m_hat)
    
    # Normalize for saving and visualization
    smoothed_map = cv2.normalize(M, None, 0, 255, cv2.NORM_MINMAX)
    smoothed_map = np.uint8(smoothed_map)
    return cv2.cvtColor(smoothed_map, cv2.COLOR_GRAY2BGR)

def extract_face_material_map_paper(face_crop, mask_size=5):
    """
    Creates a face material map using the Pattern of Local Gravitational Force
    (PLGF) descriptor as defined in Equations (5-7) of the paper.
    The paper's text states the magnitude is used for texture, so this is correct.
    """
    if face_crop is None: return None
    gray_face = cv2.cvtColor(face_crop, cv2.COLOR_BGR2GRAY).astype(np.float64)
    radius = mask_size // 2
    y, x = np.mgrid[-radius:radius+1, -radius:radius+1]
    epsilon = 1e-12
    denominator = x**2 + y**2 + epsilon
    kernel_tx = (np.cos(np.arctan2(y, x))) / denominator
    kernel_ty = (np.sin(np.arctan2(y, x))) / denominator
    kernel_tx[radius, radius] = 0
    kernel_ty[radius, radius] = 0
    gx = convolve2d(gray_face, kernel_tx, mode='same', boundary='symm')
    gy = convolve2d(gray_face, kernel_ty, mode='same', boundary='symm')
    magnitude = np.sqrt(gx**2 + gy**2)
    material_map = cv2.normalize(magnitude, None, 0, 255, cv2.NORM_MINMAX)
    material_map = np.uint8(material_map)
    return cv2.cvtColor(material_map, cv2.COLOR_GRAY2BGR)

def estimate_light_direction(face_crop):
    """
    Estimates the 2D light direction vector based on the Lambertian model
    (Paper Section 3.4), assuming the brightest region of the face points
    towards the light source. This remains a robust proxy.
    """
    if face_crop is None: return np.array([0.0, 0.0])
    gray_face = cv2.cvtColor(face_crop, cv2.COLOR_BGR2GRAY)
    _, _, _, max_loc = cv2.minMaxLoc(gray_face)
    center_x, center_y = gray_face.shape[1] // 2, gray_face.shape[0] // 2
    vec = np.array([max_loc[0] - center_x, max_loc[1] - center_y], dtype=np.float32)
    norm = np.linalg.norm(vec)
    if norm > 0: vec /= norm
    return vec

def process_video(video_info):
    """
    Processes a single video, extracts features, and saves them.
    This function is designed to be run by a multiprocessing worker.
    """
    video_path = video_info['path']
    output_dir = video_info['output_dir']
    video_id = video_info['id'] # Added for logging
    
    # Ensure the detector is initialized in this worker process
    global global_face_detector
    if global_face_detector is None:
        init_worker() # Fallback, should be handled by Pool(initializer=...)

    os.makedirs(output_dir, exist_ok=True) # Ensure output dir exists for this video

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened(): 
        print(f"[{current_process().name}] Warning: Could not open video {video_path}")
        return video_id, None, None # Return video_id and None for features and visualization
        
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if total_frames < 1:
        cap.release()
        print(f"[{current_process().name}] Warning: Video {video_path} has no frames.")
        return video_id, None, None

    indices = np.linspace(0, total_frames - 1, MAX_FRAMES_PER_VIDEO, dtype=int) if total_frames > MAX_FRAMES_PER_VIDEO else np.arange(total_frames)
    
    # We only need to return data for the first real video for visualization
    # Other videos just need to save their files.
    temp_features = {'frames': [], 'illum_maps': [], 'material_maps': [], 'light_vectors': []}
    is_first_real_video_of_session = False
    
    # If this is a real video and we need to save its features for potential visualization
    if video_info['category'] == 'real' and not os.path.exists(os.path.join(output_dir, "0000_frame.png")):
        is_first_real_video_of_session = True
        
    saved_frame_count = 0
    
    for frame_idx in indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        if not ret: continue

        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # Use the detector initialized by init_worker()
        faces = global_face_detector.detect_faces(frame_rgb)

        if len(faces) > 0:
            best_face = sorted(faces, key=lambda f: f['confidence'], reverse=True)[0]
            x, y, w, h = best_face['box']
            x, y = max(0, x), max(0, y)
            face_crop = frame[y:y+h, x:x+w]
            
            if face_crop.size == 0: continue

            face_crop_resized = cv2.resize(face_crop, IMG_SIZE)
            
            illum_map = extract_illumination_map_paper(face_crop_resized)
            material_map = extract_face_material_map_paper(face_crop_resized)
            light_vector = estimate_light_direction(face_crop_resized)
            
            fid = f"{saved_frame_count:04d}"
            cv2.imwrite(os.path.join(output_dir, f"{fid}_frame.png"), face_crop_resized)
            if illum_map is not None: cv2.imwrite(os.path.join(output_dir, f"{fid}_illum.png"), illum_map)
            if material_map is not None: cv2.imwrite(os.path.join(output_dir, f"{fid}_material.png"), material_map)
            np.save(os.path.join(output_dir, f"{fid}_lightvec.npy"), light_vector)
            
            # Only store features for the first frame of the first real video if needed for visualization
            if is_first_real_video_of_session and saved_frame_count == 0:
                 temp_features['frames'].append(face_crop_resized)
                 temp_features['illum_maps'].append(illum_map)
                 temp_features['material_maps'].append(material_map)
                 temp_features['light_vectors'].append(light_vector)

            saved_frame_count += 1
    cap.release()
    
    if is_first_real_video_of_session and temp_features['frames']:
        return video_id, temp_features, video_info['category']
    return video_id, None, video_info['category']

def visualize_preprocessing_steps(features):
    if not features or not features['frames']: return
    frame, illum_map, material_map, light_vector = features['frames'][0], features['illum_maps'][0], features['material_maps'][0], features['light_vectors'][0]
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    fig.suptitle('Preprocessing Visualization (Corrected Implementation)', fontsize=16)
    axes[0].imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)); axes[0].set_title('1. Original Face Crop'); axes[0].axis('off')
    axes[1].imshow(illum_map); axes[1].set_title('2. Illumination Map'); axes[1].axis('off')
    cx, cy = IMG_SIZE[0] // 2, IMG_SIZE[1] // 2
    axes[1].arrow(cx, cy, light_vector[0] * 50, light_vector[1] * 50, head_width=10, head_length=10, fc='r', ec='r')
    axes[2].imshow(material_map, cmap='gray'); axes[2].set_title('3. Face Material Map'); axes[2].axis('off')
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

def run_preprocessing():
    """Main function to run the entire preprocessing pipeline."""
    print("Starting preprocessing...")

    os.makedirs(PREPROCESSED_FEATURES_PATH, exist_ok=True)
    progress_file = os.path.join(PREPROCESSED_FEATURES_PATH, '_progress.txt')
    completed_videos = set()
    try:
        with open(progress_file, 'r') as f:
            completed_videos = {line.strip() for line in f}
        print(f"Found {len(completed_videos)} videos already processed. Resuming.")
    except FileNotFoundError:
        print("Starting a new preprocessing run.")
        
    # --- NO GLOBAL MTCNN INITIALIZATION HERE ---
    # MTCNN will be initialized per worker process in init_worker()
    
    videos_to_process = []
    for category, path in {'real': REAL_VIDEOS_PATH, 'fake': FAKE_VIDEOS_PATH}.items():
        if not os.path.exists(path):
            print(f"Warning: Directory not found, skipping: {path}")
            continue
        
        video_files = [f for f in os.listdir(path) if f.endswith(('.mp4', '.avi', '.mov'))]
        if MAX_VIDEOS_PER_FOLDER is not None: video_files = video_files[:MAX_VIDEOS_PER_FOLDER]

        for video_name in video_files:
            video_id = os.path.splitext(video_name)[0]
            video_unique_id = f"{category}/{video_id}"
            
            # Check if output directory already exists and contains files
            video_output_dir = os.path.join(PREPROCESSED_FEATURES_PATH, category, video_id)
            is_processed = os.path.exists(video_output_dir) and len(os.listdir(video_output_dir)) > 0

            if not is_processed and video_unique_id not in completed_videos:
                videos_to_process.append({
                    'id': video_unique_id,
                    'path': os.path.join(path, video_name),
                    'output_dir': video_output_dir,
                    'category': category
                })
            elif is_processed and video_unique_id not in completed_videos:
                 # If files exist but not in progress log, add to log for consistency
                 with open(progress_file, 'a') as progress_log:
                     progress_log.write(f"{video_unique_id}\n")
                     progress_log.flush()
                 completed_videos.add(video_unique_id)


    if not videos_to_process:
        print("All specified videos have already been processed. Nothing to do.")
        return
        
    print(f"Total new videos to process: {len(videos_to_process)}")

    # --- Parallel Processing Setup ---
    NUM_WORKERS = 32 # Set the number of workers
    print(f"Using {NUM_WORKERS} worker processes...")
    
    first_real_video_features = None

    # Use a multiprocessing Pool
    # The initializer ensures MTCNN is created once per worker process
    # The context 'spawn' is more robust across different OS (especially for macOS/Windows)
    with Pool(processes=NUM_WORKERS, initializer=init_worker) as pool:
        # Use imap_unordered for results as they become available and for the progress bar
        for video_id, features_for_viz, category in tqdm(
            pool.imap_unordered(process_video, videos_to_process),
            total=len(videos_to_process),
            desc="Processing Videos"
        ):
            if video_id is not None:
                with open(progress_file, 'a') as progress_log:
                    progress_log.write(f"{video_id}\n")
                    progress_log.flush()
                
                if category == 'real' and features_for_viz is not None and first_real_video_features is None:
                    first_real_video_features = features_for_viz

    print("\nPreprocessing complete.")
    if first_real_video_features:
        print("Displaying visualization for the first real video processed in this session...")
        visualize_preprocessing_steps(first_real_video_features)
    else:
        print("No new real videos were processed in this session that required visualization data.")

# Ensure the script runs with multiprocessing
if __name__ == '__main__':
    run_preprocessing()

## Training

#### Config

In [None]:
import os
import cv2
import timm
import glob
import json
import wandb
import torch
import numpy as np
import torch.nn as nn
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, confusion_matrix

class Config:
    PREPROCESSED_DATA_DIR = "/mnt/LIDeepDet Features"
    OUTPUT_DIR = "Outputs"
    WANDB_PROJECT_NAME = "LIDeepDet"
    WANDB_RUN_NAME = "run-890-videos-weight-decay"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    BATCH_SIZE = 4
    LEARNING_RATE = 1e-5
    EPOCHS = 50
    IMG_SIZE = (224, 224)
    VIT_MODEL_NAME = 'vit_base_patch16_224'
    EMBED_DIM = 768
    RESUME_CHECKPOINT = None

#### Dataloaders

In [None]:
class DeepfakeDataset(Dataset):
    # ... (This class is unchanged from the last robust version)
    def __init__(self, root_dir, transform=None):
        self.transform = transform
        real_videos = [d for d in glob.glob(os.path.join(root_dir, 'real', '*')) if os.path.isdir(d)]
        fake_videos = [d for d in glob.glob(os.path.join(root_dir, 'fake', '*')) if os.path.isdir(d)]
        all_folders = [(d, 0) for d in real_videos] + [(d, 1) for d in fake_videos]
        self.video_folders = [(path, label) for path, label in all_folders if len(glob.glob(os.path.join(path, '*_frame.png'))) > 0]
        self.num_real = len(real_videos)
        self.num_fake = len(fake_videos)
        if len(all_folders) != len(self.video_folders): print(f"Warning: Filtered out {len(all_folders) - len(self.video_folders)} empty video directories.")
    def __len__(self): return len(self.video_folders)
    def __getitem__(self, idx):
        video_dir, label = self.video_folders[idx]
        frame_files = glob.glob(os.path.join(video_dir, '*_frame.png'))
        if not frame_files: return self.__getitem__((idx + 1) % len(self))
        random_frame_path = np.random.choice(frame_files)
        frame_id = os.path.basename(random_frame_path).split('_')[0]
        rgb_path = random_frame_path
        illum_path = os.path.join(video_dir, f"{frame_id}_illum.png")
        material_path = os.path.join(video_dir, f"{frame_id}_material.png")
        if not all(os.path.exists(p) for p in [rgb_path, illum_path, material_path]): return self.__getitem__((idx + 1) % len(self))
        rgb_img = cv2.cvtColor(cv2.imread(rgb_path), cv2.COLOR_BGR2RGB)
        illum_img = cv2.cvtColor(cv2.imread(illum_path), cv2.COLOR_BGR2RGB)
        material_img = cv2.cvtColor(cv2.imread(material_path), cv2.COLOR_BGR2RGB)
        if rgb_img is None or illum_img is None or material_img is None: return self.__getitem__((idx + 1) % len(self))
        if self.transform:
            rgb_img, illum_img, material_img = self.transform(rgb_img), self.transform(illum_img), self.transform(material_img)
        return (rgb_img, illum_img, material_img), torch.tensor(label, dtype=torch.float32)

#### Model Architecture

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=8):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

    def forward(self, query, key, value):
        attn_output, _ = self.multihead_attn(query=query, key=key, value=value)
        return attn_output

class LIDeepDet(nn.Module):
    def __init__(self, vit_model_name, embed_dim, pretrained=True):
        super().__init__()
        self.backbone_rgb = timm.create_model(vit_model_name, pretrained=pretrained, num_classes=0)
        self.backbone_illum = timm.create_model(vit_model_name, pretrained=pretrained, num_classes=0)
        self.backbone_material = timm.create_model(vit_model_name, pretrained=pretrained, num_classes=0)
        self.cross_attention = CrossAttention(embed_dim)
        self.classifier = nn.Sequential(nn.LayerNorm(embed_dim * 6), nn.Linear(embed_dim * 6, embed_dim), nn.GELU(), nn.Linear(embed_dim, 1))

    def forward(self, rgb_img, illum_img, material_img):
        f_rgb = self.backbone_rgb.forward_features(rgb_img)[:, 0].unsqueeze(1)
        f_illum = self.backbone_illum.forward_features(illum_img)[:, 0].unsqueeze(1)
        f_material = self.backbone_material.forward_features(material_img)[:, 0].unsqueeze(1)
        a_ri = self.cross_attention(f_rgb, f_illum, f_illum)
        a_rm = self.cross_attention(f_rgb, f_material, f_material)
        a_ir = self.cross_attention(f_illum, f_rgb, f_rgb)
        a_im = self.cross_attention(f_illum, f_material, f_material)
        a_mr = self.cross_attention(f_material, f_rgb, f_rgb)
        a_mi = self.cross_attention(f_material, f_illum, f_illum)
        fused_features = torch.cat([a_ri, a_rm, a_ir, a_im, a_mr, a_mi], dim=-1).squeeze(1)
        return self.classifier(fused_features)

#### Training & Evaluation Loops

In [6]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import cv2
import numpy as np
import os
import glob
import json
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, confusion_matrix, precision_score, recall_score, f1_score
import timm
import wandb
import seaborn as sns
import matplotlib.pyplot as plt

# ==============================================================================
# 1. CONFIGURATION
# ==============================================================================
class Config:
    EPOCHS = 50
    PATIENCE = 10
    EMBED_DIM = 768
    BATCH_SIZE = 4
    LEARNING_RATE = 1e-5
    IMG_SIZE = (224, 224)
    OUTPUT_DIR = "Outputs"
    RESUME_CHECKPOINT = None
    # WANDB_PROJECT_NAME = "LIDeepDet"
    PREPROCESSED_DATA_DIR = "/mnt/LIDeepDet Features"
    VIT_MODEL_NAME = 'vit_base_patch16_224'
    # WANDB_RUN_NAME = "run-100-videos-weight-decay"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ==============================================================================
# 2. DATASET & MODEL (Unchanged)
# ==============================================================================
# The DeepfakeDataset and LIDeepDet classes are correct and do not need changes.
# They are included here to make the script standalone.
class DeepfakeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.transform = transform
        real_videos = [d for d in glob.glob(os.path.join(root_dir, 'real', '*')) if os.path.isdir(d)]
        fake_videos = [d for d in glob.glob(os.path.join(root_dir, 'fake', '*')) if os.path.isdir(d)]
        all_folders = [(d, 0) for d in real_videos] + [(d, 1) for d in fake_videos]
        self.video_folders = [(p, l) for p, l in all_folders if len(glob.glob(os.path.join(p, '*_frame.png'))) > 0]
        self.num_real, self.num_fake = len(real_videos), len(fake_videos)
        if len(all_folders) != len(self.video_folders): print(f"Warning: Filtered {len(all_folders) - len(self.video_folders)} empty video directories.")
    def __len__(self): return len(self.video_folders)
    def __getitem__(self, idx):
        video_dir, label = self.video_folders[idx]
        frame_files = glob.glob(os.path.join(video_dir, '*_frame.png'))
        if not frame_files: return self.__getitem__((idx + 1) % len(self))
        frame_path = np.random.choice(frame_files)
        frame_id = os.path.basename(frame_path).split('_')[0]
        paths = [frame_path, os.path.join(video_dir, f"{frame_id}_illum.png"), os.path.join(video_dir, f"{frame_id}_material.png")]
        if not all(os.path.exists(p) for p in paths): return self.__getitem__((idx + 1) % len(self))
        images = [cv2.cvtColor(cv2.imread(p), cv2.COLOR_BGR2RGB) for p in paths]
        if any(img is None for img in images): return self.__getitem__((idx + 1) % len(self))
        if self.transform: images = [self.transform(img) for img in images]
        return tuple(images), torch.tensor(label, dtype=torch.float32)

class LIDeepDet(nn.Module):
    def __init__(self, vit_model_name, embed_dim, pretrained=True):
        super().__init__()
        self.backbone_rgb = timm.create_model(vit_model_name, pretrained=pretrained, num_classes=0)
        self.backbone_illum = timm.create_model(vit_model_name, pretrained=pretrained, num_classes=0)
        self.backbone_material = timm.create_model(vit_model_name, pretrained=pretrained, num_classes=0)
        self.cross_attention = nn.MultiheadAttention(embed_dim, 8, batch_first=True)
        self.classifier = nn.Sequential(nn.LayerNorm(embed_dim * 6), nn.Linear(embed_dim * 6, embed_dim), nn.GELU(), nn.Linear(embed_dim, 1))
    def forward(self, rgb, illum, material):
        f_rgb, f_illum, f_mat = [b.forward_features(x)[:, 0].unsqueeze(1) for b, x in zip((self.backbone_rgb, self.backbone_illum, self.backbone_material), (rgb, illum, material))]
        a_ri, _ = self.cross_attention(f_rgb, f_illum, f_illum)
        a_rm, _ = self.cross_attention(f_rgb, f_mat, f_mat)
        a_ir, _ = self.cross_attention(f_illum, f_rgb, f_rgb)
        a_im, _ = self.cross_attention(f_illum, f_mat, f_mat)
        a_mr, _ = self.cross_attention(f_mat, f_rgb, f_rgb)
        a_mi, _ = self.cross_attention(f_mat, f_illum, f_illum)
        fused = torch.cat([a_ri, a_rm, a_ir, a_im, a_mr, a_mi], dim=-1).squeeze(1)
        return self.classifier(fused)

# ==============================================================================
# 3. HELPER FUNCTIONS (ENHANCED METRICS & PLOTTING)
# ==============================================================================
def get_metrics(labels, preds):
    """Calculates a dictionary of metrics."""
    binary_preds = (preds > 0.5).astype(int)
    metrics = {
        "accuracy": np.mean(binary_preds == labels),
        "precision": precision_score(labels, binary_preds, zero_division=0),
        "recall": recall_score(labels, binary_preds, zero_division=0),
        "f1_score": f1_score(labels, binary_preds, zero_division=0)
    }
    if len(np.unique(labels)) > 1:
        metrics["auc"] = roc_auc_score(labels, preds)
    else:
        metrics["auc"] = 0.5 # Default AUC if only one class is present
    return metrics

def train_one_epoch(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss, all_preds, all_labels = 0.0, [], []
    for images, labels in tqdm(loader, desc="Training", leave=False):
        images = [img.to(device) for img in images]
        labels = labels.to(device).unsqueeze(1)
        optimizer.zero_grad()
        outputs = model(*images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        all_preds.extend(torch.sigmoid(outputs).detach().cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    metrics = get_metrics(np.array(all_labels), np.array(all_preds))
    metrics['loss'] = total_loss / len(loader)
    return metrics

@torch.no_grad()
def evaluate(model, loader, loss_fn, device):
    model.eval()
    total_loss, all_preds, all_labels = 0.0, [], []
    for images, labels in tqdm(loader, desc="Evaluating", leave=False):
        images = [img.to(device) for img in images]
        labels = labels.to(device).unsqueeze(1)
        outputs = model(*images)
        loss = loss_fn(outputs, labels)
        total_loss += loss.item()
        all_preds.extend(torch.sigmoid(outputs).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    metrics = get_metrics(np.array(all_labels), np.array(all_preds))
    metrics['loss'] = total_loss / len(loader)
    metrics['predictions'] = np.array(all_preds)
    metrics['labels'] = np.array(all_labels)
    return metrics

def log_diagnostic_plots(val_metrics, epoch):
    """Generates and logs confusion matrix and prediction histogram to W&B."""
    labels, preds = val_metrics['labels'].flatten(), val_metrics['predictions'].flatten()
    
    # 1. Confusion Matrix
    cm = confusion_matrix(labels, preds > 0.5)
    plt.figure(figsize=(8, 6)); sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Real', 'Fake'], yticklabels=['Real', 'Fake'])
    plt.xlabel('Predicted'); plt.ylabel('Actual'); plt.title(f'Validation Confusion Matrix - Epoch {epoch}')
    wandb.log({"val_confusion_matrix": wandb.Image(plt)}); plt.close()
    
    # 2. Prediction Distribution Histogram
    plt.figure(figsize=(10, 6)); 
    sns.histplot(x=preds[labels==0], color='blue', alpha=0.5, label='Real Predictions', bins=30)
    sns.histplot(x=preds[labels==1], color='red', alpha=0.5, label='Fake Predictions', bins=30)
    plt.title(f'Validation Prediction Distribution - Epoch {epoch}'); plt.xlabel('Predicted Probability (of being Fake)'); plt.legend()
    wandb.log({"val_prediction_distribution": wandb.Image(plt)}); plt.close()

# ==============================================================================
# 4. MAIN FUNCTION
# ==============================================================================
def main():
    config = Config()
    device = torch.device(config.DEVICE)
    # wandb.init(project=config.WANDB_PROJECT_NAME, name=config.WANDB_RUN_NAME, config=vars(config))

    # --- EXTENSIVE AUGMENTATION PIPELINE ---
    train_transform = transforms.Compose([
        transforms.ToPILImage(), # Convert cv2 image to PIL Image for augmentations
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.2),
        transforms.RandomResizedCrop(size=config.IMG_SIZE, scale=(0.8, 1.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Validation and test sets should NOT have augmentations
    eval_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(config.IMG_SIZE), # Just resize, no cropping
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    print("Loading and splitting dataset with extensive augmentations...")
    dataset_for_train = DeepfakeDataset(root_dir=config.PREPROCESSED_DATA_DIR, transform=train_transform)
    dataset_for_eval = DeepfakeDataset(root_dir=config.PREPROCESSED_DATA_DIR, transform=eval_transform)
    
    train_size, val_size = int(0.8 * len(dataset_for_train)), int(0.1 * len(dataset_for_train))
    test_size = len(dataset_for_train) - train_size - val_size
    indices = torch.randperm(len(dataset_for_train)).tolist()

    train_dataset = torch.utils.data.Subset(dataset_for_train, indices[:train_size])
    val_dataset = torch.utils.data.Subset(dataset_for_eval, indices[train_size:train_size + val_size])
    test_dataset = torch.utils.data.Subset(dataset_for_eval, indices[train_size + val_size:])
    
    train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=0)
    print(f"Data split -> Train: {len(train_dataset)}, Validation: {len(val_dataset)}, Test: {len(test_dataset)}")

    model = LIDeepDet(config.VIT_MODEL_NAME, config.EMBED_DIM).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE, weight_decay=1e-5)
    pos_weight = torch.tensor(dataset_for_train.num_real / dataset_for_train.num_fake if dataset_for_train.num_fake > 0 else 1, dtype=torch.float32).to(device)
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    
    best_val_loss = float('inf')
    epochs_no_improve = 0 # <--- Counter for patience
    
    # best_val_auc = 0.0 # This line is no longer strictly needed if using loss for early stopping
    for epoch in range(config.EPOCHS):
        train_metrics = train_one_epoch(model, train_loader, optimizer, loss_fn, device)
        val_metrics = evaluate(model, val_loader, loss_fn, device)
        
        print(f"Epoch {epoch+1}/{config.EPOCHS} -> Train Loss: {train_metrics['loss']:.4f}, Train Acc: {train_metrics['accuracy']:.4f} | Val Loss: {val_metrics['loss']:.4f}, Val Acc: {val_metrics['accuracy']:.4f}")
        
        # wandb_logs = {"epoch": epoch + 1}
        # for key, val in train_metrics.items(): wandb_logs[f"train_{key}"] = val
        # for key, val in val_metrics.items(): 
            # if key not in ['predictions', 'labels']: wandb_logs[f"val_{key}"] = val
        # wandb.log(wandb_logs)
        
        log_diagnostic_plots(val_metrics, epoch + 1)
        
        # --- Early Stopping Logic ---
        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            epochs_no_improve = 0 # Reset counter
            print(f"  -> New best model found with Val Loss: {best_val_loss:.4f}. Saving lean checkpoint...")
            
            best_model_path = os.path.join(config.OUTPUT_DIR, 'best_checkpoint.pth')
            torch.save(model.state_dict(), best_model_path)
            
            # Log artifact only when a new best model is saved
            # artifact = wandb.Artifact(f'best-model-run-{wandb.run.id}', type='model')
            # artifact.add_file(best_model_path)
            # wandb.log_artifact(artifact)
        else:
            epochs_no_improve += 1
            print(f"  -> Validation loss did not improve. Patience: {epochs_no_improve}/{config.PATIENCE}")
            if epochs_no_improve == config.PATIENCE:
                print(f"Early stopping triggered after {config.PATIENCE} epochs without improvement.")
                break # Exit the training loop

    # --- Final Test Evaluation Logic (Unchanged) ---
    print("\n--- Training Finished. Starting Final Evaluation on Test Set ---")
    best_model_path = os.path.join(config.OUTPUT_DIR, 'best_checkpoint.pth')
    if os.path.exists(best_model_path):
        final_model = LIDeepDet(config.VIT_MODEL_NAME, config.EMBED_DIM).to(device)
        final_model.load_state_dict(torch.load(best_model_path, map_location=device))
        test_metrics = evaluate(final_model, test_loader, loss_fn, device)
        print(f"Final Test Set Results -> Accuracy: {test_metrics['accuracy']:.4f}, AUC: {test_metrics['auc']:.4f}, F1: {test_metrics['f1_score']:.4f}")
        # wandb.log({f"test_{k}": v for k, v in test_metrics.items() if k not in ['predictions', 'labels']})
    else:
        print("No best model checkpoint found (training might have stopped too early without any improvement).")
    
    # wandb.finish()

if __name__ == '__main__':
    os.makedirs(Config.OUTPUT_DIR, exist_ok=True)
    main()

Loading and splitting dataset with extensive augmentations...
Data split -> Train: 1424, Validation: 178, Test: 178


                                                 

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 0 has a total capacity of 19.56 GiB of which 25.19 MiB is free. Process 7259 has 6.07 GiB memory in use. Including non-PyTorch memory, this process has 12.85 GiB memory in use. Of the allocated memory 12.00 GiB is allocated by PyTorch, and 637.59 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)