In [5]:
import torch
from torchvision.transforms import v2
import sys
sys.path.append('src')

In [6]:
from src.models.frame import FrameModel

# ff_attribution -> 224
# swinv2_faceswap -> 256
rs_size = 224
interpolation = 3
inference_transforms = v2.Compose([
    v2.ToImage(),
    v2.Resize(rs_size, interpolation=interpolation, antialias=False),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
target_transforms = lambda x: torch.tensor(x, dtype=torch.float32)

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
# ff_attribution or swinv2_faceswap
model = FrameModel.load_from_checkpoint("./checkpoints/ff_attribution.ckpt").eval()

model.safetensors: 100%|██████████| 28.8M/28.8M [00:01<00:00, 15.9MB/s]


In [10]:
# inference on dummy data
# (B, C, H, W)
test_image = torch.randn(1, 3, 224, 224).to(model.device)
test_image = inference_transforms(test_image)

with torch.no_grad():
    output = model(test_image)
    print(output)
    print(output.shape)

tensor([[0.0773, 0.5556, 0.1229, 0.0376, 0.2067]], device='cuda:1')
torch.Size([1, 5])


In [16]:
# inference on a dataset
import numpy as np
from src.data.datasets import DeepfakeDataset

ds = DeepfakeDataset(
    "./faceforensics_frames.csv",
    "./ff.lmdb",
    transforms=inference_transforms,
    target_transforms=target_transforms,
    task="multiclass"
)

idx = np.random.randint(0, len(ds))
with torch.no_grad():
    frame, label = ds[idx]
    frame = frame.to(model.device)
    output = model(frame.unsqueeze(0))
    print(output)
    print(output.shape)
    print(label)

tensor([[1.4107e-06, 9.9975e-01, 2.5263e-04, 2.1636e-07, 8.3986e-08]],
       device='cuda:1')
torch.Size([1, 5])
tensor(1.)
