In [1]:
# -*- coding: utf-8 -*-

"""
LSTM V1 CROSS-DATASET EVALUATION: FakeAVCeleb (WITH FACE DETECTION)
Uses SAME preprocessing as training (face detection + mouth crop)
"""

from google.colab import drive
drive.mount('/content/drive')

import os
import cv2
import time
import torch
import zipfile
import librosa
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from tqdm import tqdm
import torchvision.models as models
from torch.cuda.amp import autocast
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
import warnings

warnings.filterwarnings('ignore')
torch.backends.cudnn.benchmark = True
print("✅ Libraries imported")

# ============================================================================
# EXTRACT DATASET
# ============================================================================

print("\n" + "="*80)
print("📦 EXTRACTING FAKEAVCELEB")
print("="*80)

zip_path = '/content/drive/MyDrive/CSE400 codes - 144/archive.zip'
extract_path = '/content/FakeAVCeleb'

if not os.path.exists(extract_path):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_path)
    print("✅ Extracted")
else:
    print("✅ Already extracted")

# ============================================================================
# CONFIGURATION
# ============================================================================

class Config:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.vis_image_size = (128, 128)
        self.vis_num_frames = 16
        self.vis_cnn_feature_dim = 576
        self.vis_lstm_hidden = 128
        self.aud_sample_rate = 16000
        self.aud_num_chunks = 5
        self.aud_chunk_duration = 1.0
        self.aud_n_mels = 128
        self.aud_cnn_feature_dim = 576
        self.aud_lstm_hidden = 128
        self.lstm_num_layers = 2
        self.lstm_bidirectional = True
        self.lstm_dropout = 0.2
        self.batch_size = 64  # Reduced for stability
        self.num_workers = 2

config = Config()
print(f"✅ Config loaded")

# ============================================================================
# DATA PROCESSING WITH FACE DETECTION (MATCHES TRAINING)
# ============================================================================

def process_visual_with_face_detection(video_path: str, config: Config):
    """EXACT MATCH to training: face detection + mouth crop"""
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        return None

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if total_frames < config.vis_num_frames:
        cap.release()
        return None

    frame_indices = np.linspace(0, total_frames - 1, config.vis_num_frames, dtype=int)
    frames = []
    face_detector = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

    for i in frame_indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, i)
        ret, frame = cap.read()
        if not ret:
            continue

        # Face detection
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        faces = face_detector.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=4)

        if len(faces) > 0:
            # Get largest face
            (x, y, w, h) = max(faces, key=lambda rect: rect[2] * rect[3])

            # Mouth region crop (EXACT MATCH TO TRAINING)
            mouth_y_start = y + int(h * 0.6)
            mouth_y_end = y + h
            mouth_x_start = x + int(w * 0.25)
            mouth_x_end = x + int(w * 0.75)

            mouth_crop = frame[mouth_y_start:mouth_y_end, mouth_x_start:mouth_x_end]

            if mouth_crop.size > 0:
                resized_crop = cv2.resize(mouth_crop, config.vis_image_size)
                resized_crop_rgb = cv2.cvtColor(resized_crop, cv2.COLOR_BGR2RGB)
                frames.append(resized_crop_rgb)
            else:
                # Fallback: center crop
                h, w = frame.shape[:2]
                crop_size = min(h, w)
                start_h = (h - crop_size) // 2
                start_w = (w - crop_size) // 2
                cropped = frame[start_h:start_h+crop_size, start_w:start_w+crop_size]
                resized = cv2.resize(cropped, config.vis_image_size)
                resized_rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
                frames.append(resized_rgb)
        else:
            # Fallback: center crop
            h, w = frame.shape[:2]
            crop_size = min(h, w)
            start_h = (h - crop_size) // 2
            start_w = (w - crop_size) // 2
            cropped = frame[start_h:start_h+crop_size, start_w:start_w+crop_size]
            resized = cv2.resize(cropped, config.vis_image_size)
            resized_rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
            frames.append(resized_rgb)

    cap.release()
    return np.stack(frames) if len(frames) == config.vis_num_frames else None

def process_audio_from_video(video_path: str, config: Config):
    """Extract audio from video"""
    try:
        y, sr = librosa.load(video_path, sr=config.aud_sample_rate, duration=5.0, mono=True)
        target_length = int(config.aud_sample_rate * config.aud_chunk_duration * config.aud_num_chunks)

        if len(y) < target_length:
            y = np.pad(y, (0, target_length - len(y)), mode='constant')
        else:
            y = y[:target_length]

        samples_per_chunk = int(config.aud_chunk_duration * sr)
        mel_spectrograms = []

        for i in range(config.aud_num_chunks):
            chunk = y[i*samples_per_chunk : (i+1)*samples_per_chunk]
            mel = librosa.feature.melspectrogram(y=chunk, sr=sr, n_mels=config.aud_n_mels)
            mel_db = librosa.power_to_db(mel, ref=np.max)
            mel_db = (mel_db - mel_db.mean()) / (mel_db.std() + 1e-9)
            mel_spectrograms.append(torch.tensor(mel_db, dtype=torch.float32))

        return torch.stack(mel_spectrograms, dim=0)
    except Exception:
        return None

