In [None]:
# # This Python 3 environment comes with many helpful analytics libraries installed
# # It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# # For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# # Input data files are available in the read-only "../input/" directory
# # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
import os
import cv2
import time
import copy
from tqdm import tqdm
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.nn.functional import softmax
import torch.nn.functional as F
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, \
    AdaptiveAvgPool2d, Sequential, Module
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
from torchvision.models import mobilenet_v2

from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score

from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')

# Prepare data and dataloader

In [3]:
# Function to extract face with OpenCV
def extract_face(image_path, show=False):
    image = Image.open(image_path)
    image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    
    bbox_file = image_path[:-4] + '_BB.txt'
    if os.path.exists(bbox_file):
        bbox = open(bbox_file).readlines()[0]
        bbox = [int(_) for _ in bbox.strip().split()[:4]]
        real_w, real_h = image.size
        x1 = int(bbox[0] * (real_w / 224))
        y1 = int(bbox[1] * (real_h / 224))
        w1 = int(bbox[2] * (real_w / 224))
        h1 = int(bbox[3] * (real_h / 224))
    else:
        face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
        gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
        faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
        if len(faces) == 0:
            raise ValueError("No face detected in the image.")
        x1, y1, w1, h1 = faces[0]

    face = image.crop((x1, y1, x1 + w1, y1 + h1))
    if show:
        plt.imshow(face)
        plt.show()
    return face

In [4]:
data_dir = '/kaggle/input/celeba-spoof-for-face-antispoofing/CelebA_Spoof_/CelebA_Spoof/Data/'
train_size = len(os.listdir(os.path.join(data_dir, 'train')))
test_size = len(os.listdir(os.path.join(data_dir, 'test')))

print('train: {}; test: {}'.format(train_size, test_size))

train: 8192; test: 1004


In [5]:
path_train_json = '/kaggle/input/celeba-spoof-for-face-antispoofing/CelebA_Spoof_/CelebA_Spoof/metas/intra_test/train_label.json'
path_test_json = '/kaggle/input/celeba-spoof-for-face-antispoofing/CelebA_Spoof_/CelebA_Spoof/metas/intra_test/test_label.json'
path_local = '/kaggle/input/celeba-spoof-for-face-antispoofing/CelebA_Spoof_/CelebA_Spoof/'

In [6]:
df_train = pd.read_json(path_train_json, orient='index')
df_test = pd.read_json(path_test_json, orient='index')

df_train = df_train.reset_index()
df_test = df_test.reset_index()
df_train.rename(columns={'index': 'Filepath'}, inplace=True)
df_test.rename(columns={'index': 'Filepath'}, inplace=True)

In [7]:
df_train['Filepath'] = df_train['Filepath'].apply(lambda x: path_local +  x)
df_test['Filepath'] = df_test['Filepath'].apply(lambda x: path_local  + x)

In [None]:
df_train

In [8]:
invalid_file_name = '/kaggle/input/celeba-spoof-for-face-antispoofing/CelebA_Spoof_/CelebA_Spoof/Data/train/3329/spoof/004046.jpg'

df_train.drop(df_train[df_train['Filepath']==invalid_file_name].index, inplace=True)
df_train

Unnamed: 0,Filepath,0,1,2,3,4,5,6,7,8,...,34,35,36,37,38,39,40,41,42,43
0,/kaggle/input/celeba-spoof-for-face-antispoofi...,0,1,1,0,0,0,0,0,0,...,0,0,1,0,0,1,0,0,0,0
1,/kaggle/input/celeba-spoof-for-face-antispoofi...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,6,1,1,1
2,/kaggle/input/celeba-spoof-for-face-antispoofi...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,7,1,1,1
3,/kaggle/input/celeba-spoof-for-face-antispoofi...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,5,1,2,1
4,/kaggle/input/celeba-spoof-for-face-antispoofi...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,7,4,2,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
494400,/kaggle/input/celeba-spoof-for-face-antispoofi...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,3,1,1,1
494401,/kaggle/input/celeba-spoof-for-face-antispoofi...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,3,3,2,1
494402,/kaggle/input/celeba-spoof-for-face-antispoofi...,0,0,0,0,0,1,1,0,0,...,0,0,1,1,0,0,0,0,0,0
494403,/kaggle/input/celeba-spoof-for-face-antispoofi...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,7,2,2,1


