In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL. vb
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

katsuyamucb_madde_dataset_path = kagglehub.dataset_download('katsuyamucb/madde-dataset')

print('Data source import complete.')


In [None]:
ls /root/.cache/kagglehub/datasets/katsuyamucb/madde-dataset/versions/13


# Project Codebase


This is our existing codebase for fine tuning visual models. The goal of this project is to find best PEFT methods for deepfake detection.
We tried a couple of 'deepfake detection' models below, and (but their performance is not great at all.
So we use our deepfake dataset (~10K) to fine-tune the pre-trained deepfake detection model, and let them work better.

What we want to spend time on are;
  - Try as many PEFT methods as possible
  - Record their performance in Weights and Biases
  - Think of why some work better and some don't
  - Compare them to full fine tuning
  - Refactor our codebase

However
  - We don't try too many models for simplicity; Still not 100% but let's **use pre-trained CLIP** for testing.
  - We don't work on non-facial forgery. Let's focus on **forged facial image**.


# Step 0: Setup

In the setup section, we will define three utility functions;
- DeepfakeDataset(): Use selected folder names to generate dataset
- FineTuner: A Class to inherit - it has basic functions such as model loading, data processing, and model evaluation.
- FullFT: A subclass of FineTuner to run full fine tuning. Please refer this to record training/validation logs using wandb

In [None]:
# Loading Libraries
import os
import torch
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from collections import OrderedDict

import numpy as np
import pandas as pd
from PIL import Image
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import CLIPProcessor

import matplotlib.pyplot as plt
from tqdm.auto import tqdm  # Import tqdm for progress bars
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, precision_recall_curve, average_precision_score
import seaborn as sns
from typing import Dict, List, Optional, Union, Tuple
from IPython.display import display

from huggingface_hub import hf_hub_download

# Wandb
import wandb

import warnings
warnings.filterwarnings("ignore")

In [None]:
! pip install git+https://github.com/purewater0901/deepfake-detection.git

from deepfake_detection.model.dfdet import DeepfakeDetectionModel
from deepfake_detection.config import Config

## Login with your wandb API keys!

In [None]:
# Please use your own login API key
! wandb login

## Utility functions

In [None]:
# Custom dataset for deepfake detection
class DeepfakeDataset(Dataset):
    """Custom dataset for testing deepfake detection models with customizable class folders"""
    def __init__(
        self,
        root_dir: str,
        real_folder: str = 'Real',
        fake_folder: str = 'Fake',
        transform=None,
        processor=None
    ):
        """
        Args:
            root_dir (str): Root directory containing class folders
            real_folder (str): Name of the folder containing real images
            fake_folder (str): Name of the folder containing fake images
            transform (callable, optional): Optional transform to be applied on images
        """
        self.root_dir = root_dir
        self.transform = transform
        self.processor = processor
        self.class_folders = {
            0: real_folder,  # 0 = real
            1: fake_folder,  # 1 = fake
        }

        self.samples = []
        self.load_samples()

    def load_samples(self):
        """Load all image paths and their corresponding labels"""
        for class_idx, folder_name in self.class_folders.items():
            class_dir = os.path.join(self.root_dir, folder_name)
            if not os.path.exists(class_dir):
                raise FileNotFoundError(f"Directory not found: {class_dir}")

            # Add all valid images from this class folder
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith(('.png')): # We restrict png format, in order to avoid overfitting to the difference in format
                    img_path = os.path.join(class_dir, img_name)
                    self.samples.append((img_path, class_idx))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]

        try:
            image = Image.open(img_path).convert('RGB')

            if self.transform:
                image = self.transform(image)
            elif self.processor:
                image = self.processor(images=image, return_tensors="pt")["pixel_values"].squeeze(0)
                # processor returns dictionary, so reduce dimension here

            return image, label, img_path

        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a placeholder image and the label
            placeholder = torch.zeros((3, 299, 299))
            return placeholder, label, img_path


In [None]:
# Define a class for full fine tuning.
# Please inherit this as a super class when you develop a new finetuner.

