In [1]:
import os
from pathlib import Path
from glob import glob

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class DeepfakeDataset(Dataset):
    def __init__(self, ds_name, train=True, transform=None):
        self.path = os.path.join('/media/data1/sangyong/df_datasets/', ds_name)
        if train:
            self.real_path = os.path.join(self.path, 'train/real')
            self.fake_path = os.path.join(self.path, 'train/rake')
        else:
            self.real_path = os.path.join(self.path, 'test/real')
            self.fake_path = os.path.join(self.path, 'test/fake')
            
        self.real_list = glob(os.path.join(self.real_path, '**/*.png'))
        self.fake_list = glob(os.path.join(self.fake_path, '**/*.png'))
        
        self.transform = transform
        
        self.img_list = self.real_list + self.fake_list
        self.class_list = [0]*len(self.real_list) + [1]*len(self.fake_list)
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        img_path = self.img_list[idx]
        label = self.class_list[idx]
        img = Image.open(img_path)
        
        if not self.transform == None:
            img = self.transform(img)
            
        return img, label
    
if __name__ == "__main__":
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(299)
    ])
    
    dataset = DeepfakeDataset(ds_name="DeepFake", train=True, transform=transform)
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=True,
                            drop_last=False)
    for epoch in range(2):
        print(f"epoch : {epoch}")
        for batch in dataloader:
            img, label = batch
            print(img.size(), label)
            break

  from .autonotebook import tqdm as notebook_tqdm


epoch : 0
torch.Size([1, 4, 299, 299]) tensor([0])
epoch : 1
torch.Size([1, 4, 299, 299]) tensor([0])


In [2]:
path = '/media/data1/sangyong/df_datasets/DeepFake'
real_path = os.path.join(path, 'train/real')
fake_path = os.path.join(path, 'train/fake')
real_list = glob(os.path.join(real_path, '**/*.png'))
fake_list = glob(os.path.join(fake_path, '**/*.png'))

print(len(real_list), len(fake_list))


60000 60000


## Model

In [11]:
import torch
import torch.nn as nn


class SeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, strides=1, padding=0, dilation=1, bias=False):
        super(SeparableConv, self).__init__()
        
        # https://gaussian37.github.io/dl-pytorch-conv2d/ dilation, groups 설명
        self.pointwiseconv = nn.Conv2d(in_channels, in_channels, kernel_size, strides, padding, dilation, groups=in_channels, bias=bias)
        self.depthwiseconv = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
    
    def forward(self,x):
        x = self.pointwiseconv(x)
        x = self.depthwiseconv(x)
        return x

class Block(nn.Module):
    def __init__(self, in_channels, out_channels, repeat, strides=1, start_with_relu=True, sizeup_first=True):
        super(Block, self).__init__()
        
        if out_channels != in_channels or strides != 1:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=strides, bias=False), # bias를 False를 두는 이유???
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.skip = nn.Identity()
        
        layers = []
        
        channels = in_channels
        if sizeup_first:
            layers.append(nn.ReLU(inplace=True)) # inplace=True 하면, inplace 연산을 수행함, inplace 연산은 결과값을 새로운 변수에 값을 저장하는 대신 기존의 데이터를 대체하는것을 의미 메모리적 이득
            layers.append(SeparableConv(in_channels, out_channels, kernel_size=3, strides=1, padding=1, bias=False))
            layers.append(nn.BatchNorm2d(out_channels))
            channels = out_channels
            
        for i in range(repeat-1):
            layers.append(nn.ReLU(inplace=True)) # 근데 왜 굳이 해야 해? 안하면 안 될 정도?
            layers.append(SeparableConv(channels, channels, kernel_size=3, strides=1, padding=1, bias=False))
            layers.append(nn.BatchNorm2d(channels))
            
        if not sizeup_first:
            layers.append(nn.ReLU(inplace=True))
            layers.append(SeparableConv(channels, out_channels, kernel_size=3, strides=1, padding=1, bias=False))
            layers.append(nn.BatchNorm2d(out_channels))
            
        if not start_with_relu:
            layers = layers[1:]
        else:
            layers[0] = nn.ReLU(inplace=True) # 다시 한번 확실하게 ReLU로 시작하게 하자
                    
        if strides != 1:
            layers.append(nn.MaxPool2d(3, strides, 1)) # padding을 1로 둠으로서 image shape을 논문과 같이 19X19로 만듦. 근데 self.skip에서 18X18로 나오는거 같은데 어케되는거지
            
        self.layers = nn.Sequential(*layers)
        
    def forward(self, img):
        x = self.layers(img)
        x = x + self.skip(img)
        
        return x
    
