In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import torch.optim as optim
import numpy as np
import pandas as pd
from tqdm import tqdm
torch.manual_seed(1)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
cpu = torch.device('cpu')
print('Running on device: {}'.format(device))

In [None]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

In [None]:
# mixnet = timm.create_model("efficientnet_b2a", pretrained=True)

In [None]:
import torch.nn as nn

class DFDCNet(nn.Module):
    def __init__(self, input_size, output_size, hidden_dim, n_layers, drop_prob=0.5):
        super(DFDCNet, self).__init__()
        self.output_size = output_size
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        
        self.mixnet = timm.create_model("efficientnet_b2a", pretrained=True)
        self.mixnet.classifier = Identity()
#         l = 0
#         for param in self.mixnet.parameters():
#             l = l + 1
#             if l > 200:
#                 param.requires_grad = True
                
        self.lstm = nn.LSTM(1408, hidden_dim, n_layers, dropout=drop_prob, batch_first=True)
        self.dropout = nn.Dropout(0.5)
        self.batchnorm = nn.BatchNorm1d(hidden_dim)
        self.elu = nn.ELU()
        self.fc1 = nn.Linear(hidden_dim, 32)
#         self.fc2 = nn.Linear(64, 32)
#         self.fc3 = nn.Linear(32, 16)
        self.fc4 = nn.Linear(32, output_size)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x, hidden):
        batch_size, seqlen, c, h, w = x.size()
        x = x.reshape(batch_size*seqlen, c, h, w).float()
        x = self.mixnet(x)
        x = x.reshape(batch_size, seqlen, x.shape[1])
        lstm_out, hidden = self.lstm(x, hidden)
        lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim)
        
        out = self.dropout(lstm_out)
#         out = self.batchnorm(out)
        out = self.fc1(out)
        out = self.elu(out)
#         out = self.fc2(out)
#         out = self.elu(out)
#         out = self.fc3(out)
#         out = self.elu(out)
        out = self.fc4(out)
        out = self.sigmoid(out)
        
        out = out.view(batch_size, -1)
        out = out[:,-1]
        return out, hidden
    
    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = (weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(device),
                      weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(device))
        return hidden

In [None]:
 !ls -ln data_images | head -n 50

In [None]:
for p in np.arange(0,30):
    if p == 0:
        metadata = torch.load('data_images/000metadata_part_' + str(p) + '.pt', map_location = cpu)
    else:
        metadata_p = torch.load('data_images/000metadata_part_' + str(p) + '.pt', map_location = cpu)
        metadata = pd.concat([metadata, metadata_p])

In [None]:
metadata.n_face.value_counts()

In [None]:
metadata.head()

In [None]:
X = np.array(metadata.index[metadata.n_face == 1])
Y = np.array(1 * (metadata.label[metadata.n_face == 1] == 'REAL'))

In [None]:
X[0:5]

In [None]:
Y[0:5]

In [None]:
len(X)

In [None]:
len(Y)

In [None]:
n_videos = len(X)
n_videos_train = int(1000 * 50)
n_videos_val = int(100 * (np.floor((n_videos - n_videos_train)/100)))
print(str(n_videos_train) + ' for training')
print(str(n_videos_val) + ' for validation')

In [None]:
from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(torch.from_numpy(np.arange(0, n_videos_train + n_videos_val)), 
                        torch.from_numpy(Y[0: n_videos_train + n_videos_val]))
train_data, val_data = torch.utils.data.random_split(dataset, [n_videos_train, n_videos_val])
train_batch_size = 20
val_batch_size = 10
train_loader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size, num_workers = 16)
val_loader = DataLoader(val_data, shuffle=True, batch_size=val_batch_size, num_workers = 16)

In [None]:
input_size = 512
output_size = 1
hidden_dim = 512
n_layers = 2

model = DFDCNet(input_size, output_size, hidden_dim, n_layers)
model.to(device)

In [None]:
param_optimizer = list(model.named_parameters())

train_criterion = nn.BCELoss()
val_criterion = nn.BCELoss()
# no_decay = ['bias', 'bn']
# plist = [
#     {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
#     {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
#     ]
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [None]:
# scheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=500, verbose=True)

In [None]:
torch.cuda.empty_cache() 

In [None]:
! rm log_training.log

In [None]:
import logging
logging.basicConfig(filename = 'log_training.log',level = logging.INFO)