class FineTuner():
    """
    A Class of fine-tuning.
    """
    def __init__(self, model_name, data_dir, real_folder, fake_folder, num_epochs, batch_size, learning_rate, use_wandb = False, model = None, processor = None):
        self.model_name = model_name

        # Load the model using timm if the model is None
        if model == None:
            self.model = timm.create_model(self.model_name, pretrained=True, num_classes=2)
            print("Loaded ", model_name, " for fine-tuning")
        else:
            self.model = model
        # print(self.model)

        self.data_dir = data_dir
        self.real_folder = real_folder
        self.fake_folder = fake_folder
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.use_wandb = use_wandb
        self.test_fake_folder = None
        self.processor = processor


    def set_TestFolder(self, test_fake_folder):
        self.test_fake_folder = test_fake_folder

    def get_Train_Val_loader(self):

        if self.processor:
            train_dataset = DeepfakeDataset(
                root_dir=self.data_dir,
                real_folder= os.path.join(self.real_folder, 'Train'),
                fake_folder= os.path.join(self.fake_folder, 'Train'),
                processor = self.processor
            )

            val_dataset = DeepfakeDataset(
                root_dir= self.data_dir,
                real_folder= os.path.join(self.real_folder, 'Validation'),
                fake_folder= os.path.join(self.fake_folder, 'Validation'),
                processor = self.processor
            )

        else:
            # Get the config file from timm
            config = resolve_data_config({}, model=self.model)
            base_transform = create_transform(**config)
            # Data augmentation and normalization for training
            train_transform_list = base_transform.transforms
            train_transform_list.append(transforms.RandomHorizontalFlip())
            train_transform_list.append(transforms.RandomRotation(10))

            brightness = np.random.uniform(0.05, 0.2)  # Random value between 0.05 and 0.2... you can change if you want
            contrast = np.random.uniform(0.05, 0.2)
            saturation = np.random.uniform(0.05, 0.2)
            hue = np.random.uniform(0, 0.1)  # Hue is typically smaller values
            train_transform_list.append(transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue = hue))

            train_transform = transforms.Compose(train_transform_list)
            val_transform = base_transform

            # Create datasets
            train_dataset = DeepfakeDataset(
                root_dir=self.data_dir,
                real_folder= os.path.join(self.real_folder, 'Train'),
                fake_folder= os.path.join(self.fake_folder, 'Train'),
                transform=train_transform
            )

            val_dataset = DeepfakeDataset(
                root_dir= self.data_dir,
                real_folder= os.path.join(self.real_folder, 'Validation'),
                fake_folder= os.path.join(self.fake_folder, 'Validation'),
                transform=val_transform,
            )

        # Create data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )
        return train_loader, val_loader

    def get_Test_loader(self, test_folder):
        target = test_folder
        if self.processor:
            # Create dataset
            dataset = DeepfakeDataset(
                root_dir=self.data_dir,
                real_folder= os.path.join(self.real_folder, 'Test'),
                fake_folder= os.path.join(target, 'Test'),
                processor = self.processor
            )
        else:
            # Get the config file from timm
            config = resolve_data_config({}, model=self.model)
            base_transform = create_transform(**config)


            # Create dataset
            dataset = DeepfakeDataset(
                root_dir=self.data_dir,
                real_folder= os.path.join(self.real_folder, 'Test'),
                fake_folder= os.path.join(target, 'Test'),
                transform=base_transform
            )

        # Create data loader
        data_loader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )

        return data_loader

    def Tune(self):
        # function to override
        pass


    def Evaluation(self, test_folder, model = None):

        # Set device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Get dataloader
        test_loader = self.get_Test_loader(test_folder)
        print('\n\n----- Test on ',test_folder,'-----')
        print(f"Device: {device}")


        # Set model to evaluation mode
        if model:
            self.model = model
        self.model.eval()
        model = self.model.to(device)

        # Lists to store results
        all_preds = []
        all_probs = []
        all_labels = []
        all_paths = []
        confidence_threshold = 0.5

        # Run inference
        with torch.no_grad():
            for inputs, labels, paths in tqdm(test_loader, desc="Testing"):
                inputs = inputs.to(device)

                # Forward pass
                outputs = model(inputs).logits_labels

                # Convert outputs to probabilities
                probs = torch.softmax(outputs.float(), dim=1).cpu().numpy() # BF16 to float()

                # Convert to binary predictions using threshold
                preds = (probs[:,1] >= confidence_threshold).astype(int)

                # Store results
                all_preds.extend(preds)
                all_probs.extend(probs)
                all_labels.extend(labels.numpy())
                all_paths.extend(paths)

        # Convert to numpy arrays
        all_preds = np.array(all_preds)
        all_probs = np.array(all_probs)
        all_labels = np.array(all_labels)

        # Classification report
        report = classification_report(all_labels, all_preds,
                                      target_names=['Real', 'Fake'],
                                      output_dict=True)

        test_acc = (all_preds == all_labels).mean()

        print(f"\n\n{'-'*50}")
        print(f"Test Result Summary:")
        print(f"{'-'*50}")
        print(f"Test accuracy: {test_acc:.4f}")

        # Visualize the result ------------------------------------------------
        report_df = pd.DataFrame(report)
        display(report_df)

        cm = confusion_matrix(all_labels, all_preds)
        plt.subplot(2, 2, 4)
        plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
        plt.title('Confusion Matrix')
        plt.colorbar()
        tick_marks = np.arange(2)
        plt.xticks(tick_marks, ['Real', 'Fake'], rotation=45)
        plt.yticks(tick_marks, ['Real', 'Fake'])

        # Add text annotations to confusion matrix
        thresh = cm.max() / 2.
        for i in range(cm.shape[0]):
            for j in range(cm.shape[1]):
                plt.text(j, i, format(cm[i, j], 'd'),
                        horizontalalignment="center",
                        color="white" if cm[i, j] > thresh else "black")

        plt.tight_layout()
        plt.show()

        # Logging
        if self.use_wandb:

            all_probs = np.array(all_probs)
            fpr, tpr, _ = roc_curve(all_labels, all_probs[:,1])
            wandb.log({
                "test_dataset": test_folder,
                "test_accuracy": test_acc,
                # Confusion Matrix
                "test_confusion_matrix": wandb.plot.confusion_matrix(
                    probs=None,
                    y_true=all_labels,
                    preds=all_preds,
                    class_names=['Real', 'Fake']
                ),
                # ROC curve
                "test_roc_curve": wandb.plot.roc_curve(
                    y_true= all_labels,
                    y_probas= all_probs,
                    labels=['Real', 'Fake']
                )
            })

        return report_df

    def log_wandb_train(
            self,
            all_labels,
            all_preds,
            all_probs,
            epoch,
            train_loss,
            train_acc,
            val_loss,
            val_acc,
            optimizer,
        ):
        if self.use_wandb:
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": train_loss,
                "train_accuracy": train_acc,
                "val_loss": val_loss,
                "val_accuracy": val_acc,
                "learning_rate": optimizer.param_groups[0]['lr']
            })

            # Extra logging for the last epoch
            if epoch == self.num_epochs - 1 and self.use_wandb:
                # Confusion Matrix
                cm = confusion_matrix(all_labels, all_preds)
                wandb.log({
                    "confusion_matrix": wandb.plot.confusion_matrix(
                        probs=None,
                        y_true=all_labels,
                        preds=all_preds,
                        class_names=['Real', 'Fake']
                    )
                })

                # ROC curve - use probs instead of preds
                all_probs = np.array(all_probs)
                fpr, tpr, _ = roc_curve(all_labels, all_probs[:,1])
                wandb.log({
                    "roc_curve": wandb.plot.roc_curve(
                        y_true= all_labels,
                        y_probas= all_probs,
                        labels=['Real', 'Fake']
                    )
                })

    def Experiment(self, wandb_run_name):
        # Initilize Wandb
        if self.use_wandb:
            wandb.init(project='Fine-Tuning Experiment', name=wandb_run_name)

            # Log the experimental setting
            wandb.config.update({
                "model": self.model_name,
                "batch_size": self.batch_size,
                "learning_rate": self.learning_rate,
                "num_epochs": self.num_epochs,
                "fine_tuning_type": "full",
                "dataset_dir": self.data_dir,
                "real_folder": self.real_folder,
                "fake_folder": self.fake_folder
            })

            total_params = sum(p.numel() for p in self.model.parameters())
            trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
            wandb.log({
                "total_parameters": total_params,
                "trainable_parameters": trainable_params,
                "frozen_parameters": total_params - trainable_params,
                "percent_trainable": 100 * trainable_params / total_params
            })

            # Log the model itself, (don't know if we need this or not)
            # wandb.watch(self.model, log="all", log_freq=100)

        # Fine-tune the model
        tuned_model = self.Tune()
        report_df_seen = self.Evaluation(self.fake_folder)
        if self.test_fake_folder != None:
            report_df_unseen = self.Evaluation(self.test_fake_folder)

        if self.use_wandb:
            model_artifact = wandb.Artifact(
                name=f"{self.method_name}-{self.model_name}",
                type="model"
            )
            #model_artifact.add_file(model_save_path)
            wandb.log_artifact(model_artifact)
            wandb.finish()

        return tuned_model


