In [1]:
"""
EfficientNet-B3 Deepfake Detector
- Freeze → Unfreeze fine-tuning
- Per-epoch logging: Acc, Prec/Rec/F1 (macro & weighted), AUC
- Saves graphs: loss/acc/F1/AUC curves
- Saves per-phase, per-epoch: Confusion Matrix + ROC curve
- Final full evaluation plots: Confusion, ROC, PR, Confidence hist, F1 vs Threshold
"""

# =========================
# Imports & setup
# =========================
import os, random, time, json, itertools, warnings
from pathlib import Path
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast

import torchvision
from torchvision import models, transforms

from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
    confusion_matrix, roc_curve, precision_recall_curve, classification_report,
    precision_recall_fscore_support
)

# =========================
# Config
# =========================
class Config:
    # Data
    TRAIN_FAKE_PATH = r"C:\Users\Gan\AI\DataSet\Train\Fake"
    TRAIN_REAL_PATH = r"C:\Users\Gan\AI\DataSet\Train\Real"
    VAL_FAKE_PATH   = r"C:\Users\Gan\AI\DataSet\Validation\Fake"
    VAL_REAL_PATH   = r"C:\Users\Gan\AI\DataSet\Validation\Real"
    TEST_FAKE_PATH  = r"C:\Users\Gan\AI\DataSet\Test\Fake"
    TEST_REAL_PATH  = r"C:\Users\Gan\AI\DataSet\Test\Real"

    # Outputs (your required paths)
    CHECKPOINT_PATH = r"C:\Users\Gan\AI Testing\checkpoint"
    RESULT_GRAPH_PATH = r"C:\Users\Gan\AI Testing\EfficientB3_result_graph"

    # Model / training
    MODEL_NAME   = "efficientnet_b3"
    IMAGE_SIZE   = 300
    BATCH_SIZE   = 16
    NUM_EPOCHS   = 10
    LEARNING_RATE = 1e-4

    # Freeze → unfreeze
    FREEZE_BACKBONE_EPOCHS = 1
    USE_DISCRIM_LR_AFTER_UNFREEZE = True
    BACKBONE_LR_MULT = 0.3          # smaller LR for backbone
    HEAD_LR_MULT     = 3.0          # larger LR for head

    # Data usage
    USE_FULL_DATA = False           # quick debug: 20% train
    DATA_FRACTION = 0.2
    VAL_FRACTION  = 1.0

    # System
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    NUM_WORKERS = 0
    PIN_MEMORY = True
    USE_MIXED_PRECISION = True

    # Early stopping (by F1)
    EARLY_STOPPING_PATIENCE = 7

# =========================
# Dataset
# =========================
class DeepfakeDataset(Dataset):
    def __init__(self, fake_path, real_path, transform=None, subset_fraction=1.0):
        self.transform = transform
        self.images, self.labels = [], []

        def list_imgs(p):
            if not os.path.exists(p): return []
            return [os.path.join(p, f) for f in os.listdir(p)
                    if f.lower().endswith(('.jpg','.jpeg','.png','.bmp','.webp','.tif','.tiff'))]

        fake_files = list_imgs(fake_path)
        real_files = list_imgs(real_path)

        if subset_fraction < 1.0:
            random.shuffle(fake_files); random.shuffle(real_files)
            fake_files = fake_files[:int(len(fake_files)*subset_fraction)]
            real_files = real_files[:int(len(real_files)*subset_fraction)]

        # label: Real=0, Fake=1
        self.images.extend(fake_files); self.labels.extend([1]*len(fake_files))
        self.images.extend(real_files); self.labels.extend([0]*len(real_files))

        combined = list(zip(self.images, self.labels))
        random.shuffle(combined)
        self.images, self.labels = zip(*combined) if combined else ([], [])

        print(f"  Loaded {len(fake_files):,} fake, {len(real_files):,} real (total {len(self.images):,})")

    def __len__(self): return len(self.images)

    def __getitem__(self, idx):
        for _ in range(3):
            try:
                image = Image.open(self.images[idx]).convert('RGB')
                if self.transform: image = self.transform(image)
                return image, self.labels[idx]
            except Exception:
                idx = (idx + 1) % len(self.images)
        return torch.zeros(3, 300, 300), self.labels[idx]

