#### Environment: pt_fpad  
Python: 3.10.4     
Pytorch: 2.1.1+cu118

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms, models
import torch.optim as optim
import os
import random
import cv2
import numpy as np

In [None]:
class_label_real = 0
class_label_attack = 1

Replay Attack Dataset:  
Training -> 360 (Real 60, Attack Fixed 150, Attack Hand 150) ->  360/1200 * 100 = 30%  
Validation -> 360 (Real 60, Attack Fixed 150, Attack Hand 150) -> 360/1200 * 100 = 30%  
Testing -> 480 (Real 80, Attack Fixed 200, Attack Hand 200) -> 480/1200 * 100 = 40%  
Total = 1200  

Replay Mobile Dataset:

Training -> 312 (Real 120, Attack 192) ->  312/1030 * 100 = 30.29%  
Validation -> 416 (Real 160, Attack 256) -> 416/1030 * 100 = 40.38%  
Testing -> 302 (Real 110, Attack 192) -> 302/1030 * 100 = 29.32%  
Total = 1030  

Rose-Youtu Dataset:  
Training -> 1397 (Real 358, Attack 1039) -> 1397/3495 * 100 = 40%  
Validation -> 350 (Real 90, Attack 260) -> 350/3495 * 100 = 10%  
Testing -> 1748 (Real 449, Attack 1299) -> 1748/3495 * 100 = 50%   
Total = 3495

In [None]:
# Replay Attack
# data_path_test_real = '/home/taha/FASdatasets/Sohail/Replay_Attack/test/real'
# data_path_test_attack_fixed = '/home/taha/FASdatasets/Sohail/Replay_Attack/test/attack/fixed'
# data_path_test_attack_hand = '/home/taha/FASdatasets/Sohail/Replay_Attack/test/attack/hand'

# Replay Mobile
# data_path_test_real = '/home/taha/FASdatasets/Sohail/Replay-Mobile/test/real'
# data_path_test_attack = '/home/taha/FASdatasets/Sohail/Replay-Mobile/test/attack'

# Rose Youtu
# data_path_test_real = '/home/taha/FASdatasets/Sohail/Rose_Youtu/test/real'
# data_path_test_attack = '/home/taha/FASdatasets/Sohail/Rose_Youtu/test/attack'

In [None]:
def load_samples(path, class_label, transform): #Select N frames returned from read_all_frames and assign labels to all samples of same class
        frames = read_all_frames(path)
        total_frames = list(range(0, frames.shape[0], 1))
        selected_samples = 5
        selected_frame = total_frames[selected_samples]
        samples =[]
        # Assign the same class label to all samples
        label = class_label
        samples =(transform(frames[selected_frame].squeeze()), label)     
        return samples

def read_all_frames(video_path): # reads all frames from a particular video and converts them to PyTorch tensors.
    frame_list = []
    video = cv2.VideoCapture(video_path)
    success = True
    while success:
        success, frame = video.read()
        if success:
            frame = cv2.resize(frame, (224, 224), interpolation=cv2.INTER_AREA) #framesize kept 40, 30 as mentioned in paper but 224, 224 is also fine 
            frame_list.append(frame)
    frame_list = np.array(frame_list)
    return frame_list

class VideoDataset(Dataset):
    def __init__(self, data_path, class_label):
        self.data_path = data_path #path for directory containing video files
        self.video_files = [file for file in os.listdir(data_path) if file.endswith('.mov') or file.endswith('.mp4')] #list of video files in the specified directory #.mov for RA and RM, .mp4 for RY
        self.class_label = class_label #manually assign class_label for your desired class while loading
        self.data_length = len(self.video_files) 
        self.transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self): # returns the total number of samples in the dataset
        return self.data_length

    def __getitem__(self, idx): # loads and returns a sample from the dataset at the given index
        file = self.video_files[idx]
        path = os.path.join(self.data_path, file)
        frames= load_samples(path, self.class_label, self.transform)

        return frames

In [None]:
# Replay Attack Dataset
# test_dataset_real = VideoDataset(data_path_test_real, class_label_real)
# test_dataset_attack_fixed = VideoDataset(data_path_test_attack_fixed, class_label_attack)
# test_dataset_attack_hand = VideoDataset(data_path_test_attack_hand, class_label_attack)

# Replay Mobile & Rose Youtu Dataset

# test_dataset_real = VideoDataset(data_path_test_real, class_label_real)
# test_dataset_attack = VideoDataset(data_path_test_attack, class_label_attack)