class Xception(nn.Module):
    def __init__(self, num_classes=2):
        super(Xception, self).__init__()
        
        self.num_classes = num_classes
        
        # Entry flow
        self.module1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, 2, 0, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32,64,3,1,1,bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        self.module2 = Block(64, 128, 2, 2, False, True)
        self.module3 = Block(128, 256, 2, 2, True, True)
        self.module4 = Block(256, 728, 2, 2, True, True)
        
        # Middle flow
        self.module5_12 = nn.Sequential(*[Block(728,728,3,1,True,True) for _ in range(8)])
        
        # Exit flow
        self.module13 = Block(728, 1024, 2, 2, True, False)
        self.module14 = nn.Sequential(
            SeparableConv(1024, 1536, 3, 1, 1),
            nn.BatchNorm2d(1536),
            nn.ReLU(inplace=True),
            SeparableConv(1536, 2048, 3, 1, 1),
            nn.BatchNorm2d(2048),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )
        
        # classifier
        self.fc = nn.Linear(2048, num_classes)
        
    def forward(self, img):
        x = self.module1(img)
        x = self.module2(x)
        x = self.module3(x)
        x = self.module4(x)
        x = self.module5_12(x)
        x = self.module13(x)
        x = self.module14(x)
        
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

import torchsummary
model = Xception(2)
print(torchsummary.summary(model, (3,299,299), device='cpu'))
print(model)
        
        

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 149, 149]             864
       BatchNorm2d-2         [-1, 32, 149, 149]              64
              ReLU-3         [-1, 32, 149, 149]               0
            Conv2d-4         [-1, 64, 149, 149]          18,432
       BatchNorm2d-5         [-1, 64, 149, 149]             128
              ReLU-6         [-1, 64, 149, 149]               0
            Conv2d-7         [-1, 64, 149, 149]             576
            Conv2d-8        [-1, 128, 149, 149]           8,192
     SeparableConv-9        [-1, 128, 149, 149]               0
      BatchNorm2d-10        [-1, 128, 149, 149]             256
             ReLU-11        [-1, 128, 149, 149]               0
           Conv2d-12        [-1, 128, 149, 149]           1,152
           Conv2d-13        [-1, 128, 149, 149]          16,384
    SeparableConv-14        [-1, 128, 1

In [12]:
import datetime
type(datetime)

module

## Train

In [None]:
import argparse
import os

import time
from datetime import datetime as dt
import random

import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.optim.lr_scheduler import StepLR

from utils import *
from Train import *
from model.xception import Xception
from data import get_dataloader

def build_args():
    parser = argparse.ArgumentParser()
    #### dataset ####
    parser.add_argument("--data_name", type=str, default="DeepFake",
                        choices=['DeepFake', 'DeepFakeDetection', 'Face2Face', 'FaceSwap', 'NeuralTextures'])
    parser.add_argument("--data_path", type=str, default='/media/data1/sangyong/df_datasets')
    parser.add_argument("--n_workers", type=int, default=4)
    #### train & test ####
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--model", type=str, default="xception", choices=["xception", "clrnet"])
    parser.add_argument("--optimizer", type=str, default="SGD")
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--weight_decay", type=float, default=5e-4)
    parser.add_argument("--epochs", type=int, default=30)
    #### save & load ####
    parser.add_argument("--save_root_dir", type=str, default='/media/data1/sangyong/deepfake_detection/save')
    parser.add_argument("--model_load_path", default=None)
    parser.add_argument("--print_freq", type=int, default=100)
    parser.add_argument("--world_size", type=int, default=4)
    parser.add_argument("--DDP", action="store_true")
    parser.add_argument("--dist_backend", type=str, default='nccl')
    parser.add_argument("--use_wandb", default=False)
    args = parser.parse_args()
    
    if args.DDP:
        args.local_rank = int(os.environ["LOCAL_RANK"])
        args.batch_size = args.batch_size // args.world_size
    else:
        args.local_rank = 0
    args.save_name = f"[data-{args.data_name}]_[bs-{args.batch_size}]_"+\
                     f"[m-{args.model}]_[optim-{args.optimizer}]_[date-{dt.now().strftime('%Y%m%d')}]"
    args.save_dir = os.path.join(args.save_root_dir, args.save_name)
    args.model_save_dir = os.path.join(args.save_dir, "save_model")
    args.logger_path = os.path.join(args.save_dir, "log.txt")
    os.makedirs(args.save_dir, exist_ok=True)
    os.makedirs(args.model_save_dir, exist_ok=True)
    return args

def main(args, logger):
    if args.model == "xception":
        model = Xception(num_classes=2).cuda(args.local_rank)
    elif args.model == "clrnet":
        model = None
    
    train_loader, valid_loader, train_sampler = get_dataloader(args)
    
    if args.DDP:
        # model backward pass에 연관되지 않은 parameter들을 mark해서 DDP가 해당 파라미터들의 gradient들을 영원히 기다리는 것을 방지 한다. 
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) 

    if args.optimizer == "SGD":
        optimizer = torch.optim.SGD(
            params=model.parameters(),
            lr = args.lr,
            momentum=0.9,
            weight_decay=args.weight_decay,
            nesterov=True
        )
    elif args.optimizer == "Adam":
        optimizer = torch.optim.Adam(
            params=model.parameters(),
            lr = args.lr,
            weight_decay=args.weight_decay,
        )
    else:
        raise NotImplementedError(f"optimizer {args.optimzier} is not implemented. please change")
    
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
    criterion = nn.CrossEntropyLoss()
    
    start_epoch = 1
    best_acc = -1
    best_loss = float('inf')
    
    if args.model_load_path:
        logger.write(f"model load from {args.model_load_path}\n")
        if not args.DDP:
            checkpoint = torch.load(args.model_load_path)
        else:
            dist.barrier()
            checkpoint = torch.load(args.model_load_path, map_location={"cuda:0": f"cuda:{args.local_rank}"})
            
        start_epoch = checkpoint["epoch"] + 1
        best_acc = checkpoint['best_acc']
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.write(f"model is successfully loaded\n"
                     f"start epoch: {start_epoch}, best_acc: {best_acc}")
        
        del checkpoint
    
    for epoch in range(start_epoch, args.epochs):
        if args.DDP:
            train_sampler.set_epoch(epoch)
            
        train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, args)
        
        valid_loss, valid_acc = validate(valid_loader, model, criterion, args)
        
        if args.local_rank == 0:
            if valid_acc > best_acc:
                logger.write(f"Best accuracy: {best_acc:.4f} -> {valid_acc:.4f}")
                best_acc = valid_acc
                checkpoint_dict = {
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "epoch": epoch,
                    "best_acc": best_acc
                }
                model_save_path = os.path.join(args.model_save_dir, f"{epoch}_{best_acc}.pth")
                torch.save(checkpoint_dict, model_save_path)
            logger.write(f"[Epoch-{epoch}]_[Train accuracy-{train_acc}]_"
                          f"[Train loss-{train_loss:.4f}]_[Valid loss-{valid_loss}]")
            if valid_loss < best_loss:
                best_loss = valid_loss
            
            if args.use_wandb:
                wandb_msg = {"Train acc": train_acc,
                             "valid acc": valid_acc,
                             "Train loss": train_loss,
                             "valid loss": valid_loss}
                wandb.log(wandb_msg)
        logger.write(f"[Best accuracy-{best_acc}]_[Best loss-{best_loss}]")
        scheduler.step()
        dist.barrier()
        
args = build_args()
logger = Logger(args.local_rank)
logger.open(args.logger_path)
print_args(args, logger=logger)
if args.use_wandb and args.local_rank ==0:
    wandb.init(project="Deepfake Detection", name=args.save_name, notes=args.save_name)
    wandb.config.update(args)
if args.DDP:
    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(backend=args.dist_backend)
    logger.write(f'DDP using {args.world_size} GPUS\n')
start_time = time.time()
if args.data_name in ['DeepFake', 'DeepFakeDetection', 'Face2Face', 'FaceSwap', 'NeuralTextures']:
    main(args, logger)
else:
    raise NotImplementedError(f"data {args.data_name} is not implemented")

logger.write(f"total time: {time.time() - start_time}")