# GenD: Generalized Deepfake Detection — Self-Contained Training Notebook

## Cách sử dụng trên Kaggle:
1. **Upload dataset** (ảnh frames đã crop mặt) lên Kaggle Dataset
2. **Upload file JSON** mô tả dataset (xem hướng dẫn ở cell Config)
3. **Bật GPU** (Settings > Accelerator > GPU T4 x2)
4. **Bật Internet** (Settings > Internet > On) — cần lần đầu để tải CLIP weights
5. **Chạy tất cả cell** từ trên xuống dưới
6. Model sẽ được lưu tại `/kaggle/working/gend_checkpoints/`

## Cell 1 — Cài đặt thư viện

In [34]:
!pip install -q albumentations transformers lmdb

## Cell 2 — Import toàn bộ thư viện

In [35]:
import os
import sys
import math
import glob
import json
import random
import datetime
import logging
import pickle
from copy import deepcopy
from collections import defaultdict

import numpy as np
import cv2
from PIL import Image
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T

import albumentations as A
from albumentations import DualTransform, ImageOnlyTransform

from sklearn import metrics as sk_metrics

from transformers import CLIPModel

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

PyTorch: 2.8.0+cu126
CUDA available: True
GPU: Tesla T4
Using device: cuda


## Cell 3 — Cấu hình (Thay thế file YAML)

### Hướng dẫn:
- **`dataset_json_folder`**: Thư mục chứa file JSON mô tả dataset.  
  Trên Kaggle: upload file JSON vào một Kaggle Dataset, rồi trỏ đường dẫn tới đó.
- **`rgb_dir`**: Thư mục gốc chứa ảnh frames (nếu đường dẫn trong JSON là tương đối).
- **`train_dataset`**: List tên dataset đúng với key trong file JSON.
- **`test_dataset`**: List tên dataset dùng để test.

### Cấu trúc file JSON mẫu (`MyDataset.json`):
```json
{
  "MyDataset": {
    "real": {
      "train": {
        "video_001": {
          "label": "FF-real",
          "frames": ["/kaggle/input/mydata/real/001/0.png", ...]
        }
      },
      "test": { ... }
    },
    "fake": {
      "train": {
        "video_002": {
          "label": "FF-DF",
          "frames": ["/kaggle/input/mydata/fake/002/0.png", ...]
        }
      },
      "test": { ... }
    }
  }
}
```
Label phải khớp với key trong `label_dict` bên dưới.

In [50]:

config = {
    'dataset_json_folder': '/kaggle/input/deepfakebench/dataset_json',   # <-- Folder chứa file .json
    'rgb_dir':             '/kaggle/input/deepfakebench/data',  # <-- Folder gốc chứa ảnh
    'train_dataset':       ['FaceForensics++'],                  # <-- Tên dataset (key trong JSON)
    'test_dataset':        ['Celeb-DF-v2'],                      # <-- Tên dataset test

    # Model 
    'model_name':    'gend',
    'backbone_name': 'vit',
    'lambda_align':  1.0,    # Hệ số Alignment Loss
    'lambda_unif':   1.0,    # Hệ số Uniformity Loss

    # ---- Dataset ----
    'compression':    'c23',
    'train_batchSize': 16,   # Giảm xuống nếu hết VRAM (Kaggle T4 = 16GB)
    'test_batchSize':  16,
    'workers':         2,    # Kaggle thường 2-4
    'frame_num':       {'train': 8, 'test': 32},
    'resolution':      224,
    'with_mask':       False,
    'with_landmark':   False,
    'lmdb':            False,

    # ---- Data Augmentation ----
    'use_data_augmentation': True,
    'data_aug': {
        'flip_prob':       0.5,
        'rotate_prob':     0.5,
        'rotate_limit':    [-10, 10],
        'blur_prob':       0.5,
        'blur_limit':      [3, 7],
        'brightness_prob': 0.5,
        'brightness_limit':[-0.1, 0.1],
        'contrast_limit':  [-0.1, 0.1],
        'quality_lower':   40,
        'quality_upper':   100,
    },

    # ---- Normalization (CLIP) ----
    'mean': [0.48145466, 0.4578275, 0.40821073],
    'std':  [0.26862954, 0.26130258, 0.27577711],

    # ---- Optimizer ----
    'optimizer': {
        'type': 'adam',
        'adam': {
            'lr':           0.0002,
            'beta1':        0.9,
            'beta2':        0.999,
            'eps':          1e-8,
            'weight_decay': 0.0005,
            'amsgrad':      False,
        },
        'sgd': {
            'lr':           0.0002,
            'momentum':     0.9,
            'weight_decay': 0.0005,
        },
    },

    # ---- Training ----
    'lr_scheduler': None,   # None, 'step', 'cosine', 'linear'
    'nEpochs':      10,
    'start_epoch':  0,
    'save_epoch':   1,
    'rec_iter':     100,    # In log mỗi bao nhiêu iteration
    'manualSeed':   1024,
    'save_ckpt':    True,
    'save_feat':    True,

    # ---- Loss ----
    'loss_func':    'cross_entropy',

    # ---- Metric ----
    'metric_scoring': 'auc',

    # ---- CUDA ----
    'cuda': True,

    # ---- Output ----
    'log_dir': '/kaggle/working/gend_checkpoints',

    # ---- Label Mapping ----
    # Tên label trong JSON -> số (0=real, 1=fake)
    'label_dict': {
        # FF++ & FaceShifter
        'FF-real': 0, 'FF-DF': 1, 'FF-F2F': 1, 'FF-FS': 1, 'FF-NT': 1, 'FF-FH': 1, 'FF-SH': 1,
        # DFD
        'DFD_fake': 1, 'DFD_real': 0,
        # CelebDF
        'CelebDFv1_real': 0, 'CelebDFv1_fake': 1,
        'CelebDFv2_real': 0, 'CelebDFv2_fake': 1,
        # DFDCP
        'DFDCP_Real': 0, 'DFDCP_FakeA': 1, 'DFDCP_FakeB': 1,
        # DFDC
        'DFDC_Fake': 1, 'DFDC_Real': 0,
        # DeeperForensics
        'DF_fake': 1, 'DF_real': 0,
        # UADFV
        'UADFV_Fake': 1, 'UADFV_Real': 0,
        # Roop
        'roop_Real': 0, 'roop_Fake': 1,
    },
}