In [35]:
import cv2
import numpy as np 
from skimage.feature import local_binary_pattern
import matplotlib.pyplot as plt


def extract_lbp_features(ImagePath):
    
    image = cv2.resize(cv2.cvtColor(ImagePath , cv2.COLOR_BGR2GRAY) , (255 , 255))
    # Compute the Local Binary Pattern (LBP)
    lbp = local_binary_pattern(image, P=8, R=1, method='uniform')
    
    return np.expand_dims(lbp , axis = 2)


def extract_fourier_features(ImagePath):

    image = cv2.resize(cv2.cvtColor(ImagePath , cv2.COLOR_BGR2GRAY) , (255 , 255))
     # Compute the 2D Fast Fourier Transform (FFT)
    f = np.fft.fft2(image)
    
    # Shift the zero frequency component to the center
    fshift = np.fft.fftshift(f)
    
    # Compute the magnitude spectrum
    magnitude_spectrum = np.abs(fshift)
    
    # Normalize the magnitude spectrum
    magnitude_spectrum = np.log(1 + magnitude_spectrum)  # Use log scaling for better visualization
    
    
    return np.expand_dims(magnitude_spectrum , axis = 2)


def extract_hsv_features(ImagePath, bins=8):  

    
    image = cv2.resize(ImagePath , (255 , 255))
    # Convert the image to HSV color space
    hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    
    return hsv_image


def extract_features(ImagePath):

    ImagePath = np.array(ImagePath)
    lbp = extract_lbp_features(ImagePath)
    fourier = extract_fourier_features(ImagePath)
    hsv = extract_hsv_features(ImagePath)

    return np.concatenate((lbp , fourier , hsv) , axis = 2)

In [None]:
# transformations
transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
#         transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
}

In [10]:
# down sample the size of trainning set
df_train_sample = df_train.sample(frac=0.01, random_state=43)
df_train_sample[43].value_counts()

43
1    3329
0    1615
Name: count, dtype: int64

In [11]:
df_1 = df_train_sample[df_train_sample[43]==1][:1000]
df_2 = df_train_sample[df_train_sample[43]==0][:1000]
df_train_sample_balanced = pd.concat([df_1, df_2])
df_train_sample_balanced = df_train_sample_balanced.sample(frac=1, random_state=42).reset_index(drop=True)

In [39]:
# prepare data
class FASDataset(Dataset):
    def __init__(self, df, transforms=None, ft_width=32, ft_height=32):
        self.df = df
        self.transforms = transforms
        self.ft_width = ft_width
        self.ft_height = ft_height
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['Filepath']
        bbox_path = img_path[:-4] + '_BB.txt'
        
        face = extract_face(img_path)
        sample = np.transpose(extract_features(face) , (2 , 1 , 0))
        target = df_train.iloc[idx][43]
                
        # Generate the FT picture of the sample
        ft_sample = self.generate_FT(face)
        
        if self.transforms is not None:
            try:
                sample = self.transforms(face)
                sample = np.transpose(extract_features(sample) , (2 , 1 , 0))
            except Exception as err:
                print('Error Occured: %s' % err, img_path)
        
        assert sample is not None

        ft_sample = cv2.resize(ft_sample, (self.ft_width, self.ft_height))
        ft_sample = torch.from_numpy(ft_sample).float()
        ft_sample = torch.unsqueeze(ft_sample, 0)

        return sample, ft_sample, target
    
    def generate_FT(self, image):
        image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2GRAY)
        f = np.fft.fft2(image)
        fshift = np.fft.fftshift(f)
        fimg = np.log(np.abs(fshift)+1)
        maxx = -1
        minn = 100000
        for i in range(len(fimg)):
            if maxx < max(fimg[i]):
                maxx = max(fimg[i])
            if minn > min(fimg[i]):
                minn = min(fimg[i])
        fimg = (fimg - minn+1) / (maxx - minn+1)
        return fimg

In [13]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score

train_df, val_df = train_test_split(df_train_sample_balanced, test_size=0.2, random_state=42, stratify=df_train_sample_balanced[43])

