In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0,2,3"
import numpy as np
import torch
import torch.nn as nn
from facenet_pytorch import fixed_image_standardization
#from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm
from vit_pytorch import ViT
from byol_pytorch import BYOL
from data_loader import get_loader, read_dataset, CompositeDataset
from model import FaceRecognitionCNN
from utils import write_json, copy_file, count_parameters

In [2]:
transform = transforms.Compose([
        transforms.Resize((160, 160)),
        np.float32,
        transforms.ToTensor(),
        fixed_image_standardization
    ])

In [3]:
datasets = read_dataset(
    '../dataset/mtcnn/', transform=transform,
    max_images_per_video=10, max_videos=1000,
    window_size=1, splits_path='../dataset/splits/'
)
# only neural textures c40 and original c40
datasets = {
    k: v for k, v in datasets.items() 
    if ('original' in k or 'neural' in k or 'face2face' in k or 'faceswap' in k or 'deepfakes' in k) and 'c23' in k
}
print('Using training data: ')
print('\n'.join(sorted(datasets.keys())))

trains, vals, tests = [], [], []
for data_dir_name, dataset in datasets.items():
    train, val, test = dataset
    # repeat original data multiple times to balance out training data
    compression = data_dir_name.split('_')[-1]
    num_tampered_with_same_compression = len({x for x in datasets.keys() if compression in x}) - 1
    count = 1 if 'original' not in data_dir_name else num_tampered_with_same_compression
    for _ in range(count):
        trains.append(train)
    vals.append(val)
    tests.append(test)
    
train_dataset, val_dataset, test_dataset = CompositeDataset(*trains), CompositeDataset(*vals), CompositeDataset(*tests)

['deepfakes_faces_c23', 'original_faces_c23', 'face2face_faces_c23', 'neural_textures_faces_c23', 'faceswap_faces_c23']
Using training data: 
deepfakes_faces_c23
face2face_faces_c23
faceswap_faces_c23
neural_textures_faces_c23
original_faces_c23


In [4]:
tqdm.write('train data size: {}, validation data size: {}'.format(len(train_dataset), len(val_dataset)))

train data size: 57600, validation data size: 7000


In [5]:
train_loader = get_loader(
    train_dataset, 64, shuffle=True, num_workers=2
)
val_loader = get_loader(
    val_dataset, 256, shuffle=True, num_workers=2
)

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('training on', device)

training on cuda


In [7]:
PATCH_SIZE = 16
DIM = 1024
MLP_DIM = 2048
DEPTH = 12
HEADS = 16

In [8]:
model = ViT(
    image_size = 160,
    patch_size = PATCH_SIZE,
    num_classes = 5,
    dim = DIM,
    depth = DEPTH,
    heads = HEADS,
    mlp_dim = MLP_DIM,
    dropout = 0.1,
    emb_dropout = 0.1
)
learner = BYOL(
    model,
    image_size = 256,
    hidden_layer = 'to_latent'
)
model = nn.DataParallel(model)
model.to(device)
#if args.freeze_first_epoch:
#for m in model.resnet.parameters():
#    m.requires_grad_(False)

