In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0,2,3"
import numpy as np
import torch
import torch.nn as nn
from facenet_pytorch import fixed_image_standardization
#from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm
from vit_pytorch import ViT
from byol_pytorch import BYOL
from data_loader import get_loader, read_dataset, CompositeDataset
from model import FaceRecognitionCNN
from utils import write_json, copy_file, count_parameters

In [2]:
transform = transforms.Compose([
        transforms.Resize((160, 160)),
        np.float32,
        transforms.ToTensor(),
        fixed_image_standardization
    ])

In [3]:
datasets = read_dataset(
    '../dataset/mtcnn/', transform=transform,
    max_images_per_video=10, max_videos=1000,
    window_size=1, splits_path='../dataset/splits/'
)
# only neural textures c40 and original c40
datasets = {
    k: v for k, v in datasets.items() 
    if ('original' in k or 'neural' in k or 'face2face' in k or 'faceswap' in k or 'deepfakes' in k) and 'c23' in k
}
print('Using training data: ')
print('\n'.join(sorted(datasets.keys())))

trains, vals, tests = [], [], []
for data_dir_name, dataset in datasets.items():
    train, val, test = dataset
    # repeat original data multiple times to balance out training data
    compression = data_dir_name.split('_')[-1]
    num_tampered_with_same_compression = len({x for x in datasets.keys() if compression in x}) - 1
    count = 1 if 'original' not in data_dir_name else num_tampered_with_same_compression
    for _ in range(count):
        trains.append(train)
    vals.append(val)
    tests.append(test)
    
train_dataset, val_dataset, test_dataset = CompositeDataset(*trains), CompositeDataset(*vals), CompositeDataset(*tests)

['deepfakes_faces_c23', 'original_faces_c23', 'face2face_faces_c23', 'neural_textures_faces_c23', 'faceswap_faces_c23']
Using training data: 
deepfakes_faces_c23
face2face_faces_c23
faceswap_faces_c23
neural_textures_faces_c23
original_faces_c23


In [4]:
tqdm.write('train data size: {}, validation data size: {}'.format(len(train_dataset), len(val_dataset)))

train data size: 57600, validation data size: 7000


In [5]:
train_loader = get_loader(
    train_dataset, 64, shuffle=True, num_workers=2
)
val_loader = get_loader(
    val_dataset, 256, shuffle=True, num_workers=2
)

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('training on', device)

training on cuda


In [7]:
PATCH_SIZE = 16
DIM = 1024
MLP_DIM = 2048
DEPTH = 6
HEADS = 16

In [8]:
model = ViT(
    image_size = 160,
    patch_size = PATCH_SIZE,
    num_classes = 5,
    dim = DIM,
    depth = DEPTH,
    heads = HEADS,
    mlp_dim = MLP_DIM,
    dropout = 0.1,
    emb_dropout = 0.1
)
learner = BYOL(
    model,
    image_size = 160,
    hidden_layer = 'to_latent'
)
#learner = nn.DataParallel(learner)
learner.to(device)
#if args.freeze_first_epoch:
#for m in model.resnet.parameters():
#    m.requires_grad_(False)