os.makedirs(config['log_dir'], exist_ok=True)
print("\u2705 Config loaded.")
print(f"   Train dataset: {config['train_dataset']}")
print(f"   Test dataset:  {config['test_dataset']}")
print(f"   Epochs:        {config['nEpochs']}")
print(f"   Batch size:    {config['train_batchSize']}")
print(f"   Resolution:    {config['resolution']}")

✅ Config loaded.
   Train dataset: ['FaceForensics++']
   Test dataset:  ['Celeb-DF-v2']
   Epochs:        10
   Batch size:    16
   Resolution:    224


## Cell 4 — Seed & Logger

In [37]:
# ============================================================
#  Seed để tái lập kết quả
# ============================================================
def init_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

init_seed(config['manualSeed'])
print(f"\u2705 Seed set to {config['manualSeed']}")

# ============================================================
#  Simple Logger (thay thế logger.py gốc)
# ============================================================
class NotebookLogger:
    """Logger đơn giản in ra notebook + ghi file."""
    def __init__(self, log_path=None):
        self.log_path = log_path
        if log_path:
            os.makedirs(os.path.dirname(log_path), exist_ok=True)
            self.file = open(log_path, 'a')
        else:
            self.file = None

    def info(self, msg):
        timestamp = datetime.datetime.now().strftime('%H:%M:%S')
        line = f"[{timestamp}] {msg}"
        print(line)
        if self.file:
            self.file.write(line + '\n')
            self.file.flush()

    def warning(self, msg):
        self.info(f"\u26a0\ufe0f {msg}")

    def error(self, msg):
        self.info(f"\u274c {msg}")

timenow = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
log_dir = os.path.join(config['log_dir'], f"{config['model_name']}_{timenow}")
os.makedirs(log_dir, exist_ok=True)
logger = NotebookLogger(os.path.join(log_dir, 'training.log'))
logger.info(f"Log directory: {log_dir}")

✅ Seed set to 1024
[02:56:01] Log directory: /kaggle/working/gend_checkpoints/gend_2026-02-07-02-56-01


## Cell 5 — Augmentation Helpers

Sao chép từ `training/dataset/albu.py`

In [None]:
# ============================================================
#  Nguồn: training/dataset/albu.py
# ============================================================

def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
    h, w = img.shape[:2]
    if max(w, h) == size:
        return img
    if w > h:
        scale = size / w
        h = h * scale
        w = size
    else:
        scale = size / h
        w = w * scale
        h = size
    interpolation = interpolation_up if scale > 1 else interpolation_down
    resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
    return resized


class IsotropicResize(DualTransform):
    def __init__(self, max_side, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC,
                 always_apply=False, p=1):
        # FIX: Use kwargs explicitly (Albumentations updated __init__ signature order in newer versions)
        super(IsotropicResize, self).__init__(p=p, always_apply=always_apply)
        self.max_side = max_side
        self.interpolation_down = interpolation_down
        self.interpolation_up = interpolation_up

    def apply(self, img, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, **params):
        return isotropically_resize_image(img, size=self.max_side, interpolation_down=interpolation_down,
                                          interpolation_up=interpolation_up)

    def apply_to_mask(self, img, **params):
        return self.apply(img, interpolation_down=cv2.INTER_NEAREST, interpolation_up=cv2.INTER_NEAREST, **params)

    def get_transform_init_args_names(self):
        return ("max_side", "interpolation_down", "interpolation_up")


print("\u2705 Augmentation helpers loaded.")

✅ Augmentation helpers loaded.


## Cell 6 — Dataset Class

Sao chép từ `training/dataset/abstract_dataset.py`.  
- Đã bỏ `from .albu import IsotropicResize` (dùng class ở cell trên).  
- Đã bỏ LMDB-related code cho gọn (giữ `lmdb=False` trong config).

In [None]:
# ============================================================
#  Nguồn: training/dataset/abstract_dataset.py
#  Đã chỉnh sửa: bỏ relative import, bỏ LMDB, inline IsotropicResize
#  FIX: __getitem__ trả về 4 giá trị, collate_fn khớp với project gốc
# ============================================================

FFpp_pool = ['FaceForensics++', 'FaceShifter', 'DeepFakeDetection', 'FF-DF', 'FF-F2F', 'FF-FS', 'FF-NT']


