In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import numpy as np
import torch
import torch.nn as nn
from facenet_pytorch import fixed_image_standardization
#from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm

from data_loader import get_loader, read_dataset, CompositeDataset
from resnet3d import resnet10
from utils import write_json, copy_file, count_parameters

In [2]:
transform = transforms.Compose([
        transforms.Resize((160, 160)),
        np.float32,
        transforms.ToTensor(),
        fixed_image_standardization
    ])

In [3]:
datasets = read_dataset(
    '../dataset/mtcnn/', transform=transform,
    max_images_per_video=10, max_videos=1000,
    window_size=5, splits_path='../dataset/splits/'
)
# only neural textures c40 and original c40
datasets = {
    k: v for k, v in datasets.items() 
    if ('original' in k or 'neural' in k or 'face2face' in k or 'faceswap' in k or 'deepfakes' in k) and 'c23' in k
}
print('Using training data: ')
print('\n'.join(sorted(datasets.keys())))

trains, vals, tests = [], [], []
for data_dir_name, dataset in datasets.items():
    train, val, test = dataset
    # repeat original data multiple times to balance out training data
    compression = data_dir_name.split('_')[-1]
    num_tampered_with_same_compression = len({x for x in datasets.keys() if compression in x}) - 1
    count = 1 if 'original' not in data_dir_name else num_tampered_with_same_compression
    for _ in range(count):
        trains.append(train)
    vals.append(val)
    tests.append(test)
    
train_dataset, val_dataset, test_dataset = CompositeDataset(*trains), CompositeDataset(*vals), CompositeDataset(*tests)

['deepfakes_faces_c23', 'original_faces_c23', 'face2face_faces_c23', 'neural_textures_faces_c23', 'faceswap_faces_c23']
Using training data: 
deepfakes_faces_c23
face2face_faces_c23
faceswap_faces_c23
neural_textures_faces_c23
original_faces_c23


In [4]:
tqdm.write('train data size: {}, validation data size: {}'.format(len(train_dataset), len(val_dataset)))

train data size: 57584, validation data size: 6990


In [5]:
train_loader = get_loader(
    train_dataset, 64, shuffle=True, num_workers=2
)
val_loader = get_loader(
    val_dataset, 64, shuffle=True, num_workers=2
)

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('training on', device)

training on cuda


In [7]:
model = resnet10(num_classes=5).to(device)
#if args.freeze_first_epoch:
#for m in model.resnet.parameters():
#    m.requires_grad_(False)

In [8]:
input_shape = next(iter(train_loader))[2].shape
print('input shape', input_shape)
# need to call this before summary!!!
model.eval()
# summary(model, input_shape[1:], batch_size=input_shape[0], device=device)
print('model params (trainable, total):', count_parameters(model))

input shape torch.Size([64, 3, 5, 160, 160])
model params (trainable, total): (14401989, 14401989)


In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=1e-5, weight_decay=1e-3
)

# decrease learning rate if validation accuracy has not increased
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=1/4, patience=2, verbose=True,
)

In [10]:
def save_model_checkpoint(epoch, model, val_acc):
    
    model_dir = os.path.join('./model', 'resnet3d')
    os.makedirs(model_dir, exist_ok=True)

    model_path = os.path.join(model_dir, f'model.pt')
    torch.save(model.state_dict(), model_path)

    model_info = {
        'epoch': epoch,
        'val_acc': val_acc[0],
        'model_str': str(model)
    }
    json_path = os.path.join(model_dir, 'info.json')
    write_json(model_info, json_path)

    #src_model_file = os.path.join('facenet', 'model.py')
    #dest_model_file = os.path.join(model_dir, 'model.py')
    #copy_file(src_model_file, dest_model_file)

    tqdm.write(f'New checkpoint saved at {model_path}')


def print_training_info(batch_accuracy, loss, step):
    log_info = 'Training - Loss: {:.4f}, Accuracy: {:.4f}'.format(loss.item(), batch_accuracy)
    tqdm.write(log_info)

    #writer.add_scalar('training loss', loss.item(), step)
    #writer.add_scalar('training acc', batch_accuracy, step)