BYOL(
  (net): ViT(
    (patch_to_embedding): Linear(in_features=768, out_features=1024, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (transformer): Transformer(
      (layers): ModuleList(
        (0): ModuleList(
          (0): Residual(
            (fn): PreNorm(
              (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (fn): Attention(
                (to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=1024, out_features=1024, bias=True)
                  (1): Dropout(p=0.1, inplace=False)
                )
              )
            )
          )
          (1): Residual(
            (fn): PreNorm(
              (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (fn): FeedForward(
                (net): Sequential(
                  (0): Linear(in_features=1024, out_features=2048, bias=True)
                  (1): G

In [9]:
input_shape = next(iter(train_loader))[2].shape
print('input shape', input_shape)
# need to call this before summary!!!
learner.eval()
# summary(model, input_shape[1:], batch_size=input_shape[0], device=device)
print('model params (trainable, total):', count_parameters(learner))

input shape torch.Size([64, 3, 160, 160])
model params (trainable, total): (58644997, 115180298)


In [10]:
optimizer = torch.optim.Adam(
    learner.parameters(), lr=1e-5, weight_decay=1e-3
)

# decrease learning rate if validation accuracy has not increased
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=1/4, patience=2, verbose=True,
)

In [11]:
def save_model_checkpoint(epoch, model, val_loss):
    
    model_dir = os.path.join('./model', 'byol')
    os.makedirs(model_dir, exist_ok=True)

    model_path = os.path.join(model_dir, f'model.pt')
    torch.save(model.state_dict(), model_path)

    model_info = {
        'epoch': epoch,
        'val_loss': val_loss,
        'model_str': str(model)
    }
    json_path = os.path.join(model_dir, 'info.json')
    write_json(model_info, json_path)

    #src_model_file = os.path.join('facenet', 'model.py')
    #dest_model_file = os.path.join(model_dir, 'model.py')
    #copy_file(src_model_file, dest_model_file)

    tqdm.write(f'New checkpoint saved at {model_path}')


def print_training_info(loss, step):
    log_info = 'Training - Loss: {:.4f}'.format(loss.item())
    tqdm.write(log_info)

    #writer.add_scalar('training loss', loss.item(), step)
    #writer.add_scalar('training acc', batch_accuracy, step)


def print_validation_info(device, model, val_loader, epoch, step):
    model.eval()
    with torch.no_grad():
        loss_values = []
        all_predictions = []
        all_targets = []
        targets = []
        outputs = []
        for video_ids, frame_ids, images, target in val_loader:
            images = images.to(device)
            loss = learner(images)
            loss_values.append(loss.item())

        val_loss = sum(loss_values) / len(loss_values)
        
        tqdm.write(
            'Validation - Loss: {:.3f}'.format(val_loss)
        )
        
    return val_loss


In [None]:
total_step = len(train_loader)
step = 1
best_val_loss = 1
for epoch in range(30):
    for i, (video_ids, frame_ids, images, targets) in \
            tqdm(enumerate(train_loader), desc=f'training epoch {epoch}', total=len(train_loader)):
        learner.train()
        # Set mini-batch dataset
        images = images.to(device)
        
        # Forward, backward and optimize
        loss = learner(images)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        learner.update_moving_average()

        # Print log info
        step += 1
        
        if (i + 1) % 300 == 0:
            print_training_info(loss, step)

        if (i + 1) % 300 == 0:
            val_loss = print_validation_info(device, learner, val_loader, epoch, step
            )
            if val_loss < best_val_loss:
                save_model_checkpoint(epoch, learner, val_loss)
                best_val_loss = val_loss

    # validation step after full epoch
    val_loss = print_validation_info(device, learner, val_loader, epoch, step)
    #lr_scheduler.step(val_loss)
    if val_loss < best_val_loss:
        save_model_checkpoint(epoch, learner, val_loss)
        best_val_loss = val_loss

training epoch 0:  33%|███▎      | 299/900 [04:17<08:56,  1.12it/s]

Training - Loss: 0.7737


training epoch 0:  33%|███▎      | 300/900 [05:08<2:40:54, 16.09s/it]

Validation - Loss: 1.794


training epoch 0:  67%|██████▋   | 599/900 [09:29<04:20,  1.15it/s]  

Training - Loss: 1.5848


training epoch 0:  67%|██████▋   | 600/900 [10:20<1:19:46, 15.95s/it]

Validation - Loss: 1.693


training epoch 0: 100%|█████████▉| 899/900 [14:42<00:00,  1.16it/s]  

Training - Loss: 0.5381


training epoch 0: 100%|██████████| 900/900 [15:33<00:00,  1.04s/it]

Validation - Loss: 1.619





Validation - Loss: 2.283


training epoch 1:  33%|███▎      | 299/900 [04:23<08:48,  1.14it/s]

Training - Loss: 0.8399


training epoch 1:  33%|███▎      | 300/900 [05:13<2:40:10, 16.02s/it]

Validation - Loss: 1.636


training epoch 1:  67%|██████▋   | 599/900 [09:36<04:22,  1.15it/s]  

Training - Loss: 0.8121


training epoch 1:  67%|██████▋   | 600/900 [10:27<1:20:32, 16.11s/it]

Validation - Loss: 1.275


training epoch 1: 100%|█████████▉| 899/900 [14:49<00:00,  1.15it/s]  

Training - Loss: 0.4783


training epoch 1: 100%|██████████| 900/900 [15:40<00:00,  1.05s/it]

Validation - Loss: 1.174





Validation - Loss: 1.777


training epoch 2:  33%|███▎      | 299/900 [04:22<08:57,  1.12it/s]

Training - Loss: 0.3861


training epoch 2:  33%|███▎      | 300/900 [05:12<2:39:12, 15.92s/it]

Validation - Loss: 1.374


training epoch 2:  67%|██████▋   | 599/900 [09:35<04:23,  1.14it/s]  

Training - Loss: 0.3017


training epoch 2:  67%|██████▋   | 600/900 [10:24<1:18:34, 15.72s/it]

Validation - Loss: 1.527


training epoch 2: 100%|█████████▉| 899/900 [14:46<00:00,  1.14it/s]  

Training - Loss: 0.2067


training epoch 2: 100%|██████████| 900/900 [15:36<00:00,  1.04s/it]

Validation - Loss: 1.442





Validation - Loss: 1.127


training epoch 3:  33%|███▎      | 299/900 [04:23<08:39,  1.16it/s]

Training - Loss: 0.2841


training epoch 3:  33%|███▎      | 300/900 [05:13<2:39:41, 15.97s/it]

Validation - Loss: 1.360


training epoch 3:  67%|██████▋   | 599/900 [09:35<04:17,  1.17it/s]  

Training - Loss: 0.2694


training epoch 3:  67%|██████▋   | 600/900 [10:26<1:20:32, 16.11s/it]

Validation - Loss: 1.252


training epoch 3: 100%|█████████▉| 899/900 [14:49<00:00,  1.15it/s]  

Training - Loss: 0.4352


training epoch 3: 100%|██████████| 900/900 [15:39<00:00,  1.04s/it]

Validation - Loss: 1.508





Validation - Loss: 1.296


training epoch 4:  33%|███▎      | 299/900 [04:22<08:33,  1.17it/s]

Training - Loss: 0.3090


training epoch 4:  33%|███▎      | 300/900 [05:12<2:39:45, 15.98s/it]

Validation - Loss: 1.308


training epoch 4:  67%|██████▋   | 599/900 [09:35<04:26,  1.13it/s]  

Training - Loss: 1.4596


training epoch 4:  67%|██████▋   | 600/900 [10:25<1:18:59, 15.80s/it]

Validation - Loss: 1.279


training epoch 4: 100%|█████████▉| 899/900 [14:47<00:00,  1.13it/s]  

Training - Loss: 0.7518


training epoch 4: 100%|██████████| 900/900 [15:38<00:00,  1.04s/it]

Validation - Loss: 1.437





Validation - Loss: 1.452


training epoch 5:  33%|███▎      | 299/900 [04:23<08:44,  1.15it/s]

Training - Loss: 0.2401


training epoch 5:  33%|███▎      | 300/900 [05:13<2:38:27, 15.85s/it]

Validation - Loss: 1.512


training epoch 5:  67%|██████▋   | 599/900 [09:35<04:24,  1.14it/s]  

Training - Loss: 0.7365


training epoch 5:  67%|██████▋   | 600/900 [10:25<1:20:06, 16.02s/it]

Validation - Loss: 1.170


training epoch 5: 100%|█████████▉| 899/900 [14:48<00:00,  1.16it/s]  

Training - Loss: 0.4299


training epoch 5: 100%|██████████| 900/900 [15:37<00:00,  1.04s/it]

Validation - Loss: 1.186





Validation - Loss: 1.190


training epoch 6:  33%|███▎      | 299/900 [04:22<08:42,  1.15it/s]

Training - Loss: 0.3916


training epoch 6:  33%|███▎      | 300/900 [05:12<2:36:54, 15.69s/it]

Validation - Loss: 1.413


training epoch 6:  67%|██████▋   | 599/900 [09:34<04:26,  1.13it/s]  

Training - Loss: 0.2962


training epoch 6:  67%|██████▋   | 600/900 [10:25<1:20:36, 16.12s/it]

Validation - Loss: 1.125


training epoch 6: 100%|█████████▉| 899/900 [14:48<00:00,  1.13it/s]  

Training - Loss: 0.8421


training epoch 6: 100%|██████████| 900/900 [15:39<00:00,  1.04s/it]

Validation - Loss: 1.672





Validation - Loss: 1.124


training epoch 7:  18%|█▊        | 159/900 [02:19<10:47,  1.14it/s]