class DeepfakeAbstractBaseDataset(Dataset):
    """
    Dataset class dùng cho training GenD.
    Đọc ảnh từ danh sách frame trong file JSON.
    """
    def __init__(self, config=None, mode='train'):
        self.config = config
        self.mode = mode
        self.compression = config['compression']
        self.frame_num = config['frame_num'][mode]
        self.image_list = []
        self.label_list = []

        if mode == 'train':
            dataset_list = config['train_dataset']
            image_list, label_list = [], []
            for one_data in dataset_list:
                tmp_image, tmp_label, tmp_name = self.collect_img_and_label_for_one_dataset(one_data)
                image_list.extend(tmp_image)
                label_list.extend(tmp_label)
        elif mode == 'test':
            one_data = config['test_dataset']
            image_list, label_list, _ = self.collect_img_and_label_for_one_dataset(one_data)
        else:
            raise NotImplementedError('Only train and test modes are supported.')

        assert len(image_list) != 0, f"Collect nothing for {mode} mode!"
        self.image_list, self.label_list = image_list, label_list

        self.data_dict = {
            'image': self.image_list,
            'label': self.label_list,
        }

        self.transform = self.init_data_aug_method()

    def init_data_aug_method(self):
        trans = A.Compose([
            A.HorizontalFlip(p=self.config['data_aug']['flip_prob']),
            A.Rotate(limit=self.config['data_aug']['rotate_limit'], p=self.config['data_aug']['rotate_prob']),
            A.GaussianBlur(blur_limit=self.config['data_aug']['blur_limit'], p=self.config['data_aug']['blur_prob']),
            A.OneOf([
                IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
                IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR),
                IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR),
            ], p=1),
            A.OneOf([
                A.RandomBrightnessContrast(
                    brightness_limit=self.config['data_aug']['brightness_limit'],
                    contrast_limit=self.config['data_aug']['contrast_limit']),
                A.FancyPCA(),
                A.HueSaturationValue()
            ], p=0.5),
            A.ImageCompression(
                quality_lower=self.config['data_aug']['quality_lower'],
                quality_upper=self.config['data_aug']['quality_upper'], p=0.5),
        ])
        return trans

    def collect_img_and_label_for_one_dataset(self, dataset_name: str):
        label_list = []
        frame_path_list = []
        video_name_list = []

        json_path = os.path.join(self.config['dataset_json_folder'], dataset_name + '.json')
        try:
            with open(json_path, 'r') as f:
                dataset_info = json.load(f)
        except Exception as e:
            raise ValueError(f'Dataset JSON not found: {json_path}. Error: {e}')

        # Xử lý _c40 variant (giống gốc)
        cp = None
        original_name = dataset_name
        if dataset_name == 'FaceForensics++_c40':
            original_name = 'FaceForensics++'; cp = 'c40'
        elif dataset_name == 'FF-DF_c40':
            original_name = 'FF-DF'; cp = 'c40'
        elif dataset_name == 'FF-F2F_c40':
            original_name = 'FF-F2F'; cp = 'c40'
        elif dataset_name == 'FF-FS_c40':
            original_name = 'FF-FS'; cp = 'c40'
        elif dataset_name == 'FF-NT_c40':
            original_name = 'FF-NT'; cp = 'c40'

        for label_key in dataset_info[original_name]:
            sub_dataset_info = dataset_info[original_name][label_key][self.mode]

            # Xử lý compression cho FF++ family (giống gốc)
            ff_family = ['FF-DF', 'FF-F2F', 'FF-FS', 'FF-NT', 'FaceForensics++', 'DeepFakeDetection', 'FaceShifter']
            if cp is None and original_name in ff_family:
                sub_dataset_info = sub_dataset_info[self.compression]
            elif cp == 'c40' and original_name in ff_family:
                sub_dataset_info = sub_dataset_info['c40']

            for video_name, video_info in sub_dataset_info.items():
                unique_video_name = video_info['label'] + '_' + video_name

                if video_info['label'] not in self.config['label_dict']:
                    raise ValueError(f"Label '{video_info['label']}' not found in config['label_dict'].")
                label = self.config['label_dict'][video_info['label']]
                frame_paths = video_info['frames']

                # Sắp xếp frame theo số (giống gốc)
                if len(frame_paths) > 0:
                    if '\\' in frame_paths[0]:
                        frame_paths = sorted(frame_paths, key=lambda x: int(x.split('\\')[-1].split('.')[0]))
                    else:
                        frame_paths = sorted(frame_paths, key=lambda x: int(x.split('/')[-1].split('.')[0]))

                # Lấy số frame cần thiết (giống gốc)
                total_frames = len(frame_paths)
                if self.frame_num < total_frames:
                    total_frames = self.frame_num
                    step = total_frames // self.frame_num
                    frame_paths = [frame_paths[i] for i in range(0, total_frames, step)][:self.frame_num]

                label_list.extend([label] * total_frames)
                frame_path_list.extend(frame_paths)
                video_name_list.extend([unique_video_name] * total_frames)

        # Shuffle
        combined = list(zip(label_list, frame_path_list, video_name_list))
        random.shuffle(combined)
        label_list, frame_path_list, video_name_list = zip(*combined) if combined else ([], [], [])

        return list(frame_path_list), list(label_list), list(video_name_list)

    def load_rgb(self, file_path):
        size = self.config['resolution']
        # Nếu đường dẫn tương đối, gắn với rgb_dir (giống gốc)
        if not file_path.startswith('/'):
            if file_path.startswith('./'):
                file_path = file_path[2:]
            file_path = os.path.join(self.config['rgb_dir'], file_path)

        file_path = file_path.replace('\\', '/')

        if not os.path.exists(file_path):
            raise ValueError(f"File not found: {file_path}")

        img = cv2.imread(file_path)
        if img is None:
            raise ValueError(f'Loaded image is None: {file_path}')

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
        return Image.fromarray(np.array(img, dtype=np.uint8))

    def to_tensor(self, img):
        return T.ToTensor()(img)

    def normalize(self, img):
        mean = self.config['mean']
        std = self.config['std']
        return T.Normalize(mean=mean, std=std)(img)

    def data_aug(self, img, landmark=None, mask=None, augmentation_seed=None):
        """Augmentation — giống gốc abstract_dataset.py (landmark/mask=None cho GenD)."""
        if augmentation_seed is not None:
            random.seed(augmentation_seed)
            np.random.seed(augmentation_seed)

        kwargs = {'image': img}
        # GenD không dùng landmark/mask, nhưng giữ interface cho tương thích
        if landmark is not None:
            kwargs['keypoints'] = landmark
            kwargs['keypoint_params'] = A.KeypointParams(format='xy')
        if mask is not None:
            mask_sq = mask.squeeze(2)
            if mask_sq.max() > 0:
                kwargs['mask'] = mask_sq

        transformed = self.transform(**kwargs)
        augmented_img = transformed['image']
        augmented_landmark = transformed.get('keypoints')
        augmented_mask = transformed.get('mask', mask)

        if augmented_landmark is not None:
            augmented_landmark = np.array(augmented_landmark)
        if augmentation_seed is not None:
            random.seed()
            np.random.seed()

        return augmented_img, augmented_landmark, augmented_mask

    def __getitem__(self, index):
        """Trả về 4 giá trị (image, label, landmark, mask) — giống gốc."""
        image_path = self.data_dict['image'][index]
        label = self.data_dict['label'][index]

        try:
            image = self.load_rgb(image_path)
        except Exception as e:
            print(f"[ERROR] Failed to load index {index}: {e}")
            if index == 0:
                raise e
            return self.__getitem__(0)

        image = np.array(image)
        mask = None
        landmarks = None

        # Augmentation (chỉ khi train) — giống gốc
        if self.mode == 'train' and self.config['use_data_augmentation']:
            image_trans, landmarks_trans, mask_trans = self.data_aug(image, landmarks, mask)
        else:
            image_trans = deepcopy(image)
            landmarks_trans = deepcopy(landmarks)
            mask_trans = deepcopy(mask)

        # To tensor & normalize
        image_trans = self.normalize(self.to_tensor(image_trans))

        return image_trans, label, landmarks_trans, mask_trans

    @staticmethod
    def collate_fn(batch):
        """Collate — giống gốc abstract_dataset.py (unpack 4 giá trị)."""
        images, labels, landmarks, masks = zip(*batch)
        images = torch.stack(images, dim=0)
        labels = torch.LongTensor(labels)

        if not any(landmark is None or (isinstance(landmark, list) and None in landmark) for landmark in landmarks):
            landmarks = torch.stack(landmarks, dim=0)
        else:
            landmarks = None

        if not any(m is None or (isinstance(m, list) and None in m) for m in masks):
            masks = torch.stack(masks, dim=0)
        else:
            masks = None

        data_dict = {
            'image': images,
            'label': labels,
            'landmark': landmarks,
            'mask': masks,
        }
        return data_dict

    def __len__(self):
        return len(self.image_list)