# ============================================================================
# DATASET
# ============================================================================

class FakeAVCelebDataset(Dataset):
    def __init__(self, video_paths, labels, config):
        self.video_paths = video_paths
        self.labels = labels
        self.config = config
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        try:
            visual_frames = process_visual_with_face_detection(self.video_paths[idx], self.config)
            if visual_frames is None:
                return None

            visual_tensors = [self.transform(frame) for frame in visual_frames]
            visual_tensor = torch.stack(visual_tensors, dim=0)

            audio_mels = process_audio_from_video(self.video_paths[idx], self.config)
            if audio_mels is None:
                return None

            return visual_tensor, audio_mels.unsqueeze(1), torch.tensor(self.labels[idx], dtype=torch.float32)
        except Exception:
            return None

def collate_fn(batch):
    batch = [x for x in batch if x is not None]
    return torch.utils.data.dataloader.default_collate(batch) if batch else None

# ============================================================================
# MODEL
# ============================================================================

class VisualStream_LSTM(nn.Module):
    def __init__(self, config):
        super().__init__()
        mobilenet = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
        self.cnn_features = mobilenet.features
        self.avgpool = mobilenet.avgpool
        self.proj = nn.Linear(config.vis_cnn_feature_dim, config.vis_lstm_hidden)
        self.proj_dropout = nn.Dropout(0.3)
        self.lstm = nn.LSTM(config.vis_lstm_hidden, config.vis_lstm_hidden, config.lstm_num_layers,
                           batch_first=True, bidirectional=config.lstm_bidirectional,
                           dropout=config.lstm_dropout if config.lstm_num_layers > 1 else 0.0)
        self.lstm_dropout = nn.Dropout(0.2)
        self.out_dim = config.vis_lstm_hidden * (2 if config.lstm_bidirectional else 1)

    def forward(self, x):
        b, t, c, h, w = x.shape
        x = x.view(b * t, c, h, w)
        features = self.avgpool(self.cnn_features(x)).view(b, t, -1)
        projected = self.proj_dropout(self.proj(features))
        lstm_out, _ = self.lstm(projected)
        return self.lstm_dropout(lstm_out[:, -1, :])

class AudioStream_LSTM(nn.Module):
    def __init__(self, config):
        super().__init__()
        mobilenet = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
        self.cnn_features = mobilenet.features
        self.avgpool = mobilenet.avgpool
        self.proj = nn.Linear(config.aud_cnn_feature_dim, config.aud_lstm_hidden)
        self.proj_dropout = nn.Dropout(0.3)
        self.lstm = nn.LSTM(config.aud_lstm_hidden, config.aud_lstm_hidden, config.lstm_num_layers,
                           batch_first=True, bidirectional=config.lstm_bidirectional,
                           dropout=config.lstm_dropout if config.lstm_num_layers > 1 else 0.0)
        self.lstm_dropout = nn.Dropout(0.2)
        self.out_dim = config.aud_lstm_hidden * (2 if config.lstm_bidirectional else 1)

    def forward(self, x):
        b, t, c, h, w = x.shape
        x = x.view(b * t, c, h, w).repeat(1, 3, 1, 1)
        features = self.avgpool(self.cnn_features(x)).view(b, t, -1)
        projected = self.proj_dropout(self.proj(features))
        lstm_out, _ = self.lstm(projected)
        return self.lstm_dropout(lstm_out[:, -1, :])

class FusionModel_LSTM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.visual_stream = VisualStream_LSTM(config)
        self.audio_stream = AudioStream_LSTM(config)
        self.fusion_head = nn.Sequential(
            nn.Linear(self.visual_stream.out_dim + self.audio_stream.out_dim, 256),
            nn.ReLU(), nn.Dropout(0.6), nn.Linear(256, 1)
        )

    def forward(self, visual, audio):
        return self.fusion_head(torch.cat([self.visual_stream(visual), self.audio_stream(audio)], dim=1))

# ============================================================================
# FIND VIDEOS
# ============================================================================

print("\n" + "="*80)
print("📂 LOCATING VIDEOS")
print("="*80)

def find_videos(root_dir):
    videos = []
    for root, dirs, files in os.walk(root_dir):
        for file in files:
            if file.lower().endswith(('.mp4', '.avi', '.mov')):
                videos.append(os.path.join(root, file))
    return videos

all_videos = find_videos(extract_path)
real_videos = [v for v in all_videos if 'RealVideo-RealAudio' in v]
fake_videos = [v for v in all_videos if 'FakeVideo' in v or ('fake' in v.lower() and 'RealVideo' not in v)]
print(f"Real: {len(real_videos)}, Fake: {len(fake_videos)}")

# ============================================================================
# LOAD MODEL
# ============================================================================

print("\n" + "="*80)
print("🔧 LOADING MODEL")
print("="*80)

model_path = '/content/drive/MyDrive/PTHs/lstm_v1_best.pth'
model = FusionModel_LSTM(config).to(config.device)
model.load_state_dict(torch.load(model_path, map_location=config.device))
model.eval()

