In [8]:
from pathlib import Path
import torch
import torchvision
from torchvision import datasets, transforms
import time
from sklearn import metrics
import os
import time
import copy

In [9]:

git_dir = Path().parent.absolute().parent


train_dir = os.path.join(git_dir, "data/train")
test_dir = os.path.join(git_dir, "data/test")


train_transforms = transforms.Compose([
        torchvision.transforms.RandomResizedCrop(224, 224),
        torchvision.transforms.RandomRotation(10),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ColorJitter(brightness=0.6,
                                           contrast=0.4,
                                           saturation=0.7),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
                [0.485, 0.456, 0.406], 
                [0.229, 0.224, 0.225])
])

        
test_transforms = transforms.Compose([
        transforms.Resize((224,224))
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
                [0.485, 0.456, 0.406], 
                [0.229, 0.224, 0.225])
])





#datasets
train_data = datasets.ImageFolder(train_dir,transform=train_transforms)
test_data = datasets.ImageFolder(test_dir,transform=test_transforms)

#dataloader
train_loader = torch.utils.data.DataLoader(train_data, shuffle = True, batch_size=16)
test_loader = torch.utils.data.DataLoader(test_data, shuffle = True, batch_size=16)

In [10]:
def make_train_step(model, optimizer, loss_fn):
  def train_step(x,y):
    #make prediction
    yhat = model(x)
    #enter train mode
    model.train()
    #compute loss
    loss = loss_fn(yhat,y)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    #optimizer.cleargrads()

    return loss
  return train_step

In [15]:
 for x_batch, y_batch in test_loader:
    print(x_batch, y_batch)
    continue

RuntimeError: stack expects each tensor to be equal size, but got [3, 298, 224] at entry 0 and [3, 224, 224] at entry 1

In [11]:

device = "cuda" if torch.cuda.is_available() else "cpu"

model = torchvision.models.resnet50(pretrained=True)
model.aux_logits = True # auxilary outputs to avoid gradient atenuation

# Initialize parameters
for param in model.parameters():
  param.requires_grad = True

nb_layers = len([param.requires_grad for param in model.parameters()])

layers_to_unfreeze = 70

# Freeze params until 70
for i, param in enumerate(model.parameters()):
  if i<= nb_layers - layers_to_unfreeze:
    param.requires_grad = False
  
model.fc = torch.nn.Sequential(
                        #torch.nn.Linear(25088, 4096),#model.classifier.in_features, 4096),
                        torch.nn.Linear(model.fc.in_features, 4096),
                        torch.nn.LeakyReLU(),
                        torch.nn.Dropout(0.2),
                        torch.nn.Linear(4096, 4096),
                        torch.nn.LeakyReLU(),
                        torch.nn.Dropout(0.2),
                        torch.nn.Linear(4096, 1)
)

In [12]:

from torch.nn.modules.loss import BCEWithLogitsLoss
from torch.optim import lr_scheduler

#loss
loss_fn = BCEWithLogitsLoss() #binary cross entropy with sigmoid, so no need to use sigmoid in the model

#optimizer
optimizer = torch.optim.Adam(model.fc.parameters()) 

#train step
train_step = make_train_step(model, optimizer, loss_fn)

In [13]:
#%%capture
!pip install tqdm
from tqdm import tqdm


losses = []
val_losses = []

epoch_train_losses = []
epoch_test_losses = []

n_epochs = 10
early_stopping_tolerance = 3
early_stopping_threshold = 0.03

for epoch in range(n_epochs):
  epoch_loss = 0
  for i ,data in tqdm(enumerate(train_loader), total = len(train_loader)): #iterate ove batches
    x_batch , y_batch = data
    x_batch = x_batch.to(device) #move to gpu
    y_batch = y_batch.unsqueeze(1).float() #convert target to same nn output shape
    y_batch = y_batch.to(device) #move to gpu


    loss = train_step(x_batch, y_batch)
    epoch_loss += loss/len(train_loader)
    losses.append(loss)
    
  epoch_train_losses.append(epoch_loss)
  print('\nEpoch : {}, train loss : {}'.format(epoch+1,epoch_loss))

  #validation doesnt requires gradient
  with torch.no_grad():
    cum_loss = 0
    for x_batch, y_batch in test_loader:
      x_batch = x_batch.to(device)
      y_batch = y_batch.unsqueeze(1).float() #convert target to same nn output shape
      y_batch = y_batch.to(device)

      #model to eval mode
      model.eval()

      yhat = model(x_batch)
      val_loss = loss_fn(yhat,y_batch)
      cum_loss += loss/len(test_loader)
      val_losses.append(val_loss.item())


    epoch_test_losses.append(cum_loss)
    print('Epoch : {}, val loss : {}'.format(epoch+1,cum_loss))  
    
    best_loss = min(epoch_test_losses)
    
    #save best model
    if cum_loss <= best_loss:
      best_model_wts = model.state_dict()
    
    #early stopping
    early_stopping_counter = 0
    if cum_loss > best_loss:
      early_stopping_counter +=1

    if (early_stopping_counter == early_stopping_tolerance) or (best_loss <= early_stopping_threshold):
      print("/nTerminating: early stopping")
      break #terminate training
    
#load best model
model.load_state_dict(best_model_wts)





  1%|          | 8/1218 [00:32<1:22:00,  4.07s/it]


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt 

def inference(test_data):
  idx = torch.randint(1, len(test_data), (1,))
  sample = torch.unsqueeze(test_data[idx][0], dim=0).to(device)

  
  if torch.sigmoid(model(sample)) < 0.5:
    print("Prediction : live")
  else:
    print("Prediction : spoof")


  plt.imshow(test_data[idx][0].permute(1, 2, 0))

inference(test_data)

In [None]:
model.eval()  # Set the model to evaluation mode

total_loss = 0
total_correct = 0
total_samples = 0

with torch.no_grad():
    for x_batch, y_batch in test_loader:
        x_batch = x_batch.to(device)
        y_batch = y_batch.unsqueeze(1).float().to(device)  # Reshape target to match model output

        yhat = model(x_batch)  # Perform forward pass
        loss = loss_fn(yhat, y_batch)  # Calculate loss

        total_loss += loss.item() * x_batch.size(0)  # Accumulate loss

        total_correct += ((yhat.sigmoid() > 0.5) ==  (y_batch > 0.5)).sum().item()  # Count correct predictions
        total_samples += x_batch.size(0)  # Count total samples

avg_loss = total_loss / total_samples
accuracy = total_correct / total_samples

print('Evaluation loss: {:.4f}, Accuracy: {:.4f}'.format(avg_loss, accuracy))

In [None]:
torch.save(model, 'models/transfer_learning.pt')

In [None]:
model = torch.load('models/transfer_learning.pt')
model.eval()