print("\u2705 Dataset class loaded (matching original project).")

✅ Dataset class loaded.


## Cell 7 — Metrics

Sao chép từ `training/metrics/base_metrics_class.py` và `training/metrics/utils.py`

In [None]:
# ============================================================
#  Nguồn: training/metrics/base_metrics_class.py + training/metrics/utils.py
#  FIX: Recorder dùng sum/num (giống gốc), calculate_metrics_for_train dùng .squeeze(),
#       get_test_metrics dùng get_video_metrics gốc, thêm parse_metric_for_print
# ============================================================

class Recorder:
    """Recorder giống hệt project gốc (sum/num thay vì list)."""
    def __init__(self):
        self.sum = 0
        self.num = 0

    def update(self, item, num=1):
        if item is not None:
            if isinstance(item, torch.Tensor):
                item = item.item()
            self.sum += item * num
            self.num += num

    def average(self):
        if self.num == 0:
            return None
        return self.sum / self.num

    def clear(self):
        self.sum = 0
        self.num = 0


def calculate_metrics_for_train(label, output):
    """Tính AUC, EER, ACC, AP — khớp hệt file gốc base_metrics_class.py."""
    if output.size(1) == 2:
        prob = torch.softmax(output, dim=1)[:, 1]
    else:
        prob = output

    # Accuracy
    _, prediction = torch.max(output, 1)
    correct = (prediction == label).sum().item()
    accuracy = correct / prediction.size(0)

    # Average Precision
    y_true = label.cpu().detach().numpy()
    y_pred = prob.cpu().detach().numpy()
    ap = sk_metrics.average_precision_score(y_true, y_pred)

    # AUC and EER (giống gốc: dùng .squeeze())
    try:
        fpr, tpr, thresholds = sk_metrics.roc_curve(
            label.squeeze().cpu().numpy(),
            prob.squeeze().cpu().numpy(),
            pos_label=1
        )
    except:
        # for the case when we only have one sample
        return None, None, accuracy, ap

    if np.isnan(fpr[0]) or np.isnan(tpr[0]):
        # for the case when all the samples within a batch is fake/real
        auc, eer = None, None
    else:
        auc = sk_metrics.auc(fpr, tpr)
        fnr = 1 - tpr
        eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]

    return auc, eer, accuracy, ap


def get_test_metrics(y_pred, y_true, img_names):
    """Tính metric toàn cục — khớp hệt file gốc metrics/utils.py (bao gồm video-level AUC)."""

    def get_video_metrics(image, pred, label):
        """Video-level AUC — khớp hệt hàm gốc trong metrics/utils.py."""
        result_dict = {}
        new_label = []
        new_pred = []

        for item in np.transpose(np.stack((image, pred, label)), (1, 0)):
            s = item[0]
            if '\\' in s:
                parts = s.split('\\')
            else:
                parts = s.split('/')
            a = parts[-2]  # tên video folder (parent directory)

            if a not in result_dict:
                result_dict[a] = []
            result_dict[a].append(item)

        image_arr = list(result_dict.values())
        for video in image_arr:
            pred_sum = 0
            label_sum = 0
            leng = 0
            for frame in video:
                pred_sum += float(frame[1])
                label_sum += int(frame[2])
                leng += 1
            new_pred.append(pred_sum / leng)
            new_label.append(int(label_sum / leng))

        fpr, tpr, thresholds = sk_metrics.roc_curve(new_label, new_pred)
        v_auc = sk_metrics.auc(fpr, tpr)
        fnr = 1 - tpr
        v_eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
        return v_auc, v_eer

    y_pred = y_pred.squeeze()
    # For UCF, where labels for different manipulations are not consistent.
    y_true[y_true >= 1] = 1

    # AUC
    fpr, tpr, thresholds = sk_metrics.roc_curve(y_true, y_pred, pos_label=1)
    auc = sk_metrics.auc(fpr, tpr)
    # EER
    fnr = 1 - tpr
    eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
    # AP
    ap = sk_metrics.average_precision_score(y_true, y_pred)
    # ACC
    prediction_class = (y_pred > 0.5).astype(int)
    correct = (prediction_class == np.clip(y_true, a_min=0, a_max=1)).sum().item()
    acc = correct / len(prediction_class)

    # Video-level AUC (giống gốc: dùng get_video_metrics)
    if type(img_names[0]) is not list:
        # frame-level methods -> tính video-level AUC
        v_auc, _ = get_video_metrics(img_names, y_pred, y_true)
    else:
        # video-level methods
        v_auc = auc

    return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap, 'pred': y_pred, 'video_auc': v_auc, 'label': y_true}


def parse_metric_for_print(metric_dict):
    """Format best metrics để in log — giống hệt file gốc metrics/utils.py."""
    if metric_dict is None:
        return "\n"
    s = "\n"
    s += "================================ Each dataset best metric ================================ \n"
    for key, value in metric_dict.items():
        if key != 'avg':
            s += f"| {key}: "
            for k, v in value.items():
                s += f" {k}={v} "
            s += "| \n"
        else:
            s += "============================================================================================= \n"
            s += "================================== Average best metric ====================================== \n"
            avg_dict = value
            for avg_key, avg_value in avg_dict.items():
                if avg_key == 'dataset_dict':
                    for dk, dv in avg_value.items():
                        s += f"| {dk}: {dv} | \n"
                else:
                    s += f"| avg {avg_key}: {avg_value} | \n"
    s += "============================================================================================="
    return s


print("\u2705 Metrics loaded (matching original project).")

✅ Metrics loaded.


## Cell 8 — Loss Functions (Alignment & Uniformity)

Sao chép từ `training/detectors/gend_detector.py`

In [42]:
# ============================================================
#  Nguồn: training/detectors/gend_detector.py (phần loss)
# ============================================================

def alignment_loss(embeddings, labels, alpha=2):
    """
    Label-aware Alignment loss: Kéo các mẫu cùng nhãn lại gần nhau.
    """
    if embeddings.size(0) < 2:
        return torch.tensor(0.0, device=embeddings.device)

    labels_equal_mask = (labels[:, None] == labels[None, :]).triu(diagonal=1)
    positive_indices = torch.nonzero(labels_equal_mask, as_tuple=False)

    if positive_indices.numel() == 0:
        return torch.tensor(0.0, device=embeddings.device)

    x = embeddings[positive_indices[:, 0]]
    y = embeddings[positive_indices[:, 1]]

    return (x - y).norm(p=2, dim=1).pow(alpha).mean()