for m in model.modules():
    if isinstance(m, nn.Dropout):
        m.p = 0.0

print(f"✅ Model loaded")

# ============================================================================
# EVALUATION
# ============================================================================

def evaluate_seed(seed, real_vids, fake_vids, model, config):
    print(f"\n{'='*80}\n🎲 SEED: {seed}\n{'='*80}")
    np.random.seed(seed)

    real_sample = np.random.choice(real_vids, min(500, len(real_vids)), replace=False).tolist()
    fake_sample = np.random.choice(fake_vids, min(3000, len(fake_vids)), replace=False).tolist()
    test_videos = real_sample + fake_sample
    test_labels = [0]*len(real_sample) + [1]*len(fake_sample)

    print(f"Samples: {len(real_sample)} real + {len(fake_sample)} fake")

    dataset = FakeAVCelebDataset(test_videos, test_labels, config)
    loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False,
                       num_workers=config.num_workers, collate_fn=collate_fn, pin_memory=True)

    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in tqdm(loader, desc=f"Seed {seed}"):
            if batch is None:
                continue
            visual, audio, labels = batch
            with autocast():
                outputs = torch.sigmoid(model(visual.to(config.device), audio.to(config.device)))
            all_preds.extend(outputs.cpu().numpy())
            all_labels.extend(labels.numpy())

    all_preds = np.array(all_preds).flatten()
    all_labels = np.array(all_labels).flatten()

    valid_mask = ~np.isnan(all_preds)
    all_preds = all_preds[valid_mask]
    all_labels = all_labels[valid_mask]

    if len(all_preds) < 10:
        return None

    pred_binary = (all_preds > 0.5).astype(int)
    acc = accuracy_score(all_labels, pred_binary)
    auc = roc_auc_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, pred_binary)
    tn, fp, fn, tp = cm.ravel()
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0
    rec = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0

    print(f"\n✅ Results: Acc={acc:.4f}, AUC={auc:.4f}, F1={f1:.4f}")
    return {'seed': seed, 'accuracy': acc, 'auc': auc, 'precision': prec, 'recall': rec, 'f1_score': f1}

# ============================================================================
# RUN 5 SEEDS
# ============================================================================

print("\n" + "="*80)
print("🚀 RUNNING 5 SEEDS")
print("="*80)

seeds = [42, 123, 456, 789, 2024]
results = []

for seed in seeds:
    result = evaluate_seed(seed, real_videos, fake_videos, model, config)
    if result:
        results.append(result)

if results:
    df = pd.DataFrame(results)
    print("\n" + "="*80)
    print("📊 FINAL STATISTICS")
    print("="*80)
    for metric in ['accuracy', 'auc', 'precision', 'recall', 'f1_score']:
        print(f"{metric.upper():12s}: {df[metric].mean()*100:5.2f}% ± {df[metric].std()*100:4.2f}%")

    df.to_csv('/content/lstm_fakeavceleb_5runs.csv', index=False)
    print("\n✅ Saved: /content/lstm_fakeavceleb_5runs.csv")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✅ Libraries imported

📦 EXTRACTING FAKEAVCELEB
✅ Extracted
✅ Config loaded

📂 LOCATING VIDEOS
Real: 500, Fake: 20544

🔧 LOADING MODEL
Downloading: "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_small-047dcff4.pth


100%|██████████| 9.83M/9.83M [00:00<00:00, 117MB/s]


✅ Model loaded

🚀 RUNNING 5 SEEDS

🎲 SEED: 42
Samples: 500 real + 3000 fake


Seed 42: 100%|██████████| 55/55 [16:43<00:00, 18.25s/it]



✅ Results: Acc=0.8471, AUC=0.8421, F1=0.9066

🎲 SEED: 123
Samples: 500 real + 3000 fake


Seed 123: 100%|██████████| 55/55 [16:25<00:00, 17.92s/it]



✅ Results: Acc=0.8469, AUC=0.8406, F1=0.9064

🎲 SEED: 456
Samples: 500 real + 3000 fake


Seed 456: 100%|██████████| 55/55 [16:24<00:00, 17.89s/it]



✅ Results: Acc=0.8477, AUC=0.8425, F1=0.9070

🎲 SEED: 789
Samples: 500 real + 3000 fake


Seed 789: 100%|██████████| 55/55 [16:24<00:00, 17.89s/it]



✅ Results: Acc=0.8526, AUC=0.8498, F1=0.9102

🎲 SEED: 2024
Samples: 500 real + 3000 fake


Seed 2024: 100%|██████████| 55/55 [16:21<00:00, 17.85s/it]


✅ Results: Acc=0.8491, AUC=0.8439, F1=0.9079

📊 FINAL STATISTICS
ACCURACY    : 84.87% ± 0.23%
AUC         : 84.38% ± 0.36%
PRECISION   : 95.17% ± 0.01%
RECALL      : 86.75% ± 0.27%
F1_SCORE    : 90.76% ± 0.16%

✅ Saved: /content/lstm_fakeavceleb_5runs.csv