# =========================
# Model
# =========================
class DeepfakeDetector(nn.Module):
    def __init__(self, model_name="efficientnet_b3", num_classes=2):
        super().__init__()
        self.model_name = model_name

        if model_name == "efficientnet_b3":
            weights = models.EfficientNet_B3_Weights.IMAGENET1K_V1
            self.backbone = models.efficientnet_b3(weights=weights)
            in_features = self.backbone.classifier[1].in_features
            self.backbone.classifier = nn.Sequential(
                nn.Dropout(0.3),
                nn.Linear(in_features, 512),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(512),
                nn.Dropout(0.2),
                nn.Linear(512, num_classes)
            )
        elif model_name == "efficientnet_b0":
            weights = models.EfficientNet_B0_Weights.IMAGENET1K_V1
            self.backbone = models.efficientnet_b0(weights=weights)
            in_features = self.backbone.classifier[1].in_features
            self.backbone.classifier = nn.Sequential(
                nn.Dropout(0.3),
                nn.Linear(in_features, 512),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(512),
                nn.Dropout(0.2),
                nn.Linear(512, num_classes)
            )
        elif model_name == "resnet50":
            weights = models.ResNet50_Weights.IMAGENET1K_V2
            self.backbone = models.resnet50(weights=weights)
            in_features = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()
            self.classifier = nn.Sequential(
                nn.Dropout(0.3),
                nn.Linear(in_features, 512),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(512),
                nn.Dropout(0.2),
                nn.Linear(512, num_classes)
            )
        else:
            raise ValueError(f"Unknown model: {model_name}")

    def forward(self, x):
        if "resnet" in self.model_name:
            x = self.backbone(x)
            return self.classifier(x)
        else:
            return self.backbone(x)

    def set_backbone_trainable(self, trainable: bool):
        # freeze/unfreeze backbone features
        if "resnet" in self.model_name:
            modules = [self.backbone]
        else:
            modules = [self.backbone.features]
        for m in modules:
            for p in m.parameters():
                p.requires_grad = trainable
        # keep classifier trainable
        if "resnet" in self.model_name:
            for p in self.classifier.parameters(): p.requires_grad = True
        else:
            for p in self.backbone.classifier.parameters(): p.requires_grad = True

def count_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# =========================
# Optimizer & scheduler
# =========================
def make_optimizer_and_scheduler(model, cfg, steps_per_epoch, total_epochs, phase):
    trainable = [p for p in model.parameters() if p.requires_grad]

    if phase == 2 and cfg.USE_DISCRIM_LR_AFTER_UNFREEZE:
        head_params, body_params = [], []
        for n,p in model.named_parameters():
            if not p.requires_grad: continue
            if ("classifier" in n) or ("fc" in n):
                head_params.append(p)
            else:
                body_params.append(p)
        optimizer = optim.AdamW([
            {"params": body_params, "lr": cfg.LEARNING_RATE * cfg.BACKBONE_LR_MULT},
            {"params": head_params, "lr": cfg.LEARNING_RATE * cfg.HEAD_LR_MULT},
        ], weight_decay=1e-4)
        scheduler = optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=[cfg.LEARNING_RATE * cfg.BACKBONE_LR_MULT,
                    cfg.LEARNING_RATE * cfg.HEAD_LR_MULT],
            steps_per_epoch=steps_per_epoch,
            epochs=total_epochs,
            pct_start=0.1
        )
    else:
        optimizer = optim.AdamW(trainable, lr=cfg.LEARNING_RATE, weight_decay=1e-4)
        scheduler = optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=cfg.LEARNING_RATE * 3,
            steps_per_epoch=steps_per_epoch,
            epochs=total_epochs,
            pct_start=0.1
        )
    return optimizer, scheduler

# =========================
# Train / Validate
# =========================
def train_epoch(model, loader, criterion, optimizer, scaler, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    pbar = tqdm(loader, desc="Training", ncols=100)
    for images, labels in pbar:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with autocast(enabled=scaler is not None):
            outputs = model(images)
            loss = criterion(outputs, labels)

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer); scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        running_loss += loss.item()
        pbar.set_postfix({"Loss": f"{running_loss/len(loader):.3f}",
                          "Acc": f"{correct/max(1,total):.3f}"})
    return running_loss / max(1,len(loader)), correct / max(1,total)

@torch.no_grad()
def validate_epoch(model, loader, criterion, device, desc="Validation"):
    model.eval()
    running_loss = 0.0
    all_preds, all_labels, all_probs = [], [], []

    pbar = tqdm(loader, desc=desc, ncols=100)
    for images, labels in pbar:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        outputs = model(images)
        loss = criterion(outputs, labels)
        running_loss += loss.item()

        probs = F.softmax(outputs, dim=1)[:, 1]  # P(Fake)
        _, predicted = torch.max(outputs, 1)

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs .extend(probs.cpu().numpy())

        acc_live = accuracy_score(all_labels, all_preds) if len(all_labels)>0 else 0.0
        pbar.set_postfix({"Loss": f"{running_loss/max(1,len(all_labels)):.3f}",
                          "Acc": f"{acc_live:.3f}"})

    # Basic (binary)
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, zero_division=0)
    recall    = recall_score(all_labels, all_preds, zero_division=0)
    f1        = f1_score(all_labels, all_preds, zero_division=0)
    try:
        auc = roc_auc_score(all_labels, all_probs)
    except:
        auc = 0.5

    # Macro & Weighted
    prec_macro, rec_macro, f1_macro, _ = precision_recall_fscore_support(
        all_labels, all_preds, average="macro", zero_division=0)
    prec_weight, rec_weight, f1_weight, _ = precision_recall_fscore_support(
        all_labels, all_preds, average="weighted", zero_division=0)

    return {
        "loss": running_loss / max(1,len(loader)),
        "acc": accuracy, "prec": precision, "rec": recall, "f1": f1, "auc": auc,
        "prec_macro": prec_macro, "rec_macro": rec_macro, "f1_macro": f1_macro,
        "prec_weight": prec_weight, "rec_weight": rec_weight, "f1_weight": f1_weight,
        "labels": np.array(all_labels), "preds": np.array(all_preds), "probs": np.array(all_probs)
    }