def print_validation_info(criterion, device, model, val_loader, epoch, step):
    model.eval()
    with torch.no_grad():
        loss_values = []
        all_predictions = []
        all_targets = []
        for video_ids, frame_ids, images, targets in val_loader:
            images = images.to(device)
            targets = targets.to(device)
            targets = targets.long()
            outputs = model(images)
            loss = criterion(outputs, targets)
            loss_values.append(loss.item())

            #predictions = outputs > 0.0
            #all_predictions.append(predictions)
            #all_targets.append(targets)
            #if args.debug:
            #    tqdm.write(outputs)
            #    tqdm.write(predictions)
            #    tqdm.write(targets)

        val_loss = sum(loss_values) / len(loss_values)

        val_accuracy = float((outputs.argmax(1)).eq(targets).sum()) / len(targets)
        
        total_target = targets.unique(return_counts=True)[1]
        pristine = ((outputs.argmax(1) == 0) * (targets == 0)).sum() / total_target[0]
        face2face = ((outputs.argmax(1) == 1) * (targets == 1)).sum() / total_target[1]
        faceswap = ((outputs.argmax(1) == 2) * (targets == 2)).sum() / total_target[2]
        neural = ((outputs.argmax(1) == 3) * (targets == 3)).sum() / total_target[3]
        deepfake = ((outputs.argmax(1) == 4) * (targets == 4)).sum() / total_target[4]
        
        tqdm.write(
            'Validation - Loss: {:.3f}, Acc: {:.3f}, Pr: {:.3f}, Ff: {:.3f}, Fs: {:.3f}, Nt: {:.3f}, Df: {:.3f}'.format(
                val_loss, val_accuracy, pristine, face2face, faceswap, neural, deepfake
            )
        )
        
    return val_accuracy, pristine, face2face, faceswap, neural, deepfake


In [11]:
total_step = len(train_loader)
step = 1
best_val_acc = 0.5
for epoch in range(15):
    for i, (video_ids, frame_ids, images, targets) in \
            tqdm(enumerate(train_loader), desc=f'training epoch {epoch}', total=len(train_loader)):
        model.train()
        # Set mini-batch dataset
        images = images.to(device)
        targets = targets.to(device)

        # Forward, backward and optimize
        outputs = model(images)
        targets = targets.long()
        loss = criterion(outputs, targets)
        model.zero_grad()
        loss.backward()
        optimizer.step()

        batch_accuracy = float((outputs.argmax(1)).eq(targets).sum()) / len(targets)
        
        

        # Print log info
        step += 1
        
        if (i + 1) % 300 == 0:
            print_training_info(batch_accuracy, loss, step)

        if (i + 1) % 300 == 0:
            val_acc, pr_acc, ff_acc, fs_acc, nt_acc, df_acc = print_validation_info(
                criterion, device, model, val_loader, epoch, step
            )
            if val_acc > best_val_acc:
                save_model_checkpoint(epoch, model, (val_acc, pr_acc, ff_acc, fs_acc, nt_acc, df_acc))
                best_val_acc = val_acc

    # validation step after full epoch
    val_acc, pr_acc, ff_acc, fs_acc, nt_acc, df_acc = print_validation_info(
        criterion, device, model, val_loader, epoch, step
    )
    lr_scheduler.step(val_acc)
    if val_acc > best_val_acc:
        save_model_checkpoint(epoch, model, (val_acc, pr_acc, ff_acc, fs_acc, nt_acc, df_acc))
        best_val_acc = val_acc

training epoch 0:  33%|███▎      | 299/899 [02:28<04:55,  2.03it/s]

Training - Loss: 1.4584, Accuracy: 0.3906


training epoch 0:  33%|███▎      | 300/899 [03:11<2:15:49, 13.61s/it]

Validation - Loss: 1.693, Acc: 0.297, Pr: 0.947, Ff: 0.000, Fs: 0.000, Nt: 0.000, Df: 0.111