In [40]:
train_dataset = FASDataset(train_df)
dataloader_train = DataLoader(train_dataset, batch_size=32)

val_dataset = FASDataset(val_df)
dataloader_val = DataLoader(val_dataset, batch_size=32)

# Create Model

In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('using \'{}\' device'.format(device))

using 'cuda' device


In [16]:
keep_dict = {'1.8M': [32, 32, 103, 103, 64, 13, 13, 64, 26, 26,
                      64, 13, 13, 64, 52, 52, 64, 231, 231, 128,
                      154, 154, 128, 52, 52, 128, 26, 26, 128, 52,
                      52, 128, 26, 26, 128, 26, 26, 128, 308, 308,
                      128, 26, 26, 128, 26, 26, 128, 512, 512],

             '1.8M_': [32, 32, 103, 103, 64, 13, 13, 64, 13, 13, 64, 13,
                       13, 64, 13, 13, 64, 231, 231, 128, 231, 231, 128, 52,
                       52, 128, 26, 26, 128, 77, 77, 128, 26, 26, 128, 26, 26,
                       128, 308, 308, 128, 26, 26, 128, 26, 26, 128, 512, 512]
             }

class Conv_block(Module):
    def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
        super(Conv_block, self).__init__()
        self.conv = Conv2d(in_c, out_c, kernel_size=kernel, groups=groups,
                           stride=stride, padding=padding, bias=False)
        self.bn = BatchNorm2d(out_c)
        self.prelu = PReLU(out_c)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.prelu(x)
        return x
    
class Linear_block(Module):
    def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
        super(Linear_block, self).__init__()
        self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel,
                           groups=groups, stride=stride, padding=padding, bias=False)
        self.bn = BatchNorm2d(out_c)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x


