# 人工知能特論2 課題 Deep Fake Challenge


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import torch.optim as optim
from torch.utils.data import random_split
import matplotlib.pyplot as plt
import torchvision.utils as vutils
from torchvision.io import read_image
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import datetime

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
from pathlib import Path
from dataclasses import dataclass

@dataclass
class Shape:
    C: int
    H: int
    W: int
    N: int

    def tuple(self):
        return (self.C, self.H, self.W, self.N)

@dataclass
class Config:
    # fixed seed
    seed = 111
    batch_size: int = 32

    # load the checkpoint from the start_epoch
    start_epoch: int = 0
    n_epochs: int = 10

    # - /
    #   VideoFrames/{id}_{frame}.jpg
    #   FakeFaces_test.csv: {file_name}, {REAL/FAKE}
    #   FakeFaces_train.csv: {file_name}, {REAL/FAKE}
    dataset_path = Path('./dataset/FakeFaces')
    # -1 to use all data
    dataset_size: int = -1
    # (C, H, W, N) = (3, 360, 640, 60)
    data_shape = Shape(3, 360, 640, 60)
    
    exp_id: str = 'deepfake_detect_attn_0715'

config = Config()

In [12]:
import random

random.seed((config.seed))
torch.manual_seed(config.seed)

<torch._C.Generator at 0x253b6c66e30>

In [13]:
def setup_tensorboard(id):
    # template = "%Y-%m-%d_%H-%M-%S"
    print(f'logdir=runs/{id}')
    writer = SummaryWriter(f'runs/{id}')
    return writer

writer = setup_tensorboard(config.exp_id)

logdir=runs/deepfake_detect_attn_0715


## Dataset, Dataloader
- video_id: 0..=399, 欠落有り
- 640x360, 3 channels, 60 frames

In [14]:
import pandas as pd

df = pd.read_csv(config.dataset_path / 'FakeFaces_train.csv')
df.head()

Unnamed: 0,File Name,Label
0,289_065.jpg,REAL
1,356_295.jpg,REAL
2,356_150.jpg,REAL
3,002_275.jpg,REAL
4,385_290.jpg,REAL


In [15]:
class DeepFakeDetectDataset(Dataset):
    def __init__(self, dataset_dir: Path):
        self.dataset_dir = dataset_dir
        # "File Name", "Label"
        self.df = pd.read_csv(dataset_dir / 'FakeFaces_train.csv') \
            .replace({ 'REAL': 0, 'FAKE': 1 })
        self.len = len(self.df)

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        file_name, label = row
        file_path = self.dataset_dir / 'VideoFrames' / file_name
        im = read_image(str(file_path)) / 255.
        return im, label

dataset = DeepFakeDetectDataset(config.dataset_path)
im, label = dataset[0]
print(im.shape, label)

tra_val_ratio = 0.95
train_dataset_size = int(len(dataset) * tra_val_ratio)
val_dataset_size = len(dataset) - train_dataset_size

train_dataset, valid_dataset = random_split(dataset, [train_dataset_size, val_dataset_size])
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False)
print(f'train_dataset_size: {train_dataset_size}, valid_dataset_size: {val_dataset_size}')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

torch.Size([3, 360, 640]) 0
train_dataset_size: 21090, valid_dataset_size: 1110


In [16]:
from glob import glob
from PIL import Image
from torchvision.transforms.functional import to_tensor
from typing import Optional

def read_images_as_tensor(shape: Shape, dir: str, video_id: str) -> Optional[torch.Tensor]:
    tensors = []
    frame_paths = glob(f'{dir}/{video_id}_*.jpg')
    if len(frame_paths) != shape.N:
        return None

    for frame in sorted(frame_paths):
        img = Image.open(frame)
        # (C, H, W)
        img_tensor = to_tensor(img)
        tensors.append(img_tensor)
    # (C, H, W, N)
    return torch.stack(tensors, dim=3)

# video_id = 0
# video_tensor = read_images_as_tensor(config.data_shape, f'{config.dataset_path}/VideoFrames', f'{video_id:03d}')
# assert(video_tensor.shape == config.data_shape.tuple())

## Model

In [17]:

def conv_with_bn_relu(
    in_channels: int, 
    out_channels: int, 
    kernel_size: int, 
    stride: int, 
    padding: int
) -> nn.Sequential:
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
        nn.BatchNorm2d(out_channels),
        nn.ReLU()
    )


class SABlock(nn.Module):
    def __init__(self, in_channels: int):
        super(SABlock, self).__init__()
        self.net = nn.Sequential(
            conv_with_bn_relu(in_channels, in_channels, 3, 1, 1),
            nn.Conv2d(in_channels, 1, 1, 1, 0),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x) * x