In [None]:
# Hope fully we run this too and compare it with PEFT, but full fine tuning with 10+ epochs takes very long time.
# To see how the training works, please set the epoch to <5.

class FullFT(FineTuner):
    def __init__(self, model_name, data_dir, real_folder, fake_folder, num_epochs, batch_size, learning_rate, use_wandb= False, model = None, processor = None):
        super().__init__(model_name, data_dir, real_folder, fake_folder, num_epochs, batch_size, learning_rate, use_wandb, model, processor)
        self.method_name = 'Full_FT'

    def Tune(self):
        """
        Args:
            data_dir (str): Directory containing 'train' and 'val' subdirectories,
                            each with 'real' and 'fake' subdirectories
            real_folder:
            fake_folder:
            num_epochs (int): Number of training epochs
            batch_size (int): Batch size for training
            learning_rate (float): Learning rate for optimizer
            use_wandb (boolean): Use wandb's logging
        """

        # Set device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")

        # Load the model
        #model = timm.create_model(self.model_name, pretrained=True, num_classes=2)
        self.model = self.model.to(device)
        model_save_path= f"{self.model_name}.pth"

        # Get data loader for training and validation
        train_loader, val_loader = self.get_Train_Val_loader()

        # Loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(self.model.parameters(), lr=self.learning_rate)

        # Learning rate scheduler
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=3, verbose=True
        )

        # Training loop
        train_losses = []
        val_losses = []
        train_accuracies = []
        val_accuracies = []
        best_val_acc = 0.0

        # Create a tqdm progress bar for epochs
        epoch_loop = tqdm(range(self.num_epochs), desc="Training Progress", unit="epoch")
        for epoch in epoch_loop:
            # Training phase
            self.model.train()
            running_loss = 0.0
            correct = 0
            total = 0

            for inputs, labels, pathes in train_loader:
                inputs, labels = inputs.to(device), labels.to(device)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward pass
                outputs = self.model(inputs).float()
                loss = criterion(outputs, labels)

                # Backward pass and optimize
                loss.backward()
                optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()


            # Calculate epoch metrics
            train_loss = running_loss / len(train_loader.dataset)
            train_acc = correct / total
            train_losses.append(train_loss)
            train_accuracies.append(train_acc)

            # Validation phase
            self.model.eval()
            val_running_loss = 0.0
            val_correct = 0
            val_total = 0
            all_preds = []
            all_labels = []
            all_probs = []


            with torch.no_grad():
                for inputs, labels, pathes in val_loader: #val_loop:
                    inputs, labels = inputs.to(device), labels.to(device)

                    outputs = self.model(inputs)
                    loss = criterion(outputs, labels)

                    val_running_loss += loss.item() * inputs.size(0)
                    probs = torch.nn.functional.softmax(outputs.float(), dim=1)
                    _, predicted = torch.max(outputs.float(), 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()

                    all_preds.extend(predicted.cpu().numpy())
                    all_labels.extend(labels.cpu().numpy())
                    all_probs.extend(probs.cpu().numpy())

            val_loss = val_running_loss / len(val_loader.dataset)
            val_acc = val_correct / val_total
            val_losses.append(val_loss)
            val_accuracies.append(val_acc)

            # Learning rate scheduler step
            scheduler.step(val_loss)

            print(f"\n\n{'-'*50}")
            print(f"Epoch {epoch+1}/{self.num_epochs} Summary:")
            print(f"{'-'*50}")
            # Save the best model
            if val_acc >= best_val_acc:
                torch.save(self.model.state_dict(), model_save_path)
                print(f"✅ New best model saved! Validation accuracy: {val_acc:.4f} (previous best: {best_val_acc:.4f})")
                best_val_acc = val_acc

            # Also save a checkpoint every 5 epoch
            if epoch + 1 % 5 == 0:
                checkpoint_path = f"checkpoint_epoch_{epoch+1}.pth"
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'train_loss': train_loss,
                    'val_loss': val_loss,
                    'train_acc': train_acc,
                    'val_acc': val_acc,
                    'best_val_acc': best_val_acc
                }, checkpoint_path)
                print(f"Checkpoint saved: {checkpoint_path}")

            print(f"Training:   Loss: {train_loss:.4f} | Accuracy: {train_acc:.4f} ({correct}/{total})")
            print(f"Validation: Loss: {val_loss:.4f} | Accuracy: {val_acc:.4f} ({val_correct}/{val_total})")

            # Calculate and display improvement or regression
            if epoch > 0:
                train_loss_change = train_loss - train_losses[-2]
                train_acc_change = train_acc - train_accuracies[-2]
                val_loss_change = val_loss - val_losses[-2]
                val_acc_change = val_acc - val_accuracies[-2]

                print(f"Changes from previous epoch:")
                print(f"  Train Loss: {train_loss_change:+.4f} | Train Acc: {train_acc_change:+.4f}")
                print(f"  Val Loss: {val_loss_change:+.4f} | Val Acc: {val_acc_change:+.4f}")

            # Display current learning rate
            current_lr = optimizer.param_groups[0]['lr']
            print(f"Current learning rate: {current_lr:.6f}")

            #Load the best model
            print(f"Loading best model from {model_save_path}")
            self.model.load_state_dict(torch.load(model_save_path))

            # Log in wandb
            self.log_wandb_train(
                all_labels,
                all_preds,
                all_probs,
                epoch,
                train_loss,
                train_acc,
                val_loss,
                val_acc,
                optimizer
            )

        # Final summary at the end of training
        tqdm.write(f"\nTraining completed after {self.num_epochs} epochs")
        tqdm.write(f"Best validation accuracy: {best_val_acc:.4f}")

        return self.model

