In [None]:
import os
import numpy as np
import pytorch_lightning as pl
import torch
from dfdetect.config import Paths, CLA
from dfdetect.data_loaders import (
    DFDC_preprocessed_single_frames,
    Oversampled,
    DFDC,
    CelebDFV2_preprocessed,
    CelebDFV2,
)
from dfdetect.models.dfdetect_model import (
    FeatureType,
    SteganalysisModel,
    TimmModel,
    TNTPl,
    SrnetPl,
    SRNetDouble,
)
import torch.nn.functional as F
from dfdetect.utils import CropResize, Slurm
import dfdetect.utils as utils

from torchvision import transforms
from dfdetect.data_augmentation import DataAugmentations
from copy import copy

import pytorch_grad_cam as gc
import matplotlib.pyplot as plt
from dfdetect.data_loaders import Oversampled

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
torch.cuda.is_available()

In [None]:
def setup_args():
    """Setup all the command line arguments supported in this script."""
    CLA.register("project_name", "DFDC-full-single-images")
    CLA.register("dataset", "DFDC", choices=["DFDC", "celebdfv2"])
    CLA.register("seed", 0x1B)
    CLA.register("test", False)
    CLA.register("valid", False)
    CLA.register("test_all_frames", False)
    CLA.register("debug", False)
    CLA.register("model", "stegano")
    CLA.register("stegano_spatial_model", "legacy_seresnet18")
    CLA.register("target_image_size", 128)
    CLA.register("spatial_features", True)
    CLA.register("spectral_features", True)
    CLA.register("dct_features", True)
    CLA.register("rgb_to_ycc", True)
    CLA.register("rgb_to_gray", False)
    CLA.register("batch_size", 64)
    CLA.register("accumulate_grad_batches", 1)
    CLA.register("decision_threshold", 0.5)  # When predicting accuracy and f1 score
    CLA.register("nb_epochs", 200)
    CLA.register("act_function", "relu", choices=["relu", "gelu"])
    CLA.register("srnet_double_nb_features", 512)
    CLA.register("srnet_double_num_type2_layers", 5)
    CLA.register("srnet_double_type_3_layer_sizes", [16, 64, 128, 256])
    CLA.register("srnet_double_type_2_layer_feat_size", 16)
    CLA.register("srnet_double_type_1_kernel_size_spatial", 3)
    CLA.register("srnet_double_type_1_kernel_size_spectral", 2)
    CLA.register("srnet_double_with_attention", False)
    CLA.register("plot_ttest", False)
    CLA.register("confidence_to_stop", 0.05)

    CLA.parse()


def get_transforms(set_type):
    """Get transforms for a single RGB frame, only set_type=train will add probabilistic data augmentations"""
    # Means and variance computed from the training set with the function
    # utils.compute_running_stats(train_set)

    assert not (
        CLA.rgb_to_ycc & CLA.rgb_to_gray
    ), "Only one of rgb_to_ycc and rgb_to_gray can be True"
    if CLA.rgb_to_ycc:
        means = torch.tensor([0.3443, 0.5621, 0.4715])
        stds = torch.tensor([0.0377, 0.0017, 0.0010]).sqrt_()
    elif CLA.rgb_to_gray:
        means = torch.tensor([0.5])
        stds = torch.tensor([0.0833]).sqrt_()
    else:
        means = torch.tensor([0.485, 0.456, 0.406])
        stds = torch.tensor([0.229, 0.224, 0.225])
    all_transforms = [
        CropResize(CLA.target_image_size),
        transforms.ToTensor(),
        transforms.ConvertImageDtype(torch.float),
        transforms.Normalize(mean=means, std=stds),
    ]
    if CLA.rgb_to_ycc:
        all_transforms.insert(0, utils.rgb_to_ycc)
    elif CLA.rgb_to_gray:
        all_transforms.insert(0, utils.rgb_to_gray)
    if set_type == "train":  # Add data augmentation for train set in RGB domain
        all_transforms.insert(0, DataAugmentations())

    return transforms.Compose(all_transforms)