DataParallel(
  (module): ViT(
    (patch_to_embedding): Linear(in_features=768, out_features=1024, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (transformer): Transformer(
      (layers): ModuleList(
        (0): ModuleList(
          (0): Residual(
            (fn): PreNorm(
              (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (fn): Attention(
                (to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=1024, out_features=1024, bias=True)
                  (1): Dropout(p=0.1, inplace=False)
                )
              )
            )
          )
          (1): Residual(
            (fn): PreNorm(
              (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (fn): FeedForward(
                (net): Sequential(
                  (0): Linear(in_features=1024, out_features=2048, bias=True)
             

In [9]:
input_shape = next(iter(train_loader))[2].shape
print('input shape', input_shape)
# need to call this before summary!!!
model.eval()
# summary(model, input_shape[1:], batch_size=input_shape[0], device=device)
print('model params (trainable, total):', count_parameters(model))

input shape torch.Size([64, 3, 160, 160])
model params (trainable, total): (101660677, 101660677)


In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=1e-4, weight_decay=1e-3
)

# decrease learning rate if validation accuracy has not increased
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=1/4, patience=2, verbose=True,
)

In [11]:
def save_model_checkpoint(epoch, model, val_acc):
    
    model_dir = os.path.join('./model', 'vit')
    os.makedirs(model_dir, exist_ok=True)

    model_path = os.path.join(model_dir, f'model.pt')
    torch.save(model.state_dict(), model_path)

    model_info = {
        'epoch': epoch,
        'val_acc': val_acc[0],
        'model_str': str(model)
    }
    json_path = os.path.join(model_dir, 'info.json')
    write_json(model_info, json_path)

    #src_model_file = os.path.join('facenet', 'model.py')
    #dest_model_file = os.path.join(model_dir, 'model.py')
    #copy_file(src_model_file, dest_model_file)

    tqdm.write(f'New checkpoint saved at {model_path}')


def print_training_info(batch_accuracy, loss, step):
    log_info = 'Training - Loss: {:.4f}, Accuracy: {:.4f}'.format(loss.item(), batch_accuracy)
    tqdm.write(log_info)

    #writer.add_scalar('training loss', loss.item(), step)
    #writer.add_scalar('training acc', batch_accuracy, step)


def print_validation_info(criterion, device, model, val_loader, epoch, step):
    model.eval()
    with torch.no_grad():
        loss_values = []
        all_predictions = []
        all_targets = []
        targets = []
        outputs = []
        for video_ids, frame_ids, images, target in val_loader:
            images = images.to(device)
            target = target.to(device)
            #target = target.long()
            #output = model(images)
            loss = criterion(output, target)
            loss_values.append(loss.item())
            targets.append(target)
            outputs.append(output)
            #predictions = outputs > 0.0
            #all_predictions.append(predictions)
            #all_targets.append(targets)
            #if args.debug:
            #    tqdm.write(outputs)
            #    tqdm.write(predictions)
            #    tqdm.write(targets)
        
        val_loss = sum(loss_values) / len(loss_values)
        
        outputs = torch.cat(outputs, 0)
        targets = torch.cat(targets, 0)
        
        val_accuracy = float((outputs.argmax(1)).eq(targets).sum()) / len(targets)
        
        total_target = targets.unique(return_counts=True)[1]
        pristine = ((outputs.argmax(1) == 0) * (targets == 0)).sum() / total_target[0]
        face2face = ((outputs.argmax(1) == 1) * (targets == 1)).sum() / total_target[1]
        faceswap = ((outputs.argmax(1) == 2) * (targets == 2)).sum() / total_target[2]
        neural = ((outputs.argmax(1) == 3) * (targets == 3)).sum() / total_target[3]
        deepfake = ((outputs.argmax(1) == 4) * (targets == 4)).sum() / total_target[4]
        
        tqdm.write(
            'Validation - Loss: {:.3f}, Acc: {:.3f}, Pr: {:.3f}, Ff: {:.3f}, Fs: {:.3f}, Nt: {:.3f}, Df: {:.3f}'.format(
                val_loss, val_accuracy, pristine, face2face, faceswap, neural, deepfake
            )
        )
        
    return val_accuracy, pristine, face2face, faceswap, neural, deepfake


In [12]:
total_step = len(train_loader)
step = 1
best_val_acc = 0.5
for epoch in range(30):
    for i, (video_ids, frame_ids, images, targets) in \
            tqdm(enumerate(train_loader), desc=f'training epoch {epoch}', total=len(train_loader)):
        model.train()
        # Set mini-batch dataset
        images = images.to(device)
        targets = targets.to(device)

        # Forward, backward and optimize
        outputs = model(images)
        targets = targets.long()
        loss = criterion(outputs, targets)
        model.zero_grad()
        loss.backward()
        optimizer.step()

        batch_accuracy = float((outputs.argmax(1)).eq(targets).sum()) / len(targets)
        
        

        # Print log info
        step += 1
        
        if (i + 1) % 300 == 0:
            print_training_info(batch_accuracy, loss, step)

        if (i + 1) % 300 == 0:
            val_acc, pr_acc, ff_acc, fs_acc, nt_acc, df_acc = print_validation_info(
                criterion, device, model, val_loader, epoch, step
            )
            if val_acc > best_val_acc:
                save_model_checkpoint(epoch, model, (val_acc, pr_acc, ff_acc, fs_acc, nt_acc, df_acc))
                best_val_acc = val_acc

    # validation step after full epoch
    val_acc, pr_acc, ff_acc, fs_acc, nt_acc, df_acc = print_validation_info(
        criterion, device, model, val_loader, epoch, step
    )
    lr_scheduler.step(val_acc)
    if val_acc > best_val_acc:
        save_model_checkpoint(epoch, model, (val_acc, pr_acc, ff_acc, fs_acc, nt_acc, df_acc))
        best_val_acc = val_acc

    #if epoch == 0:
    #    for m in model.resnet.parameters():
    #        m.requires_grad_(True)
    #    tqdm.write('Fine tuning on')

training epoch 0:  33%|███▎      | 299/900 [01:51<03:32,  2.82it/s]

Training - Loss: 1.4234, Accuracy: 0.4844


training epoch 0:  33%|███▎      | 300/900 [02:10<1:01:18,  6.13s/it]

Validation - Loss: 1.766, Acc: 0.200, Pr: 1.000, Ff: 0.000, Fs: 0.000, Nt: 0.000, Df: 0.000


training epoch 0:  67%|██████▋   | 599/900 [03:57<01:46,  2.81it/s]  

Training - Loss: 1.2458, Accuracy: 0.6250


training epoch 0:  67%|██████▋   | 600/900 [04:15<29:14,  5.85s/it]

Validation - Loss: 1.908, Acc: 0.200, Pr: 1.000, Ff: 0.000, Fs: 0.000, Nt: 0.000, Df: 0.000


training epoch 0: 100%|█████████▉| 899/900 [06:01<00:00,  2.81it/s]

Training - Loss: 1.5065, Accuracy: 0.4375


training epoch 0: 100%|██████████| 900/900 [06:24<00:00,  2.34it/s]

Validation - Loss: 1.986, Acc: 0.200, Pr: 1.000, Ff: 0.000, Fs: 0.000, Nt: 0.000, Df: 0.000





KeyboardInterrupt: 

In [None]:
model = ViT(
    image_size = 160,
    patch_size = PATCH_SIZE,
    num_classes = 5,
    dim = DIM,
    depth = DEPTH,
    heads = HEADS,
    mlp_dim = MLP_DIM,
    dropout = 0.1,
    emb_dropout = 0.1
)
model = nn.DataParallel(model)
state_dict = torch.load('./model/vit/model.pt', map_location='cpu')
model.load_state_dict(state_dict)
model.to(device)

test_loader = get_loader(
    test_dataset, 64, shuffle=True, num_workers=2, drop_last=False
)
with torch.no_grad():
    loss_values = []
    all_predictions = []
    all_targets = []
    targets = []
    outputs = []
    for video_ids, frame_ids, images, target in tqdm(test_loader):
        images = images.to(device)
        target = target.to(device)
        target = target.long()
        output = model(images)
        targets.append(target)
        outputs.append(output)
        loss = criterion(output, target)
        loss_values.append(loss.item())

#                 predictions = outputs > 0.0
#                 all_predictions.append(predictions)
#                 all_targets.append(targets)

    val_loss = sum(loss_values) / len(loss_values)

    outputs = torch.cat(outputs, 0)
    targets = torch.cat(targets, 0)
        
    val_accuracy = float((outputs.argmax(1)).eq(targets).sum()) / len(targets)

    total_target = targets.unique(return_counts=True)[1]
    pristine = ((outputs.argmax(1) == 0) * (targets == 0)).sum() / total_target[0]
    face2face = ((outputs.argmax(1) == 1) * (targets == 1)).sum() / total_target[1]
    faceswap = ((outputs.argmax(1) == 2) * (targets == 2)).sum() / total_target[2]
    neural = ((outputs.argmax(1) == 3) * (targets == 3)).sum() / total_target[3]
    deepfake = ((outputs.argmax(1) == 4) * (targets == 4)).sum() / total_target[4]
    tqdm.write(
        'Test - Loss: {:.3f}, Acc: {:.3f}, Pr: {:.3f}, Ff: {:.3f}, Fs: {:.3f}, Nt: {:.3f}, Df: {:.3f}'.format(
            val_loss, val_accuracy, pristine, face2face, faceswap, neural, deepfake
        )
    )