In [None]:
#Intrinsic SAID

In [None]:
from google.colab import files
uploaded = files.upload()

In [None]:
!pip install ninja


In [None]:
from torch.utils.cpp_extension import load

wht = load(
    name="wht",
    sources=["fwh_cpp.cpp", "fwh_cu.cu"],
    verbose=True
)

In [None]:
print(wht)

In [None]:

# The codes are from Armen Aghajanyan from facebook, from paper
# Intrinsic Dimensionality Explains the Effectiveness of Language Model Fine-Tuning
# https://arxiv.org/abs/2012.13255

import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from typing import Tuple, Set
from wht import fast_walsh_hadamard_transform as fast_walsh_hadamard_transform_cuda


def fast_walsh_hadamard_torched(x, axis: int = 0, normalize: bool = True):
    orig_shape = x.size()
    assert axis >= 0 and axis < len(orig_shape), (
        "For a vector of shape %s, axis must be in [0, %d] but it is %d"
        % (orig_shape, len(orig_shape) - 1, axis)
    )
    h_dim = orig_shape[axis]
    h_dim_exp = int(round(np.log(h_dim) / np.log(2)))
    assert h_dim == 2 ** h_dim_exp, (
        "hadamard can only be computed over axis with size that is a power of two, but"
        " chosen axis %d has size %d" % (axis, h_dim)
    )

    working_shape_pre = [int(torch.prod(torch.tensor(orig_shape[:axis])))]
    working_shape_post = [
        int(torch.prod(torch.tensor(orig_shape[axis + 1:])))
    ]
    working_shape_mid = [2] * h_dim_exp
    working_shape = working_shape_pre + working_shape_mid + working_shape_post

    ret = x.view(working_shape)

    for ii in range(h_dim_exp):
        dim = ii + 1
        arrs = torch.chunk(ret, 2, dim=dim)
        assert len(arrs) == 2
        ret = torch.cat((arrs[0] + arrs[1], arrs[0] - arrs[1]), axis=dim)

    if normalize:
        ret = ret / np.sqrt(float(h_dim))

    ret = ret.view(orig_shape)

    return ret