def get_model():
    model_name = CLA.model.lower()
    features = 0
    if CLA.dct_features:
        features |= FeatureType.DCT
    if CLA.spatial_features:
        features |= FeatureType.YCC

    if model_name == "stegano":
        model_class = SteganalysisModel
        model_args = dict(
            features=features,
            original_stegano=True,
            spatial_model=CLA.stegano_spatial_model,
        )
    elif model_name == "tnt":
        assert CLA.target_image_size == 224
        model_class = TNTPl
        model_args = dict(im_size=CLA.target_image_size)
    elif model_name == "seresnet":
        model_class = SteganalysisModel
        model_args = dict(
            features=features,
            original_stegano=False,
            spatial_model=CLA.stegano_spatial_model,
        )
    elif model_name == "srnet":
        model_class = SrnetPl
        model_args = dict(in_channels=1 if CLA.rgb_to_gray else 3)
    elif model_name == "srnetdouble":
        model_class = SRNetDouble
        model_args = dict(
            features=features,
            srnet_args_spatial={
                "act_function": F.gelu if CLA.act_function == "gelu" else F.relu,
                "nb_features": CLA.srnet_double_nb_features,
                "type_3_layer_sizes": [
                    int(tmp) for tmp in CLA.srnet_double_type_3_layer_sizes
                ],
                "num_type2_layers": CLA.srnet_double_num_type2_layers,
                "type_2_layer_feat_size": CLA.srnet_double_type_2_layer_feat_size,
                "type_1_kernel_size": CLA.srnet_double_type_1_kernel_size_spatial,
            },
            srnet_args_spectral={
                "act_function": F.gelu if CLA.act_function == "gelu" else F.relu,
                "nb_features": CLA.srnet_double_nb_features,
                "type_3_layer_sizes": [
                    int(tmp) for tmp in CLA.srnet_double_type_3_layer_sizes
                ],
                "num_type2_layers": CLA.srnet_double_num_type2_layers,
                "type_2_layer_feat_size": CLA.srnet_double_type_2_layer_feat_size,
                "type_1_kernel_size": CLA.srnet_double_type_1_kernel_size_spectral,
            },
            with_attention=CLA.srnet_double_with_attention,
        )
    else:
        model_class = TimmModel
        model_args = dict(model_name=model_name)

    model_args["decision_threshold"] = CLA.decision_threshold

    if Paths.previous_checkpoint is not None and os.path.exists(
        Paths.previous_checkpoint
    ):
        model = model_class.load_from_checkpoint(
            Paths.previous_checkpoint, **model_args
        )
    else:
        model = model_class(**model_args)

    return model


def get_dataset():
    if CLA.dataset == "DFDC":
        if CLA.test_all_frames:
            dataset = DFDC(
                Paths.DFDC.test_set if CLA.test else Paths.DFDC.validation_set,
                is_test=True,
            )
            return dataset
        elif CLA.test or CLA.valid:
            dataset = DFDC_preprocessed_single_frames(
                (
                    Paths.DFDC.preprocessed_dataset_single_frames_test
                    if CLA.test
                    else Paths.DFDC.preprocessed_dataset_single_frames_val
                ),
                transforms=get_transforms("test"),
            )
            return dataset
        else:  # training
            train_set = DFDC_preprocessed_single_frames(
                Paths.DFDC.preprocessed_dataset_single_frames_train,
                transforms=get_transforms("train"),
            )

            train_set = Oversampled(
                train_set
            )  # balance classes for training with oversampling

            val_set = DFDC_preprocessed_single_frames(
                Paths.DFDC.preprocessed_dataset_single_frames_val,
                transforms=get_transforms("val"),
            )
            return train_set, val_set
    elif CLA.dataset == "celebdfv2":
        cls, transforms, path = None, None, None
        if CLA.test_all_frames:
            cls = CelebDFV2
            path = Paths.CelebDFV2.dataset_path
        else:
            cls = CelebDFV2_preprocessed
            path = Paths.CelebDFV2.preprocessed_path
            transforms = get_transforms("test")

        dataset = cls(path, is_train=not CLA.test, transforms=transforms)

        if not CLA.test:
            train_set, val_set = dataset.split_train_val(ratio=0.8)
            # train_set, val_set = torch.utils.data.random_split(
            #     dataset, (train_len, len(dataset) - train_len), generator=torch.Generator().manual_seed(CLA.seed)
            # )
            train_set.dataset = copy(dataset)  # Full copy for training transforms
            train_set.dataset.transforms = get_transforms("train")
            if CLA.valid:
                return val_set
            else:
                return Oversampled(train_set), val_set

        return dataset
    return None

