#### Environment: pt_fpad  
Python: 3.10.4     
Pytorch: 2.1.1+cu118

In [1]:
#To provide server to VPN_access via PC
import os
os.environ['http_proxy'] = 'http://10.162.15.186:5555'
os.environ['https_proxy'] = 'http://10.162.15.186:5555'

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms, models
import torch.optim as optim
import os
import random
import cv2
import numpy as np
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class_label_real = 0
class_label_attack = 1

Rose-Youtu Dataset:  
Training -> 1397 (Real 358, Attack 1039) -> 1397/3495 * 100 = 40%  
Validation -> 350 (Real 90, Attack 260) -> 350/3495 * 100 = 10%  
Testing -> 1748 (Real 449, Attack 1299) -> 1748/3495 * 100 = 50%   
Total = 3495

In [4]:
data_path_train_real = '/home/data/taha/FASdatasets/Rose_Youtu/train/real/'
data_path_train_attack = '/home/data/taha/FASdatasets/Rose_Youtu/train/attack/'

data_path_devel_real = '/home/data/taha/FASdatasets/Rose_Youtu/devel/real/'
data_path_devel_attack = '/home/data/taha/FASdatasets/Rose_Youtu/devel/attack/'

data_path_GAB_fb = '/home/taha/Taha26/All_Experiments/cyclegan_fpad_SRA_TRB_lsgan_res9/pytorch-CycleGAN-and-pix2pix-master-selftrain-fpad-RY/saved_results/RY_trainattack_GAB_fakebon6234'

In [5]:
def load_samples_image(image_path, class_label, transform):
    image = Image.open(image_path).convert('RGB') # It uses PIL (Pillow) library to open the image, convert it to the RGB mode
    sample = (transform(image), class_label) # Apply transformation
    return sample

class ImageDataset(Dataset):
    def __init__(self, data_path, class_label):
        self.data_path = data_path
        self.image_files = [file for file in os.listdir(data_path) if file.endswith(('.png', '.jpg', '.jpeg'))]
        self.class_label = class_label
        self.data_length = len(self.image_files)
        self.transform = transforms.Compose([transforms.Resize((256, 256)),
                                             transforms.ToTensor()])

    def __len__(self):
        return self.data_length

    def __getitem__(self, idx):
        file = self.image_files[idx]
        path = os.path.join(self.data_path, file)
        sample = load_samples_image(path, self.class_label, self.transform)
        return sample

In [6]:
def load_samples(path, class_label, transform): #Select N frames returned from read_all_frames and assign labels to all samples of same class
        frames = read_all_frames(path)
        total_frames = list(range(0, frames.shape[0], 1))
        selected_samples = random.sample(total_frames, 1)
        samples =[]
        # Assign the same class label to all samples
        label = class_label
        samples =(transform(frames[selected_samples].squeeze()), label)     
        return samples

def read_all_frames(video_path): # reads all frames from a particular video and converts them to PyTorch tensors.
    frame_list = []
    video = cv2.VideoCapture(video_path)
    success = True
    while success:
        success, frame = video.read()
        if success:
            frame = cv2.resize(frame, (256, 256), interpolation=cv2.INTER_AREA) #framesize kept 40, 30 as mentioned in paper but 224, 224 is also fine 
            frame_list.append(frame)
    frame_list = np.array(frame_list)
    return frame_list

class VideoDataset(Dataset):
    def __init__(self, data_path, class_label):
        self.data_path = data_path #path for directory containing video files
        self.video_files = [file for file in os.listdir(data_path) if file.endswith('.mp4')] #list of video files in the specified directory #.mov for RA and RM, .mp4 for RY
        self.class_label = class_label #manually assign class_label for your desired class while loading
        self.data_length = len(self.video_files) 
        self.transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self): # returns the total number of samples in the dataset
        return self.data_length

    def __getitem__(self, idx): # loads and returns a sample from the dataset at the given index
        file = self.video_files[idx]
        path = os.path.join(self.data_path, file)
        frames= load_samples(path, self.class_label, self.transform)

        return frames

In [7]:
train_dataset_real = VideoDataset(data_path_train_real, class_label_real)
train_dataset_attack = VideoDataset(data_path_train_attack, class_label_attack)

val_dataset_real = VideoDataset(data_path_devel_real, class_label_real)
val_dataset_attack = VideoDataset(data_path_devel_attack, class_label_attack)

train_dataset_GAB_fb = ImageDataset(data_path_GAB_fb, class_label_attack)

In [8]:
train_loader_real = DataLoader(train_dataset_real, batch_size=1, shuffle=True)
train_loader_attack = DataLoader(train_dataset_attack, batch_size=1, shuffle=True)

val_loader_real = DataLoader(val_dataset_real, batch_size=1, shuffle=False)
val_loader_attack = DataLoader(val_dataset_attack, batch_size=1, shuffle=False)

train_loader_GAB_fb = DataLoader(train_dataset_GAB_fb, batch_size=1, shuffle=True)

In [9]:
concatenated_train_dataset = ConcatDataset([train_dataset_real, train_dataset_attack, train_dataset_GAB_fb])
concatenated_val_dataset = ConcatDataset([val_dataset_real, val_dataset_attack])

concatenated_train_loader = DataLoader(concatenated_train_dataset, batch_size=64, shuffle=True, pin_memory=True, num_workers=8)
concatenated_val_loader = DataLoader(concatenated_val_dataset, batch_size=64, shuffle=False, pin_memory=True, num_workers=8)