def fastfood_vars(DD, device=0):
    """
    Returns parameters for fast food transform
    :param DD: desired dimension
    :return:
    """
    ll = int(np.ceil(np.log(DD) / np.log(2)))
    LL = 2 ** ll

    # Binary scaling matrix where $B_{i,i} \in \{\pm 1 \}$ drawn iid
    BB = torch.FloatTensor(LL).uniform_(0, 2).type(torch.LongTensor)
    BB = (BB * 2 - 1)
    BB.requires_grad_(False)

    # Random permutation matrix
    Pi = torch.LongTensor(np.random.permutation(LL))
    Pi.requires_grad_(False)

    # Gaussian scaling matrix, whose elements $G_{i,i} \sim \mathcal{N}(0, 1)$
    GG = torch.FloatTensor(LL,).normal_()
    GG.requires_grad_(False)
    divisor = torch.sqrt(LL * torch.sum(torch.pow(GG, 2)))
    return [BB.to(device), Pi.to(device), GG.to(device), divisor.to(device), LL]


def random_vars(desired_dim, intrinsic_dim, device=0):
    """Returns a random matrix of the desired dimension."""
    R = torch.FloatTensor(desired_dim, intrinsic_dim).normal_(std=0.01).to(device)
    R.requires_grad_(False)
    divisor = torch.norm(R)
    return [R, divisor]