def uniformity_loss(x, t=2, clip_value=1e-6):
    """
    Uniformity loss: Đẩy các mẫu phân bố đều trên mặt cầu.
    """
    if x.size(0) < 2:
        return torch.tensor(0.0, device=x.device)
    return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().clamp(min=clip_value).log()


print("\u2705 Loss functions loaded.")

✅ Loss functions loaded.


## Cell 9 — Model GenD (CLIP ViT-L/14 + LN-Tuning)

Sao chép từ `training/detectors/gend_detector.py`.  
- Đã bỏ `@DETECTOR.register_module` decorator.  
- Đã bỏ `from detectors import DETECTOR`.

In [43]:
# ============================================================
#  Nguồn: training/detectors/gend_detector.py
#  Đã chỉnh sửa: bỏ DETECTOR registry, inline loss imports
# ============================================================

class GenDDetector(nn.Module):
    def __init__(self, config=None):
        super(GenDDetector, self).__init__()
        self.config = config

        # 1. Load Backbone (CLIP ViT-L/14)
        print("Loading CLIP ViT-L/14 for GenD...")
        self.backbone = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").vision_model

        # 2. Setup LN-Tuning (Chỉ train LayerNorm)
        self._setup_training_params()

        # 3. Classifier Head
        self.head = nn.Linear(1024, 2)

        # 4. Loss Weights
        self.lambda_align = config.get('lambda_align', 1.0) if config else 1.0
        self.lambda_unif = config.get('lambda_unif', 1.0) if config else 1.0

        self.loss_ce = nn.CrossEntropyLoss()

    def _setup_training_params(self):
        """
        Đóng băng toàn bộ backbone, CHỈ mở khóa các lớp LayerNorm.
        """
        for param in self.backbone.parameters():
            param.requires_grad = False

        trainable_params = 0
        for name, param in self.backbone.named_parameters():
            if 'layer_norm' in name or 'layernorm' in name:
                param.requires_grad = True
                trainable_params += param.numel()

        print(f"GenD Initialized. Trainable backbone params (LayerNorm): {trainable_params:,}")
        print(f"Classifier head params: {1024 * 2 + 2:,}")

    def features(self, data_dict: dict) -> torch.Tensor:
        outputs = self.backbone(data_dict['image'])
        feat = outputs.pooler_output  # [B, 1024]
        feat = F.normalize(feat, p=2, dim=1)
        return feat

    def classifier(self, features: torch.Tensor) -> torch.Tensor:
        return self.head(features)

    def forward(self, data_dict: dict, inference=False) -> dict:
        features = self.features(data_dict)
        pred = self.classifier(features)
        prob = torch.softmax(pred, dim=1)[:, 1]
        return {'cls': pred, 'prob': prob, 'feat': features}

    def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
        label = data_dict['label']
        pred = pred_dict['cls']
        features = pred_dict['feat']

        loss_cls = self.loss_ce(pred, label)

        loss_align = torch.tensor(0.0, device=pred.device)
        loss_unif = torch.tensor(0.0, device=pred.device)

        if self.training:
            loss_align = alignment_loss(features, label)
            loss_unif = uniformity_loss(features)

        overall = loss_cls + (self.lambda_align * loss_align) + (self.lambda_unif * loss_unif)

        return {
            'overall': overall,
            'ce_loss': loss_cls,
            'align_loss': loss_align,
            'unif_loss': loss_unif,
        }

    def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
        label = data_dict['label']
        pred = pred_dict['cls']
        auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
        return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap}


print("\u2705 GenDDetector class loaded.")

✅ GenDDetector class loaded.


## Cell 10 — SAM Optimizer & Scheduler

Sao chép từ `training/optimizor/SAM.py` và `training/optimizor/LinearLR.py`

In [44]:
# ============================================================
#  Nguồn: training/optimizor/SAM.py
# ============================================================

def disable_running_stats(model):
    def _disable(module):
        if isinstance(module, nn.BatchNorm2d):
            module.backup_momentum = module.momentum
            module.momentum = 0
    model.apply(_disable)

def enable_running_stats(model):
    def _enable(module):
        if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"):
            module.momentum = module.backup_momentum
    model.apply(_enable)


class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        assert rho >= 0.0, f"Invalid rho: {rho}"
        defaults = dict(rho=rho, **kwargs)
        super(SAM, self).__init__(params, defaults)
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)
            for p in group["params"]:
                if p.grad is None: continue
                e_w = p.grad * scale.to(p)
                p.add_(e_w)
                self.state[p]["e_w"] = e_w
        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.sub_(self.state[p]["e_w"])
        self.base_optimizer.step()
        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "SAM requires closure"
        closure = torch.enable_grad()(closure)
        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device
        norm = torch.norm(
            torch.stack([
                p.grad.norm(p=2).to(shared_device)
                for group in self.param_groups for p in group["params"]
                if p.grad is not None
            ]), p=2)
        return norm


# ============================================================
#  Nguồn: training/optimizor/LinearLR.py
# ============================================================
from torch.optim.lr_scheduler import _LRScheduler