training epoch 0:  67%|██████▋   | 599/899 [05:39<02:28,  2.02it/s]  

Training - Loss: 1.2205, Accuracy: 0.5000


training epoch 0:  67%|██████▋   | 600/899 [06:23<1:07:52, 13.62s/it]

Validation - Loss: 1.580, Acc: 0.219, Pr: 1.000, Ff: 0.000, Fs: 0.000, Nt: 0.000, Df: 0.182


training epoch 0: 100%|██████████| 899/899 [08:51<00:00,  1.69it/s]  


Validation - Loss: 1.569, Acc: 0.219, Pr: 0.833, Ff: 0.143, Fs: 0.000, Nt: 0.000, Df: 0.500


training epoch 1:  33%|███▎      | 299/899 [02:28<04:55,  2.03it/s]

Training - Loss: 1.0410, Accuracy: 0.6250


training epoch 1:  33%|███▎      | 300/899 [03:10<2:11:00, 13.12s/it]

Validation - Loss: 1.509, Acc: 0.344, Pr: 0.800, Ff: 0.083, Fs: 0.200, Nt: 0.000, Df: 0.538


training epoch 1:  67%|██████▋   | 599/899 [05:37<02:27,  2.04it/s]  

Training - Loss: 0.9252, Accuracy: 0.6875


training epoch 1:  67%|██████▋   | 600/899 [06:22<1:09:41, 13.99s/it]

Validation - Loss: 1.565, Acc: 0.406, Pr: 1.000, Ff: 0.077, Fs: 0.294, Nt: 0.000, Df: 0.600


training epoch 1: 100%|██████████| 899/899 [08:48<00:00,  1.70it/s]  


Validation - Loss: 1.330, Acc: 0.453, Pr: 0.350, Ff: 0.625, Fs: 0.600, Nt: 0.100, Df: 0.625


training epoch 2:  33%|███▎      | 299/899 [02:27<04:53,  2.04it/s]

Training - Loss: 0.6427, Accuracy: 0.7969


training epoch 2:  33%|███▎      | 300/899 [03:06<2:01:56, 12.21s/it]

Validation - Loss: 1.340, Acc: 0.406, Pr: 0.471, Ff: 0.417, Fs: 0.625, Nt: 0.053, Df: 0.875


training epoch 2:  67%|██████▋   | 599/899 [05:33<02:27,  2.04it/s]  

Training - Loss: 0.5557, Accuracy: 0.7969


training epoch 2:  67%|██████▋   | 600/899 [06:14<1:03:34, 12.76s/it]

Validation - Loss: 1.352, Acc: 0.250, Pr: 0.467, Ff: 0.100, Fs: 0.333, Nt: 0.053, Df: 0.500


training epoch 2: 100%|██████████| 899/899 [08:40<00:00,  1.73it/s]  


Validation - Loss: 1.330, Acc: 0.562, Pr: 0.417, Ff: 0.545, Fs: 0.812, Nt: 0.182, Df: 0.714
New checkpoint saved at ./model/resnet3d/model.pt


training epoch 3:  33%|███▎      | 299/899 [02:27<04:55,  2.03it/s]

Training - Loss: 0.4241, Accuracy: 0.8594


training epoch 3:  33%|███▎      | 300/899 [03:05<1:58:12, 11.84s/it]

Validation - Loss: 1.382, Acc: 0.531, Pr: 0.333, Ff: 0.706, Fs: 0.583, Nt: 0.357, Df: 0.583


training epoch 3:  67%|██████▋   | 599/899 [05:32<02:28,  2.02it/s]  

Training - Loss: 0.3276, Accuracy: 0.9219


training epoch 3:  67%|██████▋   | 600/899 [06:09<58:04, 11.65s/it]

Validation - Loss: 1.348, Acc: 0.438, Pr: 0.444, Ff: 0.273, Fs: 0.667, Nt: 0.059, Df: 0.800


training epoch 3:  78%|███████▊  | 703/899 [07:00<01:57,  1.67it/s]


KeyboardInterrupt: 