# サンプルコード参考
class Model(nn.Module):
    def __init__(self, config: Config):
        super(Model, self).__init__()

        (C, H, W, N) = config.data_shape.tuple()
        self.net = nn.Sequential(
            conv_with_bn_relu(C, 32, 5, 5, 0), # size: 1/5
            conv_with_bn_relu(32, 64, 3, 1, 1),
            nn.MaxPool2d(kernel_size=2), # 1/10
            SABlock(64),
            conv_with_bn_relu(64, 128, 3, 1, 1),
            nn.MaxPool2d(kernel_size=2), # 1/20
            conv_with_bn_relu(128, 128, 3, 1, 1),
            nn.MaxPool2d(kernel_size=2), # 1/40
            nn.Flatten(),
            # size is now 1/40
            nn.Linear(128 * (H // 40) * (W // 40) , 1024),
            nn.ReLU(),
            nn.Linear(1024, N),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


In [18]:

def save_checkpoint(model: Model, optimizer, config: Config, epoch: int):
    print(f'save models @ epoch={epoch}')
    torch.save({
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }, f'models/{config.exp_id}_{epoch}.pt')


def load_checkpoint(config: Config, start_epoch: int):
    print(f'load models @ epoch={start_epoch}')
    checkpoint = torch.load(f'models/{config.exp_id}_{start_epoch}.pt')

    config.start_epoch = checkpoint['epoch'] + 1
    model = Model(config).to(device)
    model.load_state_dict(checkpoint['model'])
    model.train()

    optimizer = optim.Adam(model.parameters(), lr=config.lr)
    optimizer.load_state_dict(checkpoint['optimizer'])
    
    return model, optimizer

In [19]:
from tqdm import tqdm

def train(
    writer: SummaryWriter,
    model: Model,
    optimizer,
    dataloader: DataLoader,
    validation_dataloader: DataLoader,
    config: Config,
):
    criterion = nn.CrossEntropyLoss()
    print(f'criterion: {criterion}')

    for epoch in range(config.n_epochs):
        now = datetime.datetime.now
        print(f'[{now()}] Epoch {epoch}')

        # train
        loss_avg = 0.
        model.train()
        for i, (x, y) in enumerate(tqdm(dataloader)):
            x, y = x.to(device), y.to(device)
            y_pred = model(x)

            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()

            step = epoch * len(dataloader) + i
            loss_avg += loss.item()
            writer.add_scalar('train/loss', loss.item(), step)
        
        loss_avg /= len(dataloader)
        print(f'[{now()}] Epoch {epoch} tra/loss: {loss_avg:.2f}')

        # validation
        model.eval()
        loss_avg = 0.
        with torch.inference_mode():
            for i, (x, y) in enumerate(tqdm(validation_dataloader)):
                x, y = x.to(device), y.to(device)
                y_pred = model(x)

                loss = criterion(y_pred, y)
                step = epoch * len(validation_dataloader) + i
                loss_avg += loss.item()
                writer.add_scalar('validation/loss', loss.item(), step)

        loss_avg /= len(validation_dataloader)
        print(f'[{now()}] Epoch {epoch} val/loss: {loss_avg:.2f}')

        save_checkpoint(model, optimizer, config, epoch)


# 学習

In [20]:

def clean_cache():
    # empty cache
    torch.cuda.empty_cache()
    # print(torch.cuda.memory_summary(device=None, abbreviated=False))

In [22]:
clean_cache()
model = Model(config).to(device)
optimizer = optim.Adam(model.parameters())

train(writer, model, optimizer, train_dataloader, valid_dataloader, config)

criterion: CrossEntropyLoss()
[2022-07-15 13:43:59.640790] Epoch 0


100%|██████████| 660/660 [03:39<00:00,  3.01it/s]


[2022-07-15 13:47:39.118202] Epoch 0 tra/loss: 6.72


100%|██████████| 35/35 [00:10<00:00,  3.46it/s]


[2022-07-15 13:47:49.233689] Epoch 0 val/loss: 4.88
save models @ epoch=0
[2022-07-15 13:47:49.930801] Epoch 1


100%|██████████| 660/660 [02:28<00:00,  4.44it/s]


[2022-07-15 13:50:18.722059] Epoch 1 tra/loss: 1.51


100%|██████████| 35/35 [00:06<00:00,  5.43it/s]


[2022-07-15 13:50:25.172107] Epoch 1 val/loss: 1.49
save models @ epoch=1
[2022-07-15 13:50:25.871762] Epoch 2


100%|██████████| 660/660 [02:35<00:00,  4.24it/s]


[2022-07-15 13:53:01.402935] Epoch 2 tra/loss: 0.72


100%|██████████| 35/35 [00:08<00:00,  4.37it/s]


[2022-07-15 13:53:09.419449] Epoch 2 val/loss: 0.58
save models @ epoch=2
[2022-07-15 13:53:10.105015] Epoch 3


100%|██████████| 660/660 [02:33<00:00,  4.29it/s]


[2022-07-15 13:55:44.069585] Epoch 3 tra/loss: 0.53


100%|██████████| 35/35 [00:06<00:00,  5.21it/s]


[2022-07-15 13:55:50.792336] Epoch 3 val/loss: 0.51
save models @ epoch=3
[2022-07-15 13:55:51.485112] Epoch 4


 32%|███▏      | 208/660 [00:49<01:36,  4.67it/s]