class LinearDecayLR(_LRScheduler):
    def __init__(self, optimizer, n_epoch, start_decay, last_epoch=-1):
        self.start_decay = start_decay
        self.n_epoch = n_epoch
        super(LinearDecayLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        last_epoch = self.last_epoch
        n_epoch = self.n_epoch
        b_lr = self.base_lrs[0]
        start_decay = self.start_decay
        if last_epoch > start_decay:
            lr = b_lr - b_lr / (n_epoch - start_decay) * (last_epoch - start_decay)
        else:
            lr = b_lr
        return [lr]


print("\u2705 SAM Optimizer & LinearDecayLR loaded.")

✅ SAM Optimizer & LinearDecayLR loaded.


## Cell 11 — Helper Functions (Optimizer, Scheduler, DataLoader)

Sao chép logic từ `training/train.py`

In [None]:
# ============================================================
#  Nguồn: training/train.py (các hàm helper)
#  FIX: choose_optimizer dùng model.parameters() (giống gốc, không filter),
#       prepare_testing_data thêm drop_last cho DeepFakeDetection
# ============================================================

def choose_optimizer(model, config):
    """Chọn optimizer — giống hệt train.py gốc (dùng model.parameters() không filter)."""
    opt_name = config['optimizer']['type']

    if opt_name == 'sgd':
        return optim.SGD(
            params=model.parameters(),
            lr=config['optimizer'][opt_name]['lr'],
            momentum=config['optimizer'][opt_name]['momentum'],
            weight_decay=config['optimizer'][opt_name]['weight_decay'],
        )
    elif opt_name == 'adam':
        return optim.Adam(
            params=model.parameters(),
            lr=config['optimizer'][opt_name]['lr'],
            weight_decay=config['optimizer'][opt_name]['weight_decay'],
            betas=(config['optimizer'][opt_name]['beta1'], config['optimizer'][opt_name]['beta2']),
            eps=config['optimizer'][opt_name]['eps'],
            amsgrad=config['optimizer'][opt_name]['amsgrad'],
        )
    elif opt_name == 'sam':
        return SAM(
            model.parameters(),
            optim.SGD,
            lr=config['optimizer'][opt_name]['lr'],
            momentum=config['optimizer'][opt_name]['momentum'],
        )
    else:
        raise NotImplementedError(f'Optimizer {opt_name} is not implemented')


def choose_scheduler(config, optimizer):
    """Chọn scheduler — giống hệt train.py gốc."""
    if config['lr_scheduler'] is None:
        return None
    elif config['lr_scheduler'] == 'step':
        return optim.lr_scheduler.StepLR(optimizer, step_size=config['lr_step'], gamma=config['lr_gamma'])
    elif config['lr_scheduler'] == 'cosine':
        return optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['lr_T_max'], eta_min=config['lr_eta_min'])
    elif config['lr_scheduler'] == 'linear':
        return LinearDecayLR(optimizer, config['nEpochs'], int(config['nEpochs'] / 4))
    else:
        raise NotImplementedError(f'Scheduler {config["lr_scheduler"]} not implemented')


def prepare_training_data(config):
    """Tạo train DataLoader — giống hệt train.py gốc (branch DeepfakeAbstractBaseDataset)."""
    train_set = DeepfakeAbstractBaseDataset(config=config, mode='train')
    train_data_loader = DataLoader(
        dataset=train_set,
        batch_size=config['train_batchSize'],
        shuffle=True,
        num_workers=int(config['workers']),
        collate_fn=train_set.collate_fn,
    )
    return train_data_loader


def prepare_testing_data(config):
    """Tạo test DataLoaders — giống hệt train.py gốc (bao gồm drop_last cho DeepFakeDetection)."""
    def get_test_data_loader(config, test_name):
        test_cfg = config.copy()
        test_cfg['test_dataset'] = test_name
        test_set = DeepfakeAbstractBaseDataset(config=test_cfg, mode='test')
        test_data_loader = DataLoader(
            dataset=test_set,
            batch_size=config['test_batchSize'],
            shuffle=False,
            num_workers=int(config['workers']),
            collate_fn=test_set.collate_fn,
            drop_last=(test_name == 'DeepFakeDetection'),  # giống gốc
        )
        return test_data_loader

    test_data_loaders = {}
    for one_test_name in config['test_dataset']:
        test_data_loaders[one_test_name] = get_test_data_loader(config, one_test_name)
    return test_data_loaders


print("\u2705 Helper functions loaded (matching original project).")

✅ Helper functions loaded.


## Cell 12 — Training Loop & Validation

Viết lại từ `training/trainer/trainer.py` — khớp 100% pipeline gốc:
- **Log format**: `Iter: {step_cnt}  training-loss, {k}: {v}` (giống gốc, mỗi 300 iteration)
- **Mid-epoch testing**: Test 2 lần/epoch (epoch ≥ 1), 1 lần (epoch 0)
- **Test logging**: Gồm test loss, test metrics, video_auc, acc_real/acc_fake
- **Best metric**: Track per-dataset, save checkpoint per test dataset
- **train_step**: Hỗ trợ cả Adam (1-step) và SAM (2-step)

In [None]:
# ============================================================
#  Training Loop (viết lại từ trainer/trainer.py, khớp 100% pipeline gốc)
#  FIX: Log format giống gốc (Iter: step_cnt, training-loss/metric),
#       mid-epoch testing, test loss logging, acc_real/acc_fake,
#       best metric tracking per dataset, 300-iteration log interval
# ============================================================

FFpp_pool_ckpt = ['FaceForensics++', 'FF-DF', 'FF-F2F', 'FF-FS', 'FF-NT']


def get_respect_acc(prob, label):
    """Tính accuracy riêng cho real và fake — giống hệt trainer gốc."""
    pred = np.where(prob > 0.5, 1, 0)
    judge = (pred == label)
    real_idx = np.where(label == 0)[0]
    fake_idx = np.where(label == 1)[0]
    acc_real = np.count_nonzero(judge[real_idx]) / len(real_idx) if len(real_idx) > 0 else 0.0
    acc_fake = np.count_nonzero(judge[fake_idx]) / len(fake_idx) if len(fake_idx) > 0 else 0.0
    return acc_real, acc_fake


def train_step(model, data_dict, optimizer, config):
    """Một bước train — giống hệt trainer.train_step()."""
    if config['optimizer']['type'] == 'sam':
        for i in range(2):
            predictions = model(data_dict)
            losses = model.get_losses(data_dict, predictions)
            if i == 0:
                pred_first = predictions
                losses_first = losses
            optimizer.zero_grad()
            losses['overall'].backward()
            if i == 0:
                optimizer.first_step(zero_grad=True)
            else:
                optimizer.second_step(zero_grad=True)
        return losses_first, pred_first
    else:
        predictions = model(data_dict)
        losses = model.get_losses(data_dict, predictions)
        optimizer.zero_grad()
        losses['overall'].backward()
        optimizer.step()
        return losses, predictions


@torch.no_grad()
def test_one_dataset(model, data_loader):
    """Test 1 dataset — giống hệt trainer.test_one_dataset()."""
    test_recorder_loss = defaultdict(Recorder)
    prediction_lists = []
    feature_lists = []
    label_lists = []

    for i, data_dict in tqdm(enumerate(data_loader), total=len(data_loader)):
        if data_dict is None:
            continue
        # Fix label to binary (giống gốc)
        if 'label_spe' in data_dict:
            data_dict.pop('label_spe')
        data_dict['label'] = torch.where(data_dict['label'] != 0, 1, 0)
        # Move to GPU
        for key in data_dict.keys():
            if data_dict[key] is not None and isinstance(data_dict[key], torch.Tensor):
                data_dict[key] = data_dict[key].to(device)
        # Forward (inference mode)
        predictions = model(data_dict, inference=True)
        label_lists += list(data_dict['label'].cpu().detach().numpy())
        prediction_lists += list(predictions['prob'].cpu().detach().numpy())
        feature_lists += list(predictions['feat'].cpu().detach().numpy())
        # Compute losses (giống gốc)
        losses = model.get_losses(data_dict, predictions)
        for name, value in losses.items():
            test_recorder_loss[name].update(value)

    return test_recorder_loss, np.array(prediction_lists), np.array(label_lists), np.array(feature_lists)


