In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from dfdetect.data_loaders import DFDC_preprocessed
from tqdm.auto import tqdm

In [None]:
%load_ext autoreload
%autoreload 2
sns.set()

In [None]:
data = DFDC_preprocessed("./dfdc_preprocessed")

In [None]:
vid_0, lab_0 = data[0]
vid_0.shape

In [None]:
from torchvision import transforms
from dfdetect.data_loaders import DFDC_preprocessed
from dfdetect.utils import CropResize, FrameBasedTransforms, rgb_to_ycc

all_transforms = FrameBasedTransforms(
    transforms.Compose(
        [
            CropResize(128),
            rgb_to_ycc,
            transforms.ToTensor(),
            transforms.ConvertImageDtype(torch.float),
            # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
)
all_transforms = transforms.Compose([all_transforms])

In [None]:
vid_0_transformed = all_transforms(vid_0)
vid_0_transformed.shape

In [None]:
# YCbCr components
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
for i in range(3):
    axs[i].imshow(vid_0_transformed[0, i], cmap="gray")

In [None]:
from numpy import r_
import scipy
import scipy.fftpack


def block_dct(frame):
    for component in frame:

        # Block DCT from https://inst.eecs.berkeley.edu/~ee123/sp16/Sections/JPEG_DCT_Demo.html
        imsize = frame.shape
        dct = np.zeros(imsize)

        def dct2(a):
            return scipy.fftpack.dct(
                scipy.fftpack.dct(a, axis=0, norm="ortho"), axis=1, norm="ortho"
            )

        def idct2(a):
            return scipy.fftpack.idct(
                scipy.fftpack.idct(a, axis=0, norm="ortho"), axis=1, norm="ortho"
            )

        # Do 8x8 DCT on image (in-place)
        for i in r_[: imsize[0] : 8]:
            for j in r_[: imsize[1] : 8]:
                dct[i : (i + 8), j : (j + 8)] = dct2(im[i : (i + 8), j : (j + 8)])

In [None]:
def dct2(a):
    return scipy.fftpack.dct(
        scipy.fftpack.dct(a, axis=0, norm="ortho"), axis=1, norm="ortho"
    )


from dfdetect.utils import dct_2d as dct2_torch

In [None]:
dct_with_torch = dct2_torch(vid_0_transformed[0], norm="ortho")

In [None]:
dct_with_scipy = np.stack(
    [dct2(vid_0_transformed[0, i].numpy()) for i in range(3)], axis=0
)

In [None]:
dct_with_scipy.shape

In [None]:
np.isclose(
    dct_with_torch.numpy(), dct_with_scipy, atol=1e-6
).mean()  # Confirming that dct with torch and scipy is similar

In [None]:
from einops import rearrange

frame = vid_0_transformed[0]
# frame = torch.zeros(3, 512, 512)
dct_patch_size = 8
patchs = rearrange(
    frame, "c (h p1) (w p2) -> c h w p1 p2", p1=dct_patch_size, p2=dct_patch_size
)
patchs_dct = dct2_torch(patchs)
patchs = rearrange(
    patchs_dct, "c h w p1 p2 -> (c p1 p2) h w", p1=dct_patch_size, p2=dct_patch_size
)

In [None]:
patchs.shape