In [10]:
# Print dataset sizes
print(f"Training set size: {len(concatenated_train_dataset)}")
print(f"Validation set size: {len(concatenated_val_dataset)}")

Training set size: 7623
Validation set size: 359


In [11]:
from transformers import MobileViTV2ForImageClassification
model = MobileViTV2ForImageClassification.from_pretrained("apple/mobilevitv2-1.0-imagenet1k-256")
model.classifier = nn.Linear(in_features=512, out_features=2)
print(model)

MobileViTV2ForImageClassification(
  (mobilevitv2): MobileViTV2Model(
    (conv_stem): MobileViTV2ConvLayer(
      (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (normalization): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): SiLU()
    )
    (encoder): MobileViTV2Encoder(
      (layer): ModuleList(
        (0): MobileViTV2MobileNetLayer(
          (layer): ModuleList(
            (0): MobileViTV2InvertedResidual(
              (expand_1x1): MobileViTV2ConvLayer(
                (convolution): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (activation): SiLU()
              )
              (conv_3x3): MobileViTV2ConvLayer(
                (convolution): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
 

  return self.fget.__get__(instance, owner)()


In [12]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [13]:
epochs = 50

device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model.to(device)

train_losses = []
val_losses = []

train_accuracies = []
val_accuracies = []

# Set up early stopping parameters
patience = 5  # Number of epochs with no improvement after which training will be stopped
best_loss = float('inf') #set to positive infinity to ensure that the first validation loss encountered will always be considered an improvement
counter = 0  # Counter to keep track of consecutive epochs with no improvement

#Training loop
for epoch in range(epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    train_correct_predictions = 0
    total_train_samples = 0

    for train_images, train_labels in concatenated_train_loader:
        train_images, train_labels = train_images.to(device), train_labels.to(device)
        # Clear the gradients
        optimizer.zero_grad()
        # Forward Pass
        train_outputs = model(train_images)

        # Extract logits
        train_logits = train_outputs.logits

        # Find the Loss
        train_loss = criterion(train_logits, train_labels)
        # Calculate gradients
        train_loss.backward()
        # Update Weights
        optimizer.step()

        # accumulate the training loss
        running_loss += train_loss.item()

        # calculate training accuracy
        _, train_predicted = torch.max(train_logits, 1) # _ contain max value, train_predicted contain the indices where maximum value occured
        train_correct_predictions += (train_predicted == train_labels).sum().item() 
        total_train_samples += train_labels.size(0)
            
    train_total_loss = running_loss / len(concatenated_train_loader)
    train_accuracy = train_correct_predictions / total_train_samples * 100
    train_losses.append(train_total_loss)
    train_accuracies.append(train_accuracy)

    val_running_loss = 0.0
    val_correct_prediction = 0
    total_val_samples = 0

    #Validation
    model.eval()
    with torch.no_grad():
        for val_images, val_labels in concatenated_val_loader:
            val_images, val_labels = val_images.to(device), val_labels.to(device)
            val_op = model(val_images)

            # Assuming val_op is the model's output, which is of type ImageClassifierOutputWithNoAttention
            val_logits = val_op.logits

            val_loss = criterion(val_logits, val_labels)
            val_running_loss += val_loss.item()

            _, val_predicted = torch.max(val_logits, 1)
            val_correct_prediction += (val_predicted == val_labels).sum().item()
            total_val_samples += val_labels.size(0)
        
        val_total_loss = val_running_loss / len(concatenated_val_loader)
        val_accuracy = val_correct_prediction / total_val_samples * 100
        val_losses.append(val_total_loss)
        val_accuracies.append(val_accuracy)

    # Check if validation loss has improved
    if val_total_loss < best_loss:
        best_loss = val_total_loss
        counter = 0
        # Save the model if needed
        torch.save(model.state_dict(), 'mobilevitv2_RY_GAB_FB.pth')

    else:
        counter += 1

        # Check if training should be stopped
        if counter >= patience:
            print(f'Early stopping at epoch {epoch}')
            break

    print(f'Epoch {epoch + 1}/{epochs}, Training Loss: {train_total_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%, Validation Loss: {val_total_loss: .4f}, Best Loss: {best_loss: .4f}, Validation Accuracy: {val_accuracy:.2f}%')

Epoch 1/50, Training Loss: 0.1412, Training Accuracy: 97.31%, Validation Loss:  0.0562, Best Loss:  0.0562, Validation Accuracy: 96.94%
Epoch 2/50, Training Loss: 0.0068, Training Accuracy: 99.90%, Validation Loss:  0.2603, Best Loss:  0.0562, Validation Accuracy: 92.20%
Epoch 3/50, Training Loss: 0.0062, Training Accuracy: 99.82%, Validation Loss:  0.0063, Best Loss:  0.0063, Validation Accuracy: 100.00%
Epoch 4/50, Training Loss: 0.0006, Training Accuracy: 100.00%, Validation Loss:  0.0169, Best Loss:  0.0063, Validation Accuracy: 98.89%
Epoch 5/50, Training Loss: 0.0046, Training Accuracy: 99.86%, Validation Loss:  0.1615, Best Loss:  0.0063, Validation Accuracy: 95.54%
Epoch 6/50, Training Loss: 0.0014, Training Accuracy: 99.96%, Validation Loss:  0.1190, Best Loss:  0.0063, Validation Accuracy: 96.94%
Epoch 7/50, Training Loss: 0.0068, Training Accuracy: 99.78%, Validation Loss:  0.0316, Best Loss:  0.0063, Validation Accuracy: 98.61%
Early stopping at epoch 7
