In [3]:
import torch
from torchvision.transforms import v2
import sys
sys.path.append('src')

In [13]:
from src.models.frame import FrameModel

# ff_attribution -> 224
# swinv2_faceswap -> 256
rs_size = 224
interpolation = 3
inference_transforms = v2.Compose([
    v2.ToImage(),
    v2.Resize(rs_size, interpolation=interpolation, antialias=False),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
target_transforms = lambda x: torch.tensor(x, dtype=torch.float32)

In [14]:
# ff_attribution or swinv2_faceswap
model = FrameModel.load_from_checkpoint("./checkpoints/ff_attribution.ckpt").eval()

In [15]:
# inference on dummy data
# (B, C, H, W)
test_image = torch.randn(1, 3, 224, 224).to(model.device)
test_image = inference_transforms(test_image)

with torch.no_grad():
    output = model(test_image)
    print(output)
    print(output.shape)

tensor([[0.0613, 0.5350, 0.1486, 0.0287, 0.2265]], device='cuda:1')
torch.Size([1, 5])


In [17]:
# inference on a dataset

import numpy as np
from src.data.datasets import DeepfakeDataset
ds = DeepfakeDataset(
    "./src/data/csvs/ff_test.csv",
    "/fssd8/user-data/spirosbax/data/xai_test_data.lmdb",
    transforms=inference_transforms,
    target_transforms=target_transforms,
    task="binary"
)

idx = np.random.randint(0, len(ds))
with torch.no_grad():
    frame, label = ds[idx]
    frame = frame.to(model.device)
    output = model(frame.unsqueeze(0))
    print(output)
    print(output.shape)
    print(label)

tensor([[5.6502e-02, 9.4350e-01, 1.5102e-10, 5.7913e-10, 4.3230e-08]],
       device='cuda:1')
torch.Size([1, 5])
tensor(1.)