# =========================
# Plot helpers
# =========================
def plot_lines(x, ys, labels, title, ylabel, out_path):
    plt.figure(figsize=(8,5))
    for y, lab in zip(ys, labels):
        plt.plot(x, y, marker="o", label=lab)
    plt.title(title); plt.xlabel("Epoch"); plt.ylabel(ylabel)
    plt.grid(True, linestyle="--", linewidth=0.5); plt.legend()
    plt.savefig(out_path, dpi=160, bbox_inches="tight"); plt.close()

def save_confusion_and_roc(y_true, y_prob, phase_label, epoch, out_dir, threshold=0.5):
    y_pred = (y_prob >= threshold).astype(int)

    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred, labels=[0,1])
    fig, ax = plt.subplots(figsize=(6,5))
    im = ax.imshow(cm, cmap="Blues")
    ax.figure.colorbar(im, ax=ax)
    ax.set_xticks([0,1]); ax.set_yticks([0,1])
    ax.set_xticklabels(["Real","Fake"]); ax.set_yticklabels(["Real","Fake"])
    ax.set_xlabel("Predicted"); ax.set_ylabel("Actual")
    ax.set_title(f"Confusion Matrix — {phase_label} (Epoch {epoch})")
    thresh = cm.max()/2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        ax.text(j, i, format(cm[i, j], 'd'),
                ha="center", va="center",
                color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    cm_path = os.path.join(out_dir, f"confusion_{phase_label.lower()}_epoch{epoch}.png")
    plt.savefig(cm_path, dpi=160, bbox_inches="tight"); plt.close(fig)

    # ROC
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    auc_val = roc_auc_score(y_true, y_prob)
    fig = plt.figure(figsize=(6,5))
    plt.plot(fpr, tpr, lw=2, label=f"AUC = {auc_val:.3f}")
    plt.plot([0,1],[0,1], linestyle="--")
    plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve — {phase_label} (Epoch {epoch})"); plt.legend(loc="lower right")
    roc_path = os.path.join(out_dir, f"roc_{phase_label.lower()}_epoch{epoch}.png")
    plt.savefig(roc_path, dpi=160, bbox_inches="tight"); plt.close(fig)

    return cm_path, roc_path, float(auc_val)

# =========================
# Metrics logger
# =========================
class MetricsLogger:
    def __init__(self, out_dir):
        self.out_dir = out_dir
        self.rows = []

    def log(self, epoch, phase, stats):
        r = dict(epoch=epoch, phase=phase)
        r.update({
            "train_loss": stats.get("train_loss", np.nan),
            "train_acc":  stats.get("train_acc",  np.nan),
            "val_loss":   stats["val_loss"],
            "val_acc":    stats["val_acc"],
            "val_prec":   stats["val_prec"],
            "val_rec":    stats["val_rec"],
            "val_f1":     stats["val_f1"],
            "val_auc":    stats["val_auc"],
            "val_prec_macro":   stats["val_prec_macro"],
            "val_rec_macro":    stats["val_rec_macro"],
            "val_f1_macro":     stats["val_f1_macro"],
            "val_prec_weight":  stats["val_prec_weight"],
            "val_rec_weight":   stats["val_rec_weight"],
            "val_f1_weight":    stats["val_f1_weight"],
        })
        self.rows.append(r)

    def to_df(self):
        return pd.DataFrame(self.rows)

    def save_csv(self, fname="metrics_per_epoch.csv"):
        p = os.path.join(self.out_dir, fname)
        self.to_df().to_csv(p, index=False)
        return p

# =========================
# Full evaluation (final)
# =========================
@torch.no_grad()
def evaluate_full(model, loader, device, out_dir, class_names=("Real","Fake"), threshold=0.5):
    model.eval()
    y_true, y_prob = [], []

    for images, labels in tqdm(loader, desc="Full Eval", ncols=100):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        logits = model(images)
        probs = F.softmax(logits, dim=1)[:,1]
        y_true.extend(labels.cpu().numpy().tolist())
        y_prob.extend(probs.cpu().numpy().tolist())

    y_true = np.array(y_true); y_prob = np.array(y_prob)
    y_pred = (y_prob >= threshold).astype(int)

    # Save raw arrays
    np.save(os.path.join(out_dir, "y_true.npy"), y_true)
    np.save(os.path.join(out_dir, "y_pred.npy"), y_pred)
    np.save(os.path.join(out_dir, "y_prob_fake.npy"), y_prob)

    # Metrics
    acc  = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec  = recall_score(y_true, y_pred, zero_division=0)
    f1   = f1_score(y_true, y_pred, zero_division=0)
    auc  = roc_auc_score(y_true, y_prob)

    # Macro & weighted
    prec_macro, rec_macro, f1_macro, _ = precision_recall_fscore_support(
        y_true, y_pred, average="macro", zero_division=0)
    prec_weight, rec_weight, f1_weight, _ = precision_recall_fscore_support(
        y_true, y_pred, average="weighted", zero_division=0)

    # Report
    report = classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
    with open(os.path.join(out_dir, "classification_report.txt"), "w", encoding="utf-8") as f:
        f.write(report + "\n")
        f.write(f"\nOverall:\nacc={acc:.4f}  prec={prec:.4f}  rec={rec:.4f}  f1={f1:.4f}  auc={auc:.4f}\n")
        f.write(f"macro: prec={prec_macro:.4f} rec={rec_macro:.4f} f1={f1_macro:.4f}\n")
        f.write(f"weighted: prec={prec_weight:.4f} rec={rec_weight:.4f} f1={f1_weight:.4f}\n")

    # Plots
    save_confusion_and_roc(y_true, y_prob, "Final", "best", out_dir, threshold=threshold)

    # PR curve
    ps, rs, _ = precision_recall_curve(y_true, y_prob)
    fig = plt.figure(figsize=(6,5))
    plt.plot(rs, ps, lw=2)
    plt.xlabel("Recall"); plt.ylabel("Precision"); plt.title("Precision–Recall Curve (Final)")
    plt.grid(True, linestyle="--", linewidth=0.5)
    plt.savefig(os.path.join(out_dir, "pr_curve.png"), dpi=160, bbox_inches="tight"); plt.close(fig)

    # Confidence histogram
    fig = plt.figure(figsize=(8,4))
    plt.hist([p for p,t in zip(y_prob, y_true) if t==0], bins=30, alpha=0.6, label="P(Fake) | Real")
    plt.hist([p for p,t in zip(y_prob, y_true) if t==1], bins=30, alpha=0.6, label="P(Fake) | Fake")
    plt.axvline(threshold, color="k", linestyle="--", label=f"Threshold={threshold:.2f}")
    plt.xlabel("P(Fake)"); plt.ylabel("Count"); plt.title("Confidence Histogram (Final)"); plt.legend()
    plt.grid(True, linestyle="--", linewidth=0.5)
    plt.savefig(os.path.join(out_dir, "confidence_hist.png"), dpi=160, bbox_inches="tight"); plt.close(fig)

    # F1 vs threshold
    ths = np.linspace(0.05, 0.95, 19)
    f1s = []
    for th in ths:
        preds = (y_prob >= th).astype(int)
        f1s.append(f1_score(y_true, preds, zero_division=0))
    fig = plt.figure(figsize=(6,5))
    plt.plot(ths, f1s, marker="o")
    plt.xlabel("Threshold on P(Fake)"); plt.ylabel("F1"); plt.title("F1 vs Threshold (Final)")
    plt.grid(True, linestyle="--", linewidth=0.5)
    plt.savefig(os.path.join(out_dir, "f1_vs_threshold.png"), dpi=160, bbox_inches="tight"); plt.close(fig)

    return {
        "acc":acc, "prec":prec, "rec":rec, "f1":f1, "auc":auc,
        "prec_macro":prec_macro, "rec_macro":rec_macro, "f1_macro":f1_macro,
        "prec_weight":prec_weight, "rec_weight":rec_weight, "f1_weight":f1_weight
    }

# =========================
# Main
# =========================
def main():
    cfg = Config()
    Path(cfg.CHECKPOINT_PATH).mkdir(parents=True, exist_ok=True)
    Path(cfg.RESULT_GRAPH_PATH).mkdir(parents=True, exist_ok=True)

    print("="*80)
    print("DEEPFAKE DETECTION TRAINING — FREEZE → UNFREEZE FINE-TUNING")
    print("="*80)
    print(f"✓ Checkpoint path: {cfg.CHECKPOINT_PATH}")
    print(f"✓ Result path    : {cfg.RESULT_GRAPH_PATH}")
    print("\nSystem Info:")
    print("  Device:", cfg.DEVICE)
    if torch.cuda.is_available():
        print("  GPU:", torch.cuda.get_device_name(0))
        print(f"  VRAM: {torch.cuda.get_device_properties(0).total_memory/1024**3:.2f} GB")

    random.seed(42); np.random.seed(42); torch.manual_seed(42)
    if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True

    print("\nModel Configuration:")
    print(f"  Model         : {cfg.MODEL_NAME}")
    print(f"  Image Size    : {cfg.IMAGE_SIZE}")
    print(f"  Batch Size    : {cfg.BATCH_SIZE}")
    print(f"  Epochs        : {cfg.NUM_EPOCHS}")
    print(f"  LR            : {cfg.LEARNING_RATE}")
    print(f"  Freeze epochs : {cfg.FREEZE_BACKBONE_EPOCHS}")

    # Transforms
    train_tf = transforms.Compose([
        transforms.Resize((cfg.IMAGE_SIZE, cfg.IMAGE_SIZE)),
        transforms.RandomHorizontalFlip(0.5),
        transforms.ColorJitter(0.2,0.2,0.2,0.05),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])
    eval_tf = transforms.Compose([
        transforms.Resize((cfg.IMAGE_SIZE, cfg.IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])

    frac_train = 1.0 if cfg.USE_FULL_DATA else cfg.DATA_FRACTION
    frac_val   = cfg.VAL_FRACTION

    print("\n" + "="*60)
    print("PREPARING DATASETS")
    print("="*60)
    print(f"Using {frac_train*100:.0f}% of TRAIN, {frac_val*100:.0f}% of VAL")

    train_dataset = DeepfakeDataset(cfg.TRAIN_FAKE_PATH, cfg.TRAIN_REAL_PATH, transform=train_tf, subset_fraction=frac_train)
    val_dataset   = DeepfakeDataset(cfg.VAL_FAKE_PATH,   cfg.VAL_REAL_PATH,   transform=eval_tf,  subset_fraction=frac_val)

    train_loader = DataLoader(train_dataset, batch_size=cfg.BATCH_SIZE, shuffle=True,
                              num_workers=cfg.NUM_WORKERS, pin_memory=cfg.PIN_MEMORY, drop_last=True)
    val_loader   = DataLoader(val_dataset,   batch_size=cfg.BATCH_SIZE*2, shuffle=False,
                              num_workers=cfg.NUM_WORKERS, pin_memory=cfg.PIN_MEMORY)

    print("\nDataset Summary:")
    print(f"  Train: {len(train_dataset):,} images")
    print(f"  Val  : {len(val_dataset):,} images")

    # Model
    print("\n" + "="*60)
    print("INITIALIZING MODEL")
    print("="*60)
    model = DeepfakeDetector(model_name=cfg.MODEL_NAME).to(cfg.DEVICE)
    print(f"Total params     : {sum(p.numel() for p in model.parameters()):,}")

    # Phase split
    phase1_epochs = min(cfg.FREEZE_BACKBONE_EPOCHS, cfg.NUM_EPOCHS)
    phase2_epochs = cfg.NUM_EPOCHS - phase1_epochs

    # Phase 1: freeze
    if phase1_epochs > 0:
        model.set_backbone_trainable(False)
        print(f"[Phase 1] Backbone FROZEN for {phase1_epochs} epoch(s). Trainable params: {count_trainable_params(model):,}")
    else:
        print("[Phase 1] No freezing (skip).")

    scaler   = GradScaler() if cfg.USE_MIXED_PRECISION and cfg.DEVICE.type == "cuda" else None
    criterion= nn.CrossEntropyLoss()
    p1_epochs= phase1_epochs if phase1_epochs > 0 else cfg.NUM_EPOCHS
    optimizer, scheduler = make_optimizer_and_scheduler(model, cfg, len(train_loader), p1_epochs, phase=1)
    logger = MetricsLogger(cfg.RESULT_GRAPH_PATH)

    print("\n" + "="*60)
    print("STARTING TRAINING")
    print("="*60)

    best_val_f1, patience, epoch = 0.0, 0, 0
    epochs_axis, train_losses, val_losses, val_accs, val_f1s, val_aucs = [], [], [], [], [], []

    # ---------- Phase 1 loop ----------
    for e in range(p1_epochs):
        epoch += 1
        print(f"\n{'='*50}\nEPOCH {epoch}/{cfg.NUM_EPOCHS}  [Phase 1: frozen]\n{'='*50}")
        print(f"LR group(s): {[pg['lr'] for pg in optimizer.param_groups]}")

        t0 = time.time()
        tr_loss, tr_acc = train_epoch(model, train_loader, criterion, optimizer, scaler, cfg.DEVICE)
        scheduler.step()
        val_stats = validate_epoch(model, val_loader, criterion, cfg.DEVICE, desc="Validation")
        dt = time.time()-t0

        print("\nResults:")
        print(f"  Train Loss: {tr_loss:.4f} | Train Acc: {tr_acc:.3f}")
        print(f"  Val   Loss: {val_stats['loss']:.4f} | Val Acc: {val_stats['acc']:.3f} | "
              f"F1: {val_stats['f1']:.3f} | AUC: {val_stats['auc']:.3f}")
        print(f"  Time: {dt:.1f}s")

        # Log series
        epochs_axis.append(epoch)
        train_losses.append(tr_loss); val_losses.append(val_stats["loss"])
        val_accs.append(val_stats["acc"]); val_f1s.append(val_stats["f1"]); val_aucs.append(val_stats["auc"])

        # Save per-phase confusion & ROC
        save_confusion_and_roc(val_stats["labels"], val_stats["probs"], "Phase1", epoch, cfg.RESULT_GRAPH_PATH, threshold=0.5)

        # Save best by F1
        if val_stats["f1"] > best_val_f1:
            best_val_f1, patience = val_stats["f1"], 0
            torch.save({"epoch": epoch, "model_state_dict": model.state_dict(),
                        "val_acc": val_stats["acc"], "val_f1": val_stats["f1"]},
                       os.path.join(cfg.CHECKPOINT_PATH, f"best_{cfg.MODEL_NAME}.pth"))
            print(f"  ✓ Saved best model (F1: {val_stats['f1']:.3f})")
        else:
            patience += 1
            if patience >= cfg.EARLY_STOPPING_PATIENCE:
                print(f"\n⚠ Early stopping at epoch {epoch} (Phase 1)")
                break

        # log row
        logger.log(epoch, "phase1", {
            "train_loss": tr_loss, "train_acc": tr_acc,
            "val_loss": val_stats["loss"], "val_acc": val_stats["acc"],
            "val_prec": val_stats["prec"], "val_rec": val_stats["rec"], "val_f1": val_stats["f1"],
            "val_auc": val_stats["auc"],
            "val_prec_macro": val_stats["prec_macro"], "val_rec_macro": val_stats["rec_macro"], "val_f1_macro": val_stats["f1_macro"],
            "val_prec_weight": val_stats["prec_weight"], "val_rec_weight": val_stats["rec_weight"], "val_f1_weight": val_stats["f1_weight"],
        })

    # ---------- Phase 2 loop ----------
    if phase2_epochs > 0 and patience < cfg.EARLY_STOPPING_PATIENCE:
        model.set_backbone_trainable(True)
        print(f"\n[Phase 2] Backbone UNFROZEN for {phase2_epochs} epoch(s). Trainable params: {count_trainable_params(model):,}")
        optimizer, scheduler = make_optimizer_and_scheduler(model, cfg, len(train_loader), phase2_epochs, phase=2)
        patience = 0

        for e in range(phase2_epochs):
            epoch += 1
            print(f"\n{'='*50}\nEPOCH {epoch}/{cfg.NUM_EPOCHS}  [Phase 2: fine-tune]\n{'='*50}")
            print(f"LR group(s): {[pg['lr'] for pg in optimizer.param_groups]}")

            t0 = time.time()
            tr_loss, tr_acc = train_epoch(model, train_loader, criterion, optimizer, scaler, cfg.DEVICE)
            scheduler.step()
            val_stats = validate_epoch(model, val_loader, criterion, cfg.DEVICE, desc="Validation")
            dt = time.time()-t0

            print("\nResults:")
            print(f"  Train Loss: {tr_loss:.4f} | Train Acc: {tr_acc:.3f}")
            print(f"  Val   Loss: {val_stats['loss']:.4f} | Val Acc: {val_stats['acc']:.3f} | "
                  f"F1: {val_stats['f1']:.3f} | AUC: {val_stats['auc']:.3f}")
            print(f"  Time: {dt:.1f}s")

            # series
            epochs_axis.append(epoch)
            train_losses.append(tr_loss); val_losses.append(val_stats["loss"])
            val_accs.append(val_stats["acc"]); val_f1s.append(val_stats["f1"]); val_aucs.append(val_stats["auc"])

            # per-epoch confusion & ROC
            save_confusion_and_roc(val_stats["labels"], val_stats["probs"], "Phase2", epoch, cfg.RESULT_GRAPH_PATH, threshold=0.5)

            # save best by F1
            if val_stats["f1"] > best_val_f1:
                best_val_f1, patience = val_stats["f1"], 0
                torch.save({"epoch": epoch, "model_state_dict": model.state_dict(),
                            "val_acc": val_stats["acc"], "val_f1": val_stats["f1"]},
                           os.path.join(cfg.CHECKPOINT_PATH, f"best_{cfg.MODEL_NAME}.pth"))
                print(f"  ✓ Saved best model (F1: {val_stats['f1']:.3f})")
            else:
                patience += 1
                if patience >= cfg.EARLY_STOPPING_PATIENCE:
                    print(f"\n⚠ Early stopping at epoch {epoch} (Phase 2)")
                    break

            # log
            logger.log(epoch, "phase2", {
                "train_loss": tr_loss, "train_acc": tr_acc,
                "val_loss": val_stats["loss"], "val_acc": val_stats["acc"],
                "val_prec": val_stats["prec"], "val_rec": val_stats["rec"], "val_f1": val_stats["f1"],
                "val_auc": val_stats["auc"],
                "val_prec_macro": val_stats["prec_macro"], "val_rec_macro": val_stats["rec_macro"], "val_f1_macro": val_stats["f1_macro"],
                "val_prec_weight": val_stats["prec_weight"], "val_rec_weight": val_stats["rec_weight"], "val_f1_weight": val_stats["f1_weight"],
            })

    # ---------- Save curves & CSV ----------
    print("\n" + "="*60)
    print("TRAINING COMPLETE - SAVING CURVES & FULL EVAL")
    print("="*60)
    csv_path = logger.save_csv()

    plot_lines(epochs_axis, [train_losses, val_losses], ["Train Loss","Val Loss"],
               "Training vs Validation Loss", "Loss",
               os.path.join(cfg.RESULT_GRAPH_PATH, "loss_curve.png"))
    plot_lines(epochs_axis, [val_accs], ["Val Accuracy"],
               "Validation Accuracy", "Accuracy",
               os.path.join(cfg.RESULT_GRAPH_PATH, "accuracy_curve.png"))
    plot_lines(epochs_axis, [val_f1s], ["Val F1"],
               "Validation F1", "F1",
               os.path.join(cfg.RESULT_GRAPH_PATH, "f1_curve.png"))
    plot_lines(epochs_axis, [val_aucs], ["Val AUC"],
               "Validation AUC (ROC)", "AUC",
               os.path.join(cfg.RESULT_GRAPH_PATH, "auc_curve.png"))

    # ---------- Load best & full evaluation ----------
    best_ckpt = os.path.join(cfg.CHECKPOINT_PATH, f"best_{cfg.MODEL_NAME}.pth")
    if os.path.isfile(best_ckpt):
        state = torch.load(best_ckpt, map_location=cfg.DEVICE)
        sd = state.get("model_state_dict", state.get("model_state"))
        model.load_state_dict(sd)
        final_scores = evaluate_full(model, val_loader, cfg.DEVICE, cfg.RESULT_GRAPH_PATH, class_names=("Real","Fake"), threshold=0.5)
        print("\nFull evaluation on best checkpoint:")
        print(final_scores)
    else:
        print(f"\n⚠ Best checkpoint not found at {best_ckpt}; skipping full evaluation.")

    print(f"\n✓ Best model saved to: {cfg.CHECKPOINT_PATH}")
    print(f"✓ Graphs & CSV saved to: {cfg.RESULT_GRAPH_PATH}")
    print("="*60); print("DONE"); print("="*60)

# =========================
# Single image test (optional)
# =========================
def single_image_test(img_path, threshold=0.5):
    cfg = Config()
    device = cfg.DEVICE
    model = DeepfakeDetector(model_name=cfg.MODEL_NAME).to(device)
    ckpt_path = os.path.join(cfg.CHECKPOINT_PATH, f"best_{cfg.MODEL_NAME}.pth")
    if not os.path.isfile(ckpt_path):
        raise FileNotFoundError(f"Best checkpoint not found at: {ckpt_path}")
    state = torch.load(ckpt_path, map_location=device)
    sd = state.get("model_state_dict", state.get("model_state"))
    model.load_state_dict(sd); model.eval()

    tf = transforms.Compose([
        transforms.Resize((cfg.IMAGE_SIZE, cfg.IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])

    img = Image.open(img_path).convert("RGB")
    x = tf(img).unsqueeze(0).to(device)

    t0 = time.time()
    with torch.no_grad():
        with autocast(enabled=cfg.USE_MIXED_PRECISION and device.type=="cuda"):
            logits = model(x)
            probs = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy()
    dt_ms = (time.time()-t0)*1000.0

    p_real, p_fake = float(probs[0]), float(probs[1])
    pred = "Fake" if p_fake >= threshold else "Real"

    print("\n--- Single-image prediction ---")
    print(f"Path       : {img_path}")
    print(f"Pred label : {pred}  (threshold={threshold:.2f})")
    print(f"Prob(Fake) : {p_fake:.4f}")
    print(f"Prob(Real) : {p_real:.4f}")
    print(f"Latency    : {dt_ms:.1f} ms")

    plt.figure(figsize=(5,5)); plt.imshow(img); plt.axis("off")
    plt.title(f"{pred}  |  P(Fake)={p_fake:.3f}")
    plt.show()
    return pred, p_fake

# =========================
# Run
# =========================
if __name__ == "__main__":
    main()
    # Example:
    # single_image_test(r"C:\Users\Gan\Pictures\example.jpg", threshold=0.5)

DEEPFAKE DETECTION TRAINING — FREEZE → UNFREEZE FINE-TUNING
✓ Checkpoint path: C:\Users\Gan\AI Testing\checkpoint
✓ Result path    : C:\Users\Gan\AI Testing\EfficientB3_result_graph

System Info:
  Device: cuda
  GPU: NVIDIA GeForce RTX 4050 Laptop GPU
  VRAM: 6.00 GB

Model Configuration:
  Model         : efficientnet_b3
  Image Size    : 300
  Batch Size    : 16
  Epochs        : 10
  LR            : 0.0001
  Freeze epochs : 1

PREPARING DATASETS
Using 20% of TRAIN, 100% of VAL
  Loaded 14,000 fake, 14,000 real (total 28,000)
  Loaded 19,641 fake, 19,787 real (total 39,428)

Dataset Summary:
  Train: 28,000 images
  Val  : 39,428 images

INITIALIZING MODEL
Total params     : 11,485,226
[Phase 1] Backbone FROZEN for 1 epoch(s). Trainable params: 788,994

STARTING TRAINING

EPOCH 1/10  [Phase 1: frozen]
LR group(s): [1.200000000000002e-05]


Training: 100%|██████████████████████████| 1750/1750 [08:14<00:00,  3.54it/s, Loss=0.558, Acc=0.707]
Validation: 100%|████████████████████████| 1233/1233 [08:52<00:00,  2.32it/s, Loss=0.016, Acc=0.760]



Results:
  Train Loss: 0.5579 | Train Acc: 0.707
  Val   Loss: 0.5181 | Val Acc: 0.760 | F1: 0.743 | AUC: 0.846
  Time: 1027.2s
  ✓ Saved best model (F1: 0.743)

[Phase 2] Backbone UNFROZEN for 9 epoch(s). Trainable params: 11,485,226

EPOCH 2/10  [Phase 2: fine-tune]
LR group(s): [1.1999999999999987e-06, 1.200000000000002e-05]


Training: 100%|██████████████████████████| 1750/1750 [10:43<00:00,  2.72it/s, Loss=0.440, Acc=0.793]
Validation: 100%|████████████████████████| 1233/1233 [10:03<00:00,  2.04it/s, Loss=0.018, Acc=0.814]



Results:
  Train Loss: 0.4405 | Train Acc: 0.793
  Val   Loss: 0.5749 | Val Acc: 0.814 | F1: 0.798 | AUC: 0.907
  Time: 1246.7s
  ✓ Saved best model (F1: 0.798)

EPOCH 3/10  [Phase 2: fine-tune]
LR group(s): [1.2000286828724482e-06, 1.200028682872453e-05]


Training: 100%|██████████████████████████| 1750/1750 [10:46<00:00,  2.71it/s, Loss=0.340, Acc=0.850]
Validation: 100%|████████████████████████| 1233/1233 [05:14<00:00,  3.92it/s, Loss=0.019, Acc=0.853]



Results:
  Train Loss: 0.3404 | Train Acc: 0.850
  Val   Loss: 0.5932 | Val Acc: 0.853 | F1: 0.852 | AUC: 0.933
  Time: 960.7s
  ✓ Saved best model (F1: 0.852)

EPOCH 4/10  [Phase 2: fine-tune]
LR group(s): [1.2001147313755184e-06, 1.2001147313755198e-05]


Training: 100%|██████████████████████████| 1750/1750 [07:22<00:00,  3.96it/s, Loss=0.276, Acc=0.884]
Validation: 100%|████████████████████████| 1233/1233 [04:42<00:00,  4.36it/s, Loss=0.015, Acc=0.860]



Results:
  Train Loss: 0.2760 | Train Acc: 0.884
  Val   Loss: 0.4810 | Val Acc: 0.860 | F1: 0.849 | AUC: 0.950
  Time: 725.1s

EPOCH 5/10  [Phase 2: fine-tune]
LR group(s): [1.2002581451664253e-06, 1.2002581451664267e-05]


Training: 100%|██████████████████████████| 1750/1750 [07:35<00:00,  3.84it/s, Loss=0.238, Acc=0.904]
Validation: 100%|████████████████████████| 1233/1233 [04:42<00:00,  4.36it/s, Loss=0.015, Acc=0.892]



Results:
  Train Loss: 0.2379 | Train Acc: 0.904
  Val   Loss: 0.4785 | Val Acc: 0.892 | F1: 0.891 | AUC: 0.961
  Time: 738.6s
  ✓ Saved best model (F1: 0.891)

EPOCH 6/10  [Phase 2: fine-tune]
LR group(s): [1.2004589236738452e-06, 1.2004589236738479e-05]


Training: 100%|██████████████████████████| 1750/1750 [07:18<00:00,  3.99it/s, Loss=0.213, Acc=0.915]
Validation: 100%|████████████████████████| 1233/1233 [04:41<00:00,  4.38it/s, Loss=0.048, Acc=0.893]



Results:
  Train Loss: 0.2125 | Train Acc: 0.915
  Val   Loss: 1.5335 | Val Acc: 0.893 | F1: 0.887 | AUC: 0.967
  Time: 720.5s

EPOCH 7/10  [Phase 2: fine-tune]
LR group(s): [1.2007170660979315e-06, 1.2007170660979349e-05]


Training: 100%|██████████████████████████| 1750/1750 [07:21<00:00,  3.97it/s, Loss=0.198, Acc=0.925]
Validation: 100%|████████████████████████| 1233/1233 [04:41<00:00,  4.38it/s, Loss=0.015, Acc=0.902]



Results:
  Train Loss: 0.1983 | Train Acc: 0.925
  Val   Loss: 0.4704 | Val Acc: 0.902 | F1: 0.898 | AUC: 0.971
  Time: 722.9s
  ✓ Saved best model (F1: 0.898)

EPOCH 8/10  [Phase 2: fine-tune]
LR group(s): [1.2010325714103084e-06, 1.2010325714103111e-05]


Training: 100%|██████████████████████████| 1750/1750 [07:20<00:00,  3.97it/s, Loss=0.181, Acc=0.932]
Validation: 100%|████████████████████████| 1233/1233 [04:42<00:00,  4.37it/s, Loss=0.028, Acc=0.911]



Results:
  Train Loss: 0.1805 | Train Acc: 0.932
  Val   Loss: 0.8835 | Val Acc: 0.911 | F1: 0.907 | AUC: 0.976
  Time: 722.8s
  ✓ Saved best model (F1: 0.907)

EPOCH 9/10  [Phase 2: fine-tune]
LR group(s): [1.2014054383540908e-06, 1.2014054383540935e-05]


Training: 100%|██████████████████████████| 1750/1750 [07:21<00:00,  3.97it/s, Loss=0.177, Acc=0.933]
Validation: 100%|████████████████████████| 1233/1233 [04:42<00:00,  4.36it/s, Loss=0.007, Acc=0.913]



Results:
  Train Loss: 0.1769 | Train Acc: 0.933
  Val   Loss: 0.2262 | Val Acc: 0.913 | F1: 0.909 | AUC: 0.979
  Time: 724.2s
  ✓ Saved best model (F1: 0.909)

EPOCH 10/10  [Phase 2: fine-tune]
LR group(s): [1.2018356654438742e-06, 1.2018356654438817e-05]


Training: 100%|██████████████████████████| 1750/1750 [07:22<00:00,  3.96it/s, Loss=0.165, Acc=0.942]
Validation: 100%|████████████████████████| 1233/1233 [04:43<00:00,  4.35it/s, Loss=0.023, Acc=0.906]



Results:
  Train Loss: 0.1649 | Train Acc: 0.942
  Val   Loss: 0.7472 | Val Acc: 0.906 | F1: 0.899 | AUC: 0.982
  Time: 725.9s

TRAINING COMPLETE - SAVING CURVES & FULL EVAL


Full Eval: 100%|████████████████████████████████████████████████| 1233/1233 [04:33<00:00,  4.50it/s]



Full evaluation on best checkpoint:
{'acc': 0.9133357005173988, 'prec': 0.9559865092748735, 'rec': 0.8658927753169391, 'f1': 0.9087120301354492, 'auc': 0.9787630943032425, 'prec_macro': 0.9171281099845113, 'rec_macro': 0.9131606697628816, 'f1_macro': 0.9131128044461193, 'prec_weight': 0.9169842191887216, 'rec_weight': 0.9133357005173988, 'f1_weight': 0.9131291003031083}

✓ Best model saved to: C:\Users\Gan\AI Testing\checkpoint
✓ Graphs & CSV saved to: C:\Users\Gan\AI Testing\EfficientB3_result_graph
DONE