def save_best_and_log(logger, epoch, iteration, step_cnt, losses_recorder, key,
                      metric_one_dataset, best_metrics_all_time, metric_scoring,
                      config, log_dir, model):
    """
    Lưu checkpoint nếu cải thiện + in log — giống hệt trainer.save_best().
    In testing-loss và testing-metric với format gốc.
    """
    best_metric = best_metrics_all_time[key].get(
        metric_scoring,
        float('-inf') if metric_scoring != 'eer' else float('inf')
    )
    improved = (metric_one_dataset[metric_scoring] > best_metric) if metric_scoring != 'eer' else \
               (metric_one_dataset[metric_scoring] < best_metric)
    if improved:
        best_metrics_all_time[key][metric_scoring] = metric_one_dataset[metric_scoring]
        if key == 'avg':
            best_metrics_all_time[key]['dataset_dict'] = metric_one_dataset['dataset_dict']
        if config['save_ckpt'] and key not in FFpp_pool_ckpt:
            save_checkpoint(model, logger,
                            os.path.join(log_dir, 'test', key, 'ckpt_best.pth'),
                            info=f"{epoch}+{iteration}")

    # Log test losses (giống gốc)
    if losses_recorder is not None:
        loss_str = f"dataset: {key}    step: {step_cnt}    "
        for k, v in losses_recorder.items():
            v_avg = v.average()
            if v_avg is None:
                print(f'{k} is not calculated')
                continue
            loss_str += f"testing-loss, {k}: {v_avg}    "
        logger.info(loss_str)

    # Log test metrics (giống gốc: bao gồm video_auc, acc_real, acc_fake)
    metric_str = f"dataset: {key}    step: {step_cnt}    "
    for k, v in metric_one_dataset.items():
        if k == 'pred' or k == 'label' or k == 'dataset_dict':
            continue
        metric_str += f"testing-metric, {k}: {v}    "
    if 'pred' in metric_one_dataset:
        acc_real, acc_fake = get_respect_acc(metric_one_dataset['pred'], metric_one_dataset['label'])
        metric_str += f'testing-metric, acc_real:{acc_real}; acc_fake:{acc_fake}'
    logger.info(metric_str)


def test_epoch(model, test_loaders, logger, epoch, iteration, step_cnt,
               config, best_metrics_all_time, metric_scoring, log_dir):
    """
    Test toàn bộ test datasets — giống hệt trainer.test_epoch().
    Bao gồm: test loss, test metrics, video_auc, acc_real/acc_fake, best ckpt tracking.
    """
    model.eval()
    avg_metric = {'acc': 0, 'auc': 0, 'eer': 0, 'ap': 0, 'video_auc': 0, 'dataset_dict': {}}

    for key, loader in test_loaders.items():
        # Test one dataset
        losses_recorder, predictions_nps, label_nps, feature_nps = test_one_dataset(model, loader)

        # Get test metrics (includes video-level AUC)
        data_dict = loader.dataset.data_dict
        metric_one_dataset = get_test_metrics(
            y_pred=predictions_nps, y_true=label_nps, img_names=data_dict['image']
        )

        # Accumulate for average
        for metric_name, value in metric_one_dataset.items():
            if metric_name in avg_metric:
                avg_metric[metric_name] += value
        avg_metric['dataset_dict'][key] = metric_one_dataset[metric_scoring]

        # Save best + log (giống gốc)
        save_best_and_log(logger, epoch, iteration, step_cnt, losses_recorder, key,
                          metric_one_dataset, best_metrics_all_time, metric_scoring,
                          config, log_dir, model)

        # Save features if configured
        if config.get('save_feat', False) and feature_nps is not None:
            feat_save_dir = os.path.join(log_dir, 'test', key)
            os.makedirs(feat_save_dir, exist_ok=True)
            np.save(os.path.join(feat_save_dir, 'feat_best.npy'), feature_nps)
            logger.info(f"Feature saved to {os.path.join(feat_save_dir, 'feat_best.npy')}")

    # Average metrics (giống gốc)
    if len(test_loaders) > 0 and config.get('save_avg', False):
        for k in avg_metric:
            if k != 'dataset_dict':
                avg_metric[k] /= len(test_loaders)
        save_best_and_log(logger, epoch, iteration, step_cnt, None, 'avg',
                          avg_metric, best_metrics_all_time, metric_scoring,
                          config, log_dir, model)

    logger.info('===> Test Done!')
    return best_metrics_all_time


def train_one_epoch(model, train_loader, test_loaders, optimizer, config, logger, epoch,
                    best_metrics_all_time, metric_scoring, log_dir):
    """
    Train 1 epoch — giống hệt trainer.train_epoch().
    Bao gồm:
    - Log loss & metric mỗi 300 iterations (format gốc: "Iter: step_cnt  training-loss, ...")
    - Mid-epoch testing: test 2 lần/epoch khi epoch >= 1, 1 lần khi epoch 0
    """
    logger.info(f"===> Epoch[{epoch}] start!")
    model.train()

    # Giống trainer gốc: test nhiều lần / epoch
    times_per_epoch = 2 if epoch >= 1 else 1
    test_step = len(train_loader) // times_per_epoch
    step_cnt = epoch * len(train_loader)

    # Recorders (giống gốc)
    train_recorder_loss = defaultdict(Recorder)
    train_recorder_metric = defaultdict(Recorder)

    test_best_metric = None

    for iteration, data_dict in tqdm(enumerate(train_loader), total=len(train_loader)):
        model.train()  # Đảm bảo train mode sau mỗi lần test
        if data_dict is None:
            step_cnt += 1
            continue

        # Move to GPU (giống gốc)
        for key in data_dict.keys():
            if data_dict[key] is not None and isinstance(data_dict[key], torch.Tensor):
                data_dict[key] = data_dict[key].to(device)

        # Train step
        losses, predictions = train_step(model, data_dict, optimizer, config)

        # Record batch metrics (giống gốc)
        batch_metrics = model.get_train_metrics(data_dict, predictions)
        for name, value in batch_metrics.items():
            train_recorder_metric[name].update(value)
        for name, value in losses.items():
            train_recorder_loss[name].update(value)

        # --- Logging mỗi 300 iterations (giống trainer gốc: iteration % 300 == 0) ---
        if iteration % 300 == 0:
            # Loss log (format gốc)
            loss_str = f"Iter: {step_cnt}    "
            for k, v in train_recorder_loss.items():
                v_avg = v.average()
                if v_avg is None:
                    loss_str += f"training-loss, {k}: not calculated"
                    continue
                loss_str += f"training-loss, {k}: {v_avg}    "
            logger.info(loss_str)

            # Metric log (format gốc)
            metric_str = f"Iter: {step_cnt}    "
            for k, v in train_recorder_metric.items():
                v_avg = v.average()
                if v_avg is None:
                    metric_str += f"training-metric, {k}: not calculated    "
                    continue
                metric_str += f"training-metric, {k}: {v_avg}    "
            logger.info(metric_str)

            # Clear recorders (giống gốc: chỉ tính 300 samples gần nhất)
            for name, recorder in train_recorder_loss.items():
                recorder.clear()
            for name, recorder in train_recorder_metric.items():
                recorder.clear()

        # --- Mid-epoch testing (giống trainer gốc) ---
        if (step_cnt + 1) % test_step == 0:
            if test_loaders is not None and len(test_loaders) > 0:
                logger.info("===> Test start!")
                test_best_metric = test_epoch(
                    model, test_loaders, logger, epoch, iteration, step_cnt,
                    config, best_metrics_all_time, metric_scoring, log_dir
                )

        step_cnt += 1

    return test_best_metric