class Depth_Wise(Module):
     def __init__(self, c1, c2, c3, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
        super(Depth_Wise, self).__init__()
        c1_in, c1_out = c1
        c2_in, c2_out = c2
        c3_in, c3_out = c3
        self.conv = Conv_block(c1_in, out_c=c1_out, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
        self.conv_dw = Conv_block(c2_in, c2_out, groups=c2_in, kernel=kernel, padding=padding, stride=stride)
        self.project = Linear_block(c3_in, c3_out, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
        self.residual = residual

     def forward(self, x):
        if self.residual:
            short_cut = x
        x = self.conv(x)
        x = self.conv_dw(x)
        x = self.project(x)
        if self.residual:
            output = short_cut + x
        else:
            output = x
        return output
    

class Residual(Module):
    def __init__(self, c1, c2, c3, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
        super(Residual, self).__init__()
        modules = []
        for i in range(num_block):
            c1_tuple = c1[i]
            c2_tuple = c2[i]
            c3_tuple = c3[i]
            modules.append(Depth_Wise(c1_tuple, c2_tuple, c3_tuple, residual=True,
                                      kernel=kernel, padding=padding, stride=stride, groups=groups))
        self.model = Sequential(*modules)

    def forward(self, x):
        return self.model(x)
    

# Define the Fourier Transform Generator
class FTGenerator(nn.Module):
    def __init__(self, in_channels=128, out_channels=1):
        super(FTGenerator, self).__init__()

        self.ft = nn.Sequential(
            nn.Conv2d(in_channels, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, out_channels, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.ft(x)

# Modify the MobileNetV2 model
class MobileNetV2_FT(nn.Module):
    def __init__(self, num_classes=2, freeze_backbone=True):
        super(MobileNetV2_FT, self).__init__()
        keep = keep_dict['1.8M_']
        self.conv1 = Conv_block(5, keep[0], kernel=(3, 3), stride=(2, 2), padding=(1, 1))
        self.conv2_dw = Conv_block(keep[0], keep[1], kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=keep[1])

        c1 = [(keep[1], keep[2])]
        c2 = [(keep[2], keep[3])]
        c3 = [(keep[3], keep[4])]

        self.conv_23 = Depth_Wise(c1[0], c2[0], c3[0], kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=keep[3])

        c1 = [(keep[4], keep[5]), (keep[7], keep[8]), (keep[10], keep[11]), (keep[13], keep[14])]
        c2 = [(keep[5], keep[6]), (keep[8], keep[9]), (keep[11], keep[12]), (keep[14], keep[15])]
        c3 = [(keep[6], keep[7]), (keep[9], keep[10]), (keep[12], keep[13]), (keep[15], keep[16])]

        self.conv_3 = Residual(c1, c2, c3, num_block=4, groups=keep[4], kernel=(3, 3), stride=(1, 1), padding=(1, 1))

        c1 = [(keep[16], keep[17])]
        c2 = [(keep[17], keep[18])]
        c3 = [(keep[18], keep[19])]

        self.conv_34 = Depth_Wise(c1[0], c2[0], c3[0], kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=keep[19])

        c1 = [(keep[19], keep[20]), (keep[22], keep[23]), (keep[25], keep[26]), (keep[28], keep[29]),
              (keep[31], keep[32]), (keep[34], keep[35])]
        c2 = [(keep[20], keep[21]), (keep[23], keep[24]), (keep[26], keep[27]), (keep[29], keep[30]),
              (keep[32], keep[33]), (keep[35], keep[36])]
        c3 = [(keep[21], keep[22]), (keep[24], keep[25]), (keep[27], keep[28]), (keep[30], keep[31]),
              (keep[33], keep[34]), (keep[36], keep[37])]

        self.conv_4 = Residual(c1, c2, c3, num_block=6, groups=keep[19], kernel=(3, 3), stride=(1, 1), padding=(1, 1))
        self.model = mobilenet_v2(pretrained=True)
        original_conv = self.model.features[0][0]
        new_conv = nn.Conv2d(5, original_conv.out_channels, kernel_size=original_conv.kernel_size,
                             stride=original_conv.stride, padding=original_conv.padding, bias=original_conv.bias)

        # Copy the weights from the original conv layer to the new conv layer
        with torch.no_grad():
            new_conv.weight[:, :3, :, :] = original_conv.weight
            new_conv.weight[:, 3:, :, :] = original_conv.weight.mean(dim=1, keepdim=True)

        # Replace the original conv layer with the new conv layer in the model
        self.model.features[0][0] = new_conv
        
        # Freeze the backbone layers
        if freeze_backbone:
            for param in self.model.parameters():
                param.requires_grad = False
                
        # Unfreeze the last two layers
        for param in self.model.features[-2:].parameters():
            param.requires_grad = True
                
        self.model.classifier[1] = nn.Linear(self.model.classifier[1].in_features, num_classes)
        self.FTGenerator = FTGenerator(in_channels=128)

    def forward(self, x):
        features = self.model.features(x)
        # x1 = self.model.avgpool(x)
        # Global average pooling
        x1 = nn.functional.adaptive_avg_pool2d(features, (1, 1))
        x1 = torch.flatten(x1, 1)
        cls_output = self.model.classifier(x1)
        
        if self.training:
            x = self.conv1(x)
            x = self.conv2_dw(x)
            x = self.conv_23(x)
            x = self.conv_3(x)
            x = self.conv_34(x)
            x = self.conv_4(x)
            ft_output = self.FTGenerator(x)
            return cls_output , ft_output
        else:
            return cls_output

In [17]:
model = MobileNetV2_FT(num_classes=2).to(device)
model

MobileNetV2_FT(
  (conv1): Conv_block(
    (conv): Conv2d(5, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=32)
  )
  (conv2_dw): Conv_block(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=32)
  )
  (conv_23): Depth_Wise(
    (conv): Conv_block(
      (conv): Conv2d(32, 103, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(103, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu): PReLU(num_parameters=103)
    )
    (conv_dw): Conv_block(
      (conv): Conv2d(103, 103, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=103, bias=False)
      (bn): BatchNorm2d(103, eps=1e-05, momentum=0.1, affine=True, track_running_s

In [18]:
model.cuda()

MobileNetV2_FT(
  (conv1): Conv_block(
    (conv): Conv2d(5, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=32)
  )
  (conv2_dw): Conv_block(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=32)
  )
  (conv_23): Depth_Wise(
    (conv): Conv_block(
      (conv): Conv2d(32, 103, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(103, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu): PReLU(num_parameters=103)
    )
    (conv_dw): Conv_block(
      (conv): Conv2d(103, 103, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=103, bias=False)
      (bn): BatchNorm2d(103, eps=1e-05, momentum=0.1, affine=True, track_running_s

In [19]:
# Initialize loss functions, and optimizer
criterion_cls = nn.CrossEntropyLoss()
criterion_ft = nn.MSELoss()
# optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.AdamW(model.parameters(), lr=0.001)

In [20]:
# Function to save the checkpoint
def save_checkpoint(model, optimizer, epoch, path, loss):
    state = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }
    torch.save(state, path)

# Function to load the checkpoint
def load_checkpoint(model, optimizer, path, device):
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    return model, optimizer, epoch, loss


# Function to evaluate model on validation data
def evaluate_model(model, dataloader, criterion_cls, device):
    model.eval()
    running_loss = 0.0
    running_acc = 0.0
    all_labels = []
    all_preds = []
    
    with torch.no_grad():
        for inputs,ft_inputs, labels in dataloader:
            inputs, ft_inputs, labels = inputs.to(device).float(),ft_inputs.to(device), labels.to(device)
            outputs_cls = model(inputs)
            
            loss = criterion_cls(outputs_cls, labels)
            running_loss += loss.item()
            
            preds = torch.argmax(outputs_cls, dim=1)
            running_acc += (preds == labels).sum().item()
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
    
    avg_loss = running_loss / len(dataloader)
    avg_acc = running_acc / len(dataloader.dataset)
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')
    
    return avg_loss, avg_acc, precision, recall, f1


# Training Loop with Checkpointing
def train_model(model, train_loader, val_loader, criterion_cls,criterion_ft, optimizer, device, num_epochs=10, checkpoint_path='checkpoint.pth'):
    start_epoch = 0
    best_loss = float('inf')
    scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    
    if os.path.exists(checkpoint_path):
        model, optimizer, start_epoch, best_loss = load_checkpoint(model, optimizer, checkpoint_path, device)
        print(f"Resuming from epoch {start_epoch+1}")

    for epoch in range(start_epoch, num_epochs):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        all_labels = []
        all_preds = []
        
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for inputs,ft_inputs, labels in progress_bar:
            inputs,ft_inputs, labels = inputs.to(device).float(), ft_inputs.to(device) , labels.to(device)
            optimizer.zero_grad()
            outputs_cls ,outputs_ft = model(inputs)
            
            loss_cls = criterion_cls(outputs_cls, labels)
            loss_ft = criterion_ft(outputs_ft, ft_inputs)
            loss = 0.5 * loss_cls + 0.5 * loss_ft
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
            preds = torch.argmax(outputs_cls, dim=1)
            running_acc += (preds == labels).sum().item()
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            
            progress_bar.set_postfix({'loss': running_loss/len(train_loader), 'acc': running_acc/len(train_loader.dataset)})
        
        avg_loss = running_loss / len(train_loader)
        avg_acc = running_acc / len(train_loader.dataset)
        
        precision = precision_score(all_labels, all_preds, average='weighted')
        recall = recall_score(all_labels, all_preds, average='weighted')
        f1 = f1_score(all_labels, all_preds, average='weighted')
        
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')
        
        val_loss, val_acc, val_precision, val_recall, val_f1 = evaluate_model(model, val_loader, criterion_cls, device)
        print(f'Validation - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1 Score: {val_f1:.4f}')
        
        if avg_loss < best_loss:
            best_loss = avg_loss
            save_checkpoint(model, optimizer, epoch, checkpoint_path, best_loss)
            print(f"Checkpoint saved at epoch {epoch+1} with loss {best_loss:.4f}")

    scheduler.step()

In [None]:
train_df['feature'][0].shape

In [42]:
# Train the model with checkpointing and learning rate scheduler
train_model(model, dataloader_train, dataloader_val, criterion_cls,criterion_ft, optimizer, device, num_epochs=50, checkpoint_path='/kaggle/working/fas_mobilenetv2_v2.pth')

Resuming from epoch 19


Epoch 19/50:   4%|▍         | 2/50 [00:05<02:01,  2.53s/it, loss=8.25e-5, acc=0.04]


KeyboardInterrupt: 