In [None]:
import sys


del sys.argv[1:]



sys.argv += [
    "--no_spectral_features",
    "--model=srnet",
    "--valid",
]

In [None]:
setup_args()
CLA.dataset

In [None]:
Paths.previous_checkpoint = "./checkpoints/5090479/checkpoint_epoch=25-val_accuracy_epoch=0.76066-val_auroc_epoch=0.84424.ckpt"
Paths.previous_checkpoint

In [None]:
device = torch.device("cuda")
model = get_model()
model = model.to(device)
model = model.eval()

In [None]:
num_workers = 2
batch_size = 64
data_loader_args = {
    "batch_size": batch_size,
    "num_workers": num_workers,
    "pin_memory": True,
}
dataset = get_dataset()
all_transforms = dataset.transforms
dataset.transforms = None

In [None]:
# dataset = Oversampled(dataset)

In [None]:
# model.spatial_feat_extractor = model.ycc_feat_extractor
# model.spectral_feat_extractor = model.dct_feat_extractor

In [None]:
# model.ycc_feat_extractor.feat_extractor[9]

In [None]:
# len(model.ycc_feat_extractor.feat_extractor)

In [None]:
layers = [
    "layer1",
    "layer2",
    "layer31",
    "layer32",
    "layer41",
    "layer42",
    "layer51",
    "layer52",
    "layer61",
    "layer62",
    "layer71",
    "layer72",
    "layer81",
    "layer82",
    "layer83",
    "layer91",
    "layer92",
    "layer93",
    "layer101",
    "layer102",
    "layer103",
    "layer111",
    "layer112",
    "layer113",
    "layer121",
    "layer122",
]

In [None]:
for layer_nb in range(len(layers)):
    layer_name = layers[layer_nb]
    layer = getattr(model.model, layer_name)
    cls = gc.GradCAM
    cam_ycc = cls(model=model, target_layers=[layer], use_cuda=True)
    # cam_dct = cls(model=model, target_layers=[model.dct_feat_extractor.feat_extractor[layer_nb]], use_cuda=True)
    cam_ycc.batch_size = 32
    # cam_dct.batch_size=32

    num_samples = 6
    fig, axs = plt.subplots(num_samples, 2, figsize=(4 * 2, num_samples * 4))
    fig.suptitle(f"{cls.__name__} of feature extractors at layer {layer_nb}")
    np.random.seed(0x1B)

    for i, rnd_sample in enumerate(np.random.choice(len(dataset), size=num_samples)):
        img, label = dataset[rnd_sample]
        img_tensor = all_transforms(img).unsqueeze(0)
        # Real image after transforms

        axs[i, 0].imshow(img)
        axs[i, 0].set_title(
            "Image is " + ("real" if dataset[rnd_sample][1] else "fake")
        )

        # Cam YCC
        cam_ycc_out = cam_ycc(img_tensor, eigen_smooth=True)
        axs[i, 1].imshow(cam_ycc_out[0])
        axs[i, 1].set_title(f"{cls.__name__} of image")

        # DCT Transform
        # dct_tensor = model.post_dct(dct_2d(model.pre_dct(img_tensor)))
        # dct_tensor = dct_tensor[0].permute(1, 2, 0).numpy()
        # dct_tensor -= dct_tensor.min()
        # dct_tensor /= dct_tensor.max()
        # axs[i, 2].imshow(dct_tensor)

        # axs[i, 2].set_title(f"DCT of image")

        # Cam DCT
        # cam_dct_out = cam_dct(img_tensor)
        # axs[i, 3].imshow(cam_dct_out[0])
        # axs[i, 3].set_title(f"{cls.__name__} of DCT")

        # for j in range(4):
        #    axs[i, j].set_axis_off()

    fig.tight_layout()
    fig.savefig(
        f"grad_cam_dfdc_srnet/{cls.__name__}_layer_{layer_nb}.png", bbox_inches="tight"
    )
    # fig.savefig(f"grad_cam_dfdc_srnet/{cls.__name__}_layer_{layer_nb}.eps", bbox_inches="tight")