def save_checkpoint(model, logger, path, info=""):
    """Lưu model state dict — giống hệt trainer.save_ckpt()."""
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)
    logger.info(f"Checkpoint saved to {path}, current ckpt is {info}")


print("\u2705 Training loop & validation loaded (matching original project).")

✅ Training loop & validation loaded.


## Cell 13 — Khởi tạo Model + Optimizer + DataLoaders

Cell này sẽ:
1. Tạo model GenD (tải CLIP weights từ internet)
2. Tạo optimizer và scheduler
3. Tạo train/test DataLoaders

In [None]:
# 1. Model
model = GenDDetector(config=config)
model = model.to(device)

# In số params
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Total parameters:     {total_params:,}")
logger.info(f"Trainable parameters: {trainable_params:,}")

# 2. Optimizer (giống gốc: dùng model.parameters() không filter)
optimizer = choose_optimizer(model, config)
scheduler = choose_scheduler(config, optimizer)
logger.info(f"Optimizer: {config['optimizer']['type']}")
logger.info(f"Scheduler: {config['lr_scheduler']}")

# 3. DataLoaders
logger.info("Preparing training data...")
train_loader = prepare_training_data(config)
logger.info(f"Train samples: {len(train_loader.dataset):,}  |  Batches: {len(train_loader)}")

logger.info("Preparing testing data...")
test_loaders = prepare_testing_data(config)
for name, loader in test_loaders.items():
    logger.info(f"Test [{name}]: {len(loader.dataset):,} samples  |  {len(loader)} batches")

# 4. Print configuration (giống train.py gốc)
logger.info("--------------- Configuration ---------------")
params_string = "Parameters: \n"
for key, value in config.items():
    params_string += f"{key}: {value}\n"
logger.info(params_string)

Loading CLIP ViT-L/14 for GenD...
GenD Initialized. Trainable backbone params (LayerNorm): 100,352
Classifier head params: 2,050
[03:00:52] Total parameters:     303,181,826
[03:00:52] Trainable parameters: 102,402
[03:00:52] Optimizer: adam
[03:00:52] Scheduler: None
[03:00:52] Preparing training data...


ValueError: Dataset JSON not found: /kaggle/input/deepfakebench/dataset_json/FaceForensics++.json. Error: [Errno 2] No such file or directory: '/kaggle/input/deepfakebench/dataset_json/FaceForensics++.json'

## Cell 14 — CHẠY TRAINING

In [None]:
# ============================================================
#  Main Training Loop — giống hệt train.py gốc
#  FIX: best_metrics_all_time per-dataset, parse_metric_for_print,
#       scheduler.step() chỉ 1 lần sau tất cả epochs (giống gốc)
# ============================================================

best_metrics_all_time = defaultdict(
    lambda: defaultdict(lambda: float('-inf') if metric_scoring != 'eer' else float('inf'))
)
metric_scoring = config['metric_scoring']

best_metric = None

for epoch in range(config['start_epoch'], config['nEpochs'] + 1):
    model.epoch = epoch  # giống gốc: trainer.model.epoch = epoch

    # --- Train (bao gồm mid-epoch testing) ---
    best_metric = train_one_epoch(
        model, train_loader, test_loaders, optimizer, config, logger, epoch,
        best_metrics_all_time, metric_scoring, log_dir
    )

    if best_metric is not None:
        logger.info(f"===> Epoch[{epoch}] end with testing {metric_scoring}: {parse_metric_for_print(best_metric)}!")

    # Save checkpoint mỗi save_epoch
    if config['save_ckpt'] and (epoch % config['save_epoch'] == 0):
        epoch_ckpt = os.path.join(log_dir, f'ckpt_epoch_{epoch}.pth')
        save_checkpoint(model, logger, epoch_ckpt, info=f"epoch={epoch}")

# Scheduler step (giống gốc: chỉ step 1 lần sau tất cả epochs)
if scheduler is not None:
    scheduler.step()

if best_metric is not None:
    logger.info(f"Stop Training on best Testing metric {parse_metric_for_print(best_metric)}")

logger.info(f"\n{'='*60}")
logger.info(f"\u2705 Training complete!")
logger.info(f"Checkpoints saved in: {log_dir}")
logger.info(f"{'='*60}")

## Cell 15 — Load Checkpoint & Test lại


In [None]:
# ============================================================
#  (Tùy chọn) Load checkpoint và test lại
# ============================================================

# Bỏ comment dòng dưới để chạy:

# ckpt_path = os.path.join(log_dir, 'ckpt_best.pth')
# state_dict = torch.load(ckpt_path, map_location=device)
# model.load_state_dict(state_dict)
# logger.info(f"Loaded checkpoint: {ckpt_path}")
# 
# test_results = validate(model, test_loaders, logger)
# avg_auc = np.mean([r['auc'] for r in test_results.values()])
# logger.info(f"\u2705 Test AUC (from saved checkpoint): {avg_auc:.4f}")