def fastfood_torched(x, DD: int, param_list: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]):
    """
    Fastfood transform
    :param x: array of dd dimension
    :param DD: desired dimension
    :return:
    """
    dd = x.size(0)

    BB, Pi, GG, divisor, LL = param_list
    # Padd x if needed
    dd_pad = F.pad(x, pad=(0, LL - dd), value=0.0, mode="constant")
    # From left to right HGPiH(BX), where H is Walsh-Hadamard matrix
    dd_pad = dd_pad * BB

    # HGPi(HBX)
    mul_2 = FastWalshHadamard.apply(dd_pad)

    # HG(PiHBX)
    mul_3 = mul_2[Pi]

    # H(GPiHBX)
    mul_3 = mul_3 * GG

    # (HGPiHBX)
    mul_5 = FastWalshHadamard.apply(mul_3)

    ret = mul_5[:int(DD)]
    ret = ret / \
        (divisor * np.sqrt(float(DD) / LL))
    return ret


def random_torched(intrinsic_vec, param_list: Tuple[torch.Tensor, int]):
    """Random dense transform"""
    R, divisor = param_list
    result = torch.matmul(R, intrinsic_vec)
    # TODO: for now we are not normalizing with the divisor, to be added later.
    return result


class FastWalshHadamard(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(torch.tensor(
            [1 / np.sqrt(float(input.size(0)))]).to(input))
        if input.is_cuda:
            return fast_walsh_hadamard_transform_cuda(input.float(), False)
        else:
            return fast_walsh_hadamard_torched(input.float(), normalize=False)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        if grad_output.is_cuda:
            return input*fast_walsh_hadamard_transform_cuda(grad_output.clone().float(), False).to(grad_output)
        else:
            return input*fast_walsh_hadamard_torched(grad_output.clone().float(), normalize=False).to(grad_output)


class IntrinsicDimensionLight:
    def __init__(self, module: nn.Module, intrinsic_dimension: int,  output_dir,
                 str_filter: Set[str] = set(), said=False, projection="fastfood", device="cpu"):
        """
        Adds hook only for the parameters selected inside the str_filter, and if str_filter is empty, this selects
        all the parameters with gradient = True.
        """
        self.projection = projection
        self.name_base_localname = []
        self.initial_value = dict()
        self.projection_params = {}
        self.said = said
        self.device = device
        self.said_size = len(list(module.named_parameters()))
        if self.said:
            print(f"Intrinsic Dimension: {intrinsic_dimension}")
            print(f"SAID Size: {self.said_size}")
            assert intrinsic_dimension > self.said_size
            intrinsic_dimension -= (self.said_size+1)

        self.intrinsic_dimension = intrinsic_dimension
        self.intrinsic_parameter = nn.Parameter(
            torch.zeros((intrinsic_dimension)).cpu() if device=="cpu" else torch.zeros((intrinsic_dimension)).cuda())
        module.register_parameter(
            "intrinsic_parameter", self.intrinsic_parameter)
        setattr(module, "intrinsic_parameter", self.intrinsic_parameter)

        length = 0
        for name, param in module.named_parameters():
            if param.requires_grad and (len(str_filter) == 0 or any([x in name for x in str_filter])):
                length += 1
                self.initial_value[name] = v0 = (
                    param.clone().detach().requires_grad_(False).to(self.intrinsic_parameter.device)
                )
                DD = np.prod(v0.size())
                self.projection_params[name] = self.get_projection_params(DD, self.intrinsic_parameter.device)
                base, localname = module, name
                while "." in localname:
                    prefix, localname = localname.split(".", 1)
                    base = base.__getattr__(prefix)
                self.name_base_localname.append((name, base, localname))
                if "intrinsic_parameter" not in name:
                    param.requires_grad_(False)
        if said:
            self.intrinsic_parameter_said = nn.Parameter(
                torch.ones((length)).cpu() if device == "cpu" else torch.ones((length)).cuda())
            module.register_parameter(
                "intrinsic_parameter_said", self.intrinsic_parameter_said)
            setattr(module, "intrinsic_parameter_said",
                    self.intrinsic_parameter_said)

    def get_projection_params(self, DD, device):
        if self.projection == "fastfood":
            return fastfood_vars(DD, device)
        elif self.projection == "random":
            return random_vars(DD, self.intrinsic_dimension, device)

    def move_to(self, x_tuple, target):
        if isinstance(x_tuple, torch.Tensor):
            return x_tuple.to(target)
        a = []
        for x in x_tuple:
            if isinstance(x, torch.Tensor):
                a.append(x.to(target))
            else:
                a.append(x)
        return tuple(a)

    def requires_to(self, x_tuple, target):
        if isinstance(x_tuple, torch.Tensor):
            x_tuple.requires_grad_(target)
        for x in x_tuple:
            if isinstance(x, torch.Tensor):
                x.requires_grad_(target)

    def projection_vars_requires_grad_(self, requires_grad):
        for item in self.projection_params.items():
            self.requires_to(item, requires_grad)

    def get_projected_param(self, intrinsic_vec, DD, projection_params, init_shape):
        if self.projection == "fastfood":
            return fastfood_torched(intrinsic_vec, DD, projection_params).view(
                    init_shape
                )
        elif self.projection == "random":
            return random_torched(intrinsic_vec, projection_params).view(
                init_shape
            )

    def __call__(self, module, inputs):
        index = 0
        with torch.enable_grad():
            for name, base, localname in self.name_base_localname:
                if localname == "intrinsic_parameter":
                    continue
                if self.device == "cpu":
                    self.initial_value[name] = self.initial_value[name].to(
                        getattr(base, localname))
                    device_dtype = getattr(base, localname).dtype

                init_shape = self.initial_value[name].size()
                DD = np.prod(init_shape)
                if self.device == "cpu":
                    self.projection_params[name] = self.move_to(
                        self.projection_params[name], module.intrinsic_parameter.device)

                ray = self.get_projected_param(module.intrinsic_parameter, DD, self.projection_params[name], init_shape)
                if self.said:
                    ray = ray * self.intrinsic_parameter_said[index]
                if self.device == "cpu":
                    param = (self.initial_value[name] + ray).to(device_dtype)
                else:
                    param = (self.initial_value[name] + ray)
                delattr(base, localname)
                setattr(base, localname, param)
                index += 1

    @staticmethod
    def apply(module, intrinsic_dimension, output_dir, str_filter=set(), said=False, projection="fastfood", device="cpu"):
        for k, hook in module._forward_pre_hooks.items():
            if isinstance(hook, IntrinsicDimensionLight) and hook.name == name:
                raise RuntimeError("Cannot register two intrinsic dimension hooks on "
                                   "the same parameter {}".format(name))
        fn = IntrinsicDimensionLight(
            module, intrinsic_dimension, output_dir, str_filter, said, projection, device)
        module.register_forward_pre_hook(fn)
        return fn

    @staticmethod
    def apply_with_tensor(module, intrinsic_vector, str_filter=set()):
        assert isinstance(intrinsic_vector,
                          torch.Tensor) and intrinsic_vector.ndim == 1

        for k, hook in module._forward_pre_hooks.items():
            if isinstance(hook, IntrinsicDimensionLight) and hook.name == name:
                raise RuntimeError("Cannot register two intrinsic dimension hooks on "
                                   "the same parameter {}".format(name))
        fn = IntrinsicDimensionLight(
            module, intrinsic_vector.size(0), str_filter, False)
        fn.intrinsic_parameter = intrinsic_vector
        module.register_forward_pre_hook(fn)
        return fn


def intrinsic_dimension(module, intrinsic_dimension,  output_dir, str_filter, projection, device="cpu"):
    IntrinsicDimensionLight.apply(
        module, intrinsic_dimension,  output_dir, str_filter, False, projection, device)
    return module


def intrinsic_dimension_said(module, intrinsic_dimension,  output_dir, str_filter, projection, device="cpu"):
    IntrinsicDimensionLight.apply(
        module, intrinsic_dimension,  output_dir, str_filter, True, projection, device)
    return module


In [None]:
class IntrinsicSAIDFT(FullFT):
    def __init__(self, model_name, data_dir, real_folder, fake_folder, num_epochs, batch_size, learning_rate, use_wandb= False, model = None, processor = None):
        super().__init__(model_name, data_dir, real_folder, fake_folder, num_epochs, batch_size, learning_rate, use_wandb, model, processor)
        self.method_name = 'IntrinsicSAID_FT'
        intrinsic_dim = 20000  # Try different values like 100, 500, 1000
        IntrinsicDimensionLight.apply(
            module=self.model,
            intrinsic_dimension=intrinsic_dim,
            output_dir=None,
            str_filter=set(),          # You can specify param name patterns to apply this to specific layers
            said=True,                # Spatially Adaptive Intrinsic Dim (optional)
            projection="fastfood",     # Or "random"
            device="cuda"               # Or "cuda"
        )
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        # print(total_params, trainable_params)

    def Tune(self):
        """
        Args:
            data_dir (str): Directory containing 'train' and 'val' subdirectories,
                            each with 'real' and 'fake' subdirectories
            real_folder:
            fake_folder:
            num_epochs (int): Number of training epochs
            batch_size (int): Batch size for training
            learning_rate (float): Learning rate for optimizer
            use_wandb (boolean): Use wandb's logging
        """

        return super().Tune()

In [None]:
! pip install git+https://github.com/purewater0901/deepfake-detection.git

from deepfake_detection.model.dfdet import DeepfakeDetectionModel
from deepfake_detection.config import Config

In [None]:
model_path = "weights/model.ckpt"
if not os.path.exists(model_path):
    print("Downloading model")
    os.makedirs("weights", exist_ok=True)
    os.system(f"wget https://huggingface.co/yermandy/deepfake-detection/resolve/main/model.ckpt -O {model_path}")
ckpt = torch.load(model_path, map_location="cpu")

model = DeepfakeDetectionModel(Config(**ckpt["hyper_parameters"]))
model.load_state_dict(ckpt["state_dict"])
print(model)

# Get preprocessing function

processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14", use_fast=True)

In [None]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(total_params, trainable_params)
for param in model.parameters():
    param.requires_grad = True
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(total_params, trainable_params)

In [None]:
config = {
    'model_name': 'vit_large_patch14_clip_224.openai',
    "data_dir":"/root/.cache/kagglehub/datasets/katsuyamucb/madde-dataset/versions/13",
    "real_folder" : 'Real_split',
    "fake_folder" : 'All_fakes_split',
    "num_epochs":10,
    "batch_size":16,
    "learning_rate":0.0001,
    'use_wandb':False
}

intrinsic = IntrinsicSAIDFT(model = model, processor = processor, **config)
#FullFineTuner.set_TestFolder('StyleGAN_split')
tuned_model = intrinsic.Experiment('clip_pretrained')