Tous les imports sont ici.

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.io as io
import os
import json
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import timm
import wandb

from PIL import Image
import torchvision.transforms as transforms

FileNotFoundError: [Errno 2] No such file or directory: 'C:\\Users\\tomge/train_dataset\\metadata.json'

Deux fonctions utiles :

In [None]:
def resize_data(data, new_height, new_width, x=0, y=0, height=None, width=None):
    '''
    Data est un tenseur de dimension (..., channels, image_height, image_width)
    Retourne un tensor de dimension (..., channels, new_height, new_width)
    Possibilité de définir un cadre de coin haut gauche (x,y) de dimension (height, width) pour cropper les images contenues dans Data
    (padding si le cadre est trop grand)
    Voir les exemples (situés après la création du dataset)
    '''
    full_height = data.shape[-2]
    full_width = data.shape[-1]
    height = full_height - y if height is None else height
    width = full_width -x if width is None else width
    

    ratio = new_height/new_width
    if height/width > ratio:
        expand_height = height
        expand_width = int(height / ratio)
    elif height/width < ratio:
        expand_height = int(width * ratio)
        expand_width = width
    else:
        expand_height = height
        expand_width = width
    tr = transforms.Compose([
        transforms.CenterCrop((expand_height, expand_width)),
        transforms.Resize((new_height, new_width))
    ])
    return tr(data[...,y:min(y+height, full_height), x:min(x+width, full_width)])

def display_image(img) :
    '''
    affiche l'image img (img est un tenseur)
    '''
    img = img.permute(1,2,0)
    plt.imshow(img)

Création du dataset. On ne garde que nb_frames d'images par vidéo, et on traite les vidéos image par image.

ATTENTION : les images des vidéos sont déjà resized en 256x256, vous pouvez bien sûr modifier cette taille selon ce que prend votre modèle en entrée.

In [None]:
nb_frames = 10

class TrainVideoDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        with open(os.path.join(root_dir, "metadata.json"), 'r') as file:
            self.data= json.load(file)
            self.data = {k : (torch.tensor(float(1)) if v == 'FAKE' else torch.tensor(float(0))) for k, v in self.data.items()}
        self.video_files = [f for f in os.listdir(root_dir) if f.endswith('.mp4') and f in self.data.keys()]
        

    def __len__(self):
        return len(self.video_files)
    
    def __getitem__(self, idx):
        video_path = os.path.join(self.root_dir, self.video_files[idx])
        video, audio, info = io.read_video(video_path, pts_unit='sec')

        video = video.permute(0,3,1,2)
        length = video.shape[0]
        video = video[[i*(length//(nb_frames-1)) for i in range(nb_frames)]]
        
        video = torch.stack([resize_data(img, 256, 256)/255 for img in video])

        label = self.data[self.video_files[idx]]
        
        return video, label

    
dataset=TrainVideoDataset(os.path.expanduser("~/train_dataset"))

Exemple pour la fonction resize.

In [None]:
video, label = dataset[0]
img=video[0]

print(img.shape)
img=resize_data(img, 220 , 300, 50, 40, 200, 190)
display_image(img)

Modèle.

In [None]:
class DeepfakeDetector(nn.Module):
    def __init__(self):
        super().__init__()
        self.dense=nn.Linear(nb_frames*3*256*256,1)
        self.flat=nn.Flatten()
        self.norm=nn.BatchNorm1d(1)
    def forward(self, x):
        y=self.flat(x)
        y=self.dense(y)
        y=self.norm(y)
        return (y+1)/2

Boucle d'entraînement.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_fn = nn.MSELoss()
model = DeepfakeDetector().to(device)


optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

nb_epochs = 5
batch_size = 2



run = wandb.init(
    project="automathon",
    name="nom-de-votre-equipe",
    config={
        "learning_rate": 0.001,
        "architecture": "-",
        "dataset": "DeepFake Detection Challenge",
        "epochs": 10,
        "batch_size": 10,
    },
)


loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

for epoch in range(nb_epochs) :
    pbar = tqdm(loader, desc="Epoch {}".format(epoch), ncols=0)
    for sample in pbar:
        optimizer.zero_grad()

        X, label = sample
        
        X = X.to(device)
        label = label.to(device)
        label_pred = model(X)
        label=torch.unsqueeze(label,dim=1)
        loss = loss_fn(label, label_pred)
        loss.backward()
        optimizer.step()
        
        run.log({"training_loss": loss.item()}, step=epoch)
        print(f"Loss {loss.item():.4f}")
    pbar.close()

Boucle de test : crée la liste des prédictions de votre modèle sur les données de test.

ATTENTION : on utilise un autre DataLoader (car on a pas accès aux labels pour test), où on resize également les images en 256x256.

In [None]:

class TestVideoDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.video_files = [f for f in os.listdir(root_dir) if f.endswith('.mp4')]

    def __len__(self):
        return len(self.video_files)
    
    def __getitem__(self, idx):
        video_path = os.path.join(self.root_dir, self.video_files[idx])
        video, audio, info = io.read_video(video_path, pts_unit='sec')

        video = video.permute(0,3,1,2)
        length = video.shape[0]
        video = video[[i*(length//(nb_frames-1)) for i in range(nb_frames)]]
        
        video = torch.stack([resize_data(img, 256, 256)/255 for img in video])

        return video


test_data=TestVideoDataset(os.path.expanduser("~/test_dataset"))

test_loader=DataLoader(test_data, batch_size=batch_size, shuffle=False) 

predictions=[]

for sample in test_loader :
    X= sample
    X = X.to(device)
    label_pred = model(X)

    predictions.append(label_pred)

Création du fichier tests.csv qui contient vos prédictions. 

In [None]:
row_id_column_name = "ID"

y_pred = pd.DataFrame(index= range(len(label_pred)))
y_pred["TARGET"]= pd.DataFrame(torch.detach(label_pred).numpy())
print(y_pred)

y_pred.to_csv("tests", sep=',', index_label='ID')