In [None]:
epochs = 5
counter = 0
print_every = 1000
clip = .5
valid_loss_min = .3
val_loss = torch.tensor(np.Inf)
model.train()
for i in range(epochs):
    h = model.init_hidden(train_batch_size)
#     outputs_t = torch.empty(0).to(device)
#     labels_t = torch.empty(0).to(device)
    for indexes, labels in train_loader:
#         inputs = torch.stack([torch.load('data_images/1face_X_'+ name + '.pt', map_location = device) for name in X[inputs]])
        inputs = []
        for v in np.arange(0, train_batch_size):
            inputs = inputs + [torch.load('data_images/1face_X_'+ X[indexes[v]] + '.pt', map_location = device)]
        inputs = torch.stack(inputs)
        counter += 1
        h = tuple([e.data for e in h])
        inputs, labels = inputs.to(device), labels.to(device)
        model.zero_grad()
        output, h = model(inputs, h)
        loss = train_criterion(output.squeeze(), labels.float())
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        
#         outputs_t = torch.cat((outputs_t, output.squeeze()))
#         labels_t = torch.cat((labels_t, labels.squeeze().float()))
        if counter % 100 == 0:
#             loss = train_criterion(outputs_t, labels_t)
#             loss.backward()
#             nn.utils.clip_grad_norm_(model.parameters(), clip)
#             optimizer.step()
            logging.info("Epoch: {}/{}...".format(i+1, epochs) +  
                         "Step: {}...".format(counter) +
                         "Loss: {:.6f}...".format(loss.item())) 
#             outputs_t = torch.empty(0).to(device)
#             labels_t = torch.empty(0).to(device)
        if counter%print_every == 0:
            logging.info(str(counter))
            val_h = model.init_hidden(val_batch_size)
            val_losses = []
            model.eval()
            for val_indexes, lab in val_loader:
#                 inp = torch.stack([torch.load('data_images/1face_X_'+ name + '.pt', map_location = device) for name in X[inp]])
                inp = []
                for v in np.arange(0, val_batch_size):
                    inp = inp + [torch.load('data_images/1face_X_'+ X[val_indexes[v]] + '.pt', map_location = device)]
                inp = torch.stack(inp)
                val_h = tuple([each.data for each in val_h])
                inp, lab = inp.to(device), lab.to(device)
                out, val_h = model(inp, val_h)
                val_loss = val_criterion(out.squeeze(), lab.float())
                val_losses.append(val_loss.item())
                
            model.train()
            print("Epoch: {}/{}...".format(i+1, epochs),
                  "Step: {}...".format(counter),
                  "Loss: {:.6f}...".format(loss.item()),
                  "Val Loss: {:.6f}".format(np.mean(val_losses)))
            logging.info("Epoch: {}/{}...".format(i+1, epochs) +  
                         "Step: {}...".format(counter) +
                         "Loss: {:.6f}...".format(loss.item()) +
                         "Val Loss: {:.6f}".format(np.mean(val_losses)))
            if np.mean(val_losses) <= valid_loss_min:
                torch.save(model.state_dict(), './model_1face_unfroze.pt')
                print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min,np.mean(val_losses)))
                logging.info('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min,np.mean(val_losses)))
                valid_loss_min = np.mean(val_losses)
    scheduler.step(val_loss.item())


```
INFO:root:Epoch: 1/5...Step: 100...Loss: 0.481132...
INFO:root:Epoch: 1/5...Step: 200...Loss: 0.163793...
INFO:root:Epoch: 1/5...Step: 300...Loss: 0.391551...
INFO:root:Epoch: 1/5...Step: 400...Loss: 0.176994...
INFO:root:Epoch: 1/5...Step: 500...Loss: 0.286489...
INFO:root:Epoch: 1/5...Step: 600...Loss: 0.267409...
INFO:root:Epoch: 1/5...Step: 700...Loss: 0.297444...
INFO:root:Epoch: 1/5...Step: 800...Loss: 0.086461...
INFO:root:Epoch: 1/5...Step: 900...Loss: 0.622738...
INFO:root:Epoch: 1/5...Step: 1000...Loss: 0.199401...
INFO:root:1000
INFO:root:Epoch: 1/5...Step: 1000...Loss: 0.199401...Val Loss: 0.127336
INFO:root:Validation loss decreased (0.300000 --> 0.127336).  Saving model ...
```

# Reference
* https://github.com/ronghanghu/pytorch-gve-lrcn/blob/master/models/pretrained_models.py