In [None]:
# Replay Attack Datasets
# test_loader_real = DataLoader(test_dataset_real, batch_size=1, shuffle=False)
# test_loader_attack_fixed = DataLoader(test_dataset_attack_fixed, batch_size=1, shuffle=False)
# test_loader_attack_hand = DataLoader(test_dataset_attack_hand, batch_size=1, shuffle=False)

# Replay Mobile & Rose Youtu Dataset
# test_loader_real = DataLoader(test_dataset_real, batch_size=1, shuffle=False)
# test_loader_attack = DataLoader(test_dataset_attack, batch_size=1, shuffle=False)

In [None]:
# Replay Attack Dataset
# concatenated_test_dataset = ConcatDataset([test_dataset_real, test_dataset_attack_fixed, test_dataset_attack_hand])
# concatenated_test_loader = DataLoader(concatenated_test_dataset, batch_size=64, shuffle=False, pin_memory=True, num_workers=8)

# Replay Mobile & Rose Youtu Dataset
# concatenated_test_dataset = ConcatDataset([test_dataset_real, test_dataset_attack])
# concatenated_test_loader = DataLoader(concatenated_test_dataset, batch_size=64, shuffle=False, pin_memory=False, num_workers=8)

In [None]:
# Print dataset sizes
print(f"Test set size: {len(concatenated_test_dataset)}")

In [None]:
# Load pre-trained ResNet18
# model = models.resnet18(pretrained=True)
# num_ftrs = model.fc.in_features
# model.fc = nn.Linear(num_ftrs, 2)

# Load pre-trained MobileNetV2
model = models.mobilenet_v2(pretrained=True)
model.classifier[1] = nn.Linear(in_features=1280, out_features=2) #default in_features =1280, out_features = 1000
# print(model)

In [None]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Define the path to your saved model file
model_path = '/home/taha/Taha26/pt_fpad/cross_ds_merge/model2/RA_RYR_best_model.pth'

# Load the saved model
checkpoint = torch.load(model_path)

# Load the model's state dictionary
model.load_state_dict(checkpoint)

In [None]:
# Evaluate on the test set
test_correct = 0
test_total = 0

model.eval()
with torch.no_grad():
    
    test_cat_labels = torch.empty(0, dtype=torch.int64, device=device)
    test_predicted_cat_labels = torch.empty(0, dtype=torch.int64, device=device)

    for test_images, test_labels in concatenated_test_loader:
        test_images, test_labels = test_images.to(device), test_labels.to(device)
        test_model_op = model(test_images)
        _, test_predicted = torch.max(test_model_op, 1)
        test_correct += (test_predicted == test_labels).sum().item() 
        test_total += test_labels.size(0)

        test_cat_labels = torch.cat((test_cat_labels, test_labels))
        test_predicted_cat_labels = torch.cat((test_predicted_cat_labels, test_predicted))

    test_accuracy = test_correct / test_total * 100  
    print(f'Test Accuracy: {test_accuracy:.2f}%')

In [None]:
test_cat_labels_cpu = test_cat_labels.cpu()
test_predicted_cat_labels_cpu = test_predicted_cat_labels.cpu()

In [None]:
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, balanced_accuracy_score

In [None]:
tn, fp, fn, tp = confusion_matrix(test_cat_labels_cpu, test_predicted_cat_labels_cpu).ravel()

print(f'TN: {tn}, FP: {fp}, FN: {fn}, TP: {tp}')

acc_score = accuracy_score(test_cat_labels_cpu, test_predicted_cat_labels_cpu)
prec_score = precision_score(test_cat_labels_cpu, test_predicted_cat_labels_cpu)
recall = recall_score(test_cat_labels_cpu, test_predicted_cat_labels_cpu)

Y_I_val =(tp/(tp+fn)) + (tn/(tn+fp)) - 1
sensitivity_val = tp / (tp + fn)
specificity_val = tn / (tn + fp)
f1score_val = 2 * tp / (2 * tp + fp + fn)
FAR = fp/(fp + tn)
FRR = fn/(fn + tp)
HTER_val = (FAR + FRR)/2
EER = (fp+fn)/(tn+fp+fn+tp)
val_bacc = balanced_accuracy_score(test_cat_labels_cpu, test_predicted_cat_labels_cpu)


print('Testing Results')
print(30*'-')
print('Acc:', acc_score, '\nSen:', sensitivity_val, '\nSpec:', specificity_val, '\nYI:', Y_I_val, '\nF1:', f1score_val, '\nPrec:', prec_score, '\nRecall:', recall, '\nHTER:', HTER_val, '\nEER:', EER, '\nBACC:', val_bacc)