In [25]:
import os
# import argparse
from tqdm.notebook import tqdm
from pathlib import Path
import random
import shutil


In [26]:
import torch
import torch.nn as nn
import torchvision

import torch.optim as optim

from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger


In [27]:
from dotted_dict import DottedDict

args = DottedDict({
    'base': 'C://Users//tonyz//PycharmProjects//DeepFake',
    'batch_size': 32,
    'pretrain': False,
    'lr':1e-3,
    'step_lr':True,
})

In [28]:
base = Path(args.base)
train_dir = base / 'data' / 'train'
val_dir = base / 'data' / 'test'
pretrain_dir = base /'data' / 'pretrained.pkl'

In [29]:
def create_train_val_mix(deepfake_method='Deepfakes'):
    original = base / 'data' / 'original_sequences' / 'youtube' / 'c23' / 'images'
    manipulated = base / 'data' / 'manipulated_sequences' / str(deepfake_method) / 'c23' / 'images'

    if not os.path.exists(train_dir):
        os.makedirs(train_dir)
        os.makedirs(train_dir / '1')
        os.makedirs(train_dir / '0')

    if not os.path.exists(val_dir):
        os.makedirs(val_dir)
        os.makedirs(val_dir / '1')
        os.makedirs(val_dir / '0')

    # test_size = 0.33
    test_size = 0.2

    ori_sel = os.listdir(original)
    man_sel = os.listdir(manipulated)

    # random seed 42
    random.seed(42)

    for i, x in enumerate(ori_sel):
        for file in (original / x).iterdir():
            if random.uniform(0, 1) < test_size:
                shutil.copy(file, val_dir / '0')
            else:
                shutil.copy(file, train_dir / '0')
               
    for i, x in enumerate(man_sel):
       for file in (manipulated / x).iterdir():
            if random.uniform(0, 1) < test_size:
                shutil.copy(file, val_dir / '1')
            else:
                shutil.copy(file, train_dir / '1')

if not os.path.exists(train_dir): 
    create_train_val_mix('Face2Face')

In [30]:
class FaceForensics(pl.LightningDataModule):
   def __init__(self):
      super().__init__()
      self.batch_size = args.batch_size

      self.transform = transforms.Compose([
      transforms.Resize((256, 256)),
      transforms.ToTensor(),
      transforms.Normalize([0.5]*3, [0.5]*3)
   ])

   def train_dataloader(self):
      train_dataset = torchvision.datasets.ImageFolder(train_dir, transform=self.transform)
      train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)
      return train_loader
   
   def val_dataloader(self):
      val_dataset = torchvision.datasets.ImageFolder(val_dir, transform=self.transform)
      val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)
      return val_loader

In [31]:
class Meso4(pl.LightningModule):
    def __init__(self, num_classes=2):
        super().__init__()
        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(3, 8, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(8)
        self.relu = nn.ReLU(inplace=True)
        self.leakyrelu = nn.LeakyReLU(0.1)
        
        self.conv2 = nn.Conv2d(8, 8, 5, padding=2, bias=False)
        self.bn2 = nn.BatchNorm2d(16)
        self.conv3 = nn.Conv2d(8, 16, 5, padding=2, bias=False)
        self.conv4 = nn.Conv2d(16, 16, 5, padding=2, bias=False)
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2))
        self.maxpool2 = nn.MaxPool2d(kernel_size=(4, 4))
        self.dropout = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(16*8*8, 16)
        self.fc2 = nn.Linear(16, num_classes)

        self.loss_fn = nn.CrossEntropyLoss()


    # input is dim 256*256*3
    # based on network architecture described in
    # https://arxiv.org/pdf/1809.00888.pdf
    def forward(self, input):
        x = self.conv1(input)
        x = self.relu(x)
        x = self.bn1(x)
        x = self.maxpool1(x)

        x = self.conv2(x)
        x = self.relu(x)
        x = self.bn1(x)
        x = self.maxpool1(x)

        x = self.conv3(x)
        x = self.relu(x)
        x = self.bn2(x)
        x = self.maxpool1(x)

        x = self.conv4(x)
        x = self.relu(x)
        x = self.bn2(x)
        x = self.maxpool2(x)

        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.fc1(x)
        x = self.leakyrelu(x)
        x = self.dropout(x)
        x = self.fc2(x)

        return x
        
    
    def configure_optimizers(self):
        # num_batches = len(self.train_dataloader()) / self.trainer.accumulate_grad_batches
        # using hp from the original paper
        optimizer = optim.Adam(self.parameters(), lr=args.lr)
        lr_scheduler = {'scheduler': optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1),
        }

        return [optimizer], [lr_scheduler]
    
    def training_step(self, train_batch, batch_idx):
        loss = 0
        cnt = 0
        acc = 0

        img, lbl = train_batch
        output = self.forward(img)
        _, preds = torch.max(output.data, 1)
        
        acc += torch.sum(preds == lbl.data).to(torch.float32)
        loss += self.loss_fn(output, lbl)
        cnt += len(lbl)

        batch_dictionary={
            "loss": loss,
            "correct": acc,
            "total": cnt
        }

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_acc", acc/cnt, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return batch_dictionary
    
    def validation_step(self, val_batch, batch_idx):
        loss = 0
        cnt = 0
        acc = 0

        img, lbl = val_batch
        output = self.forward(img)
        _, preds = torch.max(output.data, 1)

        acc += torch.sum(preds == lbl.data).to(torch.float32)
        loss += self.loss_fn(output, lbl)
        cnt += len(lbl)

        batch_dictionary={
            "loss": loss,
            "correct": acc,
            "total": cnt
        }

        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_acc", acc/cnt, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        return batch_dictionary
    
    def on_train_epoch_end(self):
        torch.cuda.empty_cache()


In [32]:
FF = FaceForensics()
model = Meso4()

if args.pretrain:
    model.load_state_dict(torch.load(pretrain_dir))

logger = TensorBoardLogger('lightning_logs', name='logger')

trainer = pl.Trainer(max_epochs=20, logger=logger, accelerator='gpu', max_time={'minutes':30}, limit_val_batches=0.5,
                     default_root_dir = base / 'model' / 'checkpoints')

trainer.fit(model,datamodule=FF)
# trainer.validate(model, datamodule=FF)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name      | Type             | Params
------------------------------------------------
0  | conv1     | Conv2d           | 216   
1  | bn1       | BatchNorm2d      | 16    
2  | relu      | ReLU             | 0     
3  | leakyrelu | LeakyReLU        | 0     
4  | conv2     | Conv2d           | 1.6 K 
5  | bn2       | BatchNorm2d      | 32    
6  | conv3     | Conv2d           | 3.2 K 
7  | conv4     | Conv2d           | 6.4 K 
8  | maxpool1  | MaxPool2d        | 0     
9  | maxpool2  | MaxPool2d        | 0     
10 | dropout   | Dropout2d        | 0     
11 | fc1       | Linear           | 16.4 K
12 | fc2       | Linear           | 34    
13 | loss_fn   | CrossEntropyLoss | 0     
------------------------------------------------
27.9 K    Trainable params
0         Non-trainable params
27.9

Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


In [37]:
trainer.save_checkpoint("20_epoch_other.ckpt")

In [34]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

Launching TensorBoard...