In [2]:
import torch
import torch.nn as nn
import os
from torchvision import transforms, models
from PIL import Image
from sklearn.metrics import classification_report
import numpy as np
import time
from sklearn.metrics import confusion_matrix
import csv
start_time = time.time() 
class StudentNetwork(nn.Module):
    def __init__(self, num_classes=10):
        super(StudentNetwork, self).__init__()
        self.backbone = models.mobilenet_v2(pretrained=True)        
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.backbone.classifier[1] = nn.Linear(self.backbone.last_channel, 2)        
    def forward(self, x):
        out = self.backbone(x)        
        return out
def load_checkpoint(model, checkpoint_path='best_model.pth'):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval() 
    print(f"Checkpoint loaded from epoch {checkpoint['epoch']+1}")
def preprocess_image(image_path, transform):
    img = Image.open(image_path).convert('RGB')
    img = transform(img).unsqueeze(0)  
    return img
def predict_image(model, image_path, transform, device='cuda'):
    img = preprocess_image(image_path, transform).to(device)  
    with torch.no_grad():
        output = model(img)
    _, predicted_class = torch.max(output, 1)
    return predicted_class.item()
def evaluate_folder(model, folder_path, transform, device='cuda'):
    all_preds = []
    all_labels = []
    for class_folder in os.listdir(folder_path):
        class_folder_path = os.path.join(folder_path, class_folder)
        if os.path.isdir(class_folder_path):  
            for img_name in os.listdir(class_folder_path):
                img_path = os.path.join(class_folder_path, img_name)
                true_label = int(class_folder) 
                predicted_class = predict_image(model, img_path, transform, device)
                all_labels.append(true_label)
                all_preds.append(predicted_class)
    return all_labels, all_preds
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def calculate_sensitivity_specificity(conf_matrix):
    sensitivity = {}
    specificity = {}
    f1_score = {}
    for i in range(conf_matrix.shape[0]):
        TP = conf_matrix[i, i]
        FN = conf_matrix[:, i].sum() - TP
        FP = conf_matrix[i, :].sum() - TP
        TN = conf_matrix.sum() - (FP + FN + TP)
        sensitivity[i] = TP / (TP + FN) if (TP + FN) != 0 else 0
        specificity[i] = TN / (TN + FP) if (TN + FP) != 0 else 0
        precision = TP / (TP + FP) if (TP + FP) != 0 else 0
        f1_score[i] = 2 * (precision * sensitivity[i]) / (precision + sensitivity[i]) if (precision + sensitivity[i]) != 0 else 0
    sensitivity = (sensitivity[0]+sensitivity[1])/2
    specificity = (specificity[0]+specificity[1])/2    
    f1_score = (f1_score[0]+f1_score[1])/2
    return sensitivity, specificity, f1_score
student_model = StudentNetwork(num_classes=2).cuda() 
test_folder_path = './syn_vision_dataset/test/' 
sens=[]
spec=[]
f1=[]
acc=[]
with open('30_results.csv', mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['N', 'Sensitivity', 'Specificity', 'Accuracy', 'Macro F1'])
    for i in range(71, 101):
        c_path = './checkpoints/' + str(i) + '_best_model.pth'
        load_checkpoint(student_model, checkpoint_path=c_path) 
        true_labels, predicted_labels = evaluate_folder(student_model, test_folder_path, transform, device='cuda')
        report = classification_report(true_labels, predicted_labels, target_names=[str(i) for i in range(2)], digits=4)
        print(report)
        conf_matrix = confusion_matrix(true_labels, predicted_labels)
        sensitivity, specificity, f1_score = calculate_sensitivity_specificity(conf_matrix)
        accuracy = np.trace(conf_matrix) / np.sum(conf_matrix)
        print(f"N {i}")
        print("Sensitivity: ", sensitivity)
        print("Specificity: ", specificity)
        print("Accuracy: ", accuracy)
        print("Macro F1: ", f1_score)
        writer.writerow([i, sensitivity, specificity, accuracy, f1_score])

  checkpoint = torch.load(checkpoint_path)


Checkpoint loaded from epoch 71
              precision    recall  f1-score   support

           0     1.0000    0.9762    0.9880        84
           1     0.9636    1.0000    0.9815        53

    accuracy                         0.9854       137
   macro avg     0.9818    0.9881    0.9847       137
weighted avg     0.9859    0.9854    0.9854       137

N 71
Sensitivity:  0.9818181818181818
Specificity:  0.9818181818181818
Accuracy:  0.9854014598540146
Macro F1:  0.9847166443551986
Checkpoint loaded from epoch 72


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     1.0000    0.9762    0.9880        84
           1     0.9636    1.0000    0.9815        53

    accuracy                         0.9854       137
   macro avg     0.9818    0.9881    0.9847       137
weighted avg     0.9859    0.9854    0.9854       137

N 72
Sensitivity:  0.9818181818181818
Specificity:  0.9818181818181818
Accuracy:  0.9854014598540146
Macro F1:  0.9847166443551986
Checkpoint loaded from epoch 73


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9762    0.9762    0.9762        84
           1     0.9623    0.9623    0.9623        53

    accuracy                         0.9708       137
   macro avg     0.9692    0.9692    0.9692       137
weighted avg     0.9708    0.9708    0.9708       137

N 73
Sensitivity:  0.9692273135669363
Specificity:  0.9692273135669363
Accuracy:  0.9708029197080292
Macro F1:  0.9692273135669363
Checkpoint loaded from epoch 74


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9762    0.9762    0.9762        84
           1     0.9623    0.9623    0.9623        53

    accuracy                         0.9708       137
   macro avg     0.9692    0.9692    0.9692       137
weighted avg     0.9708    0.9708    0.9708       137

N 74
Sensitivity:  0.9692273135669363
Specificity:  0.9692273135669363
Accuracy:  0.9708029197080292
Macro F1:  0.9692273135669363
Checkpoint loaded from epoch 75


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9880    0.9762    0.9820        84
           1     0.9630    0.9811    0.9720        53

    accuracy                         0.9781       137
   macro avg     0.9755    0.9787    0.9770       137
weighted avg     0.9783    0.9781    0.9781       137

N 75
Sensitivity:  0.9754573850959393
Specificity:  0.9754573850959393
Accuracy:  0.9781021897810219
Macro F1:  0.9769992724830712
Checkpoint loaded from epoch 76


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     1.0000    0.9762    0.9880        84
           1     0.9636    1.0000    0.9815        53

    accuracy                         0.9854       137
   macro avg     0.9818    0.9881    0.9847       137
weighted avg     0.9859    0.9854    0.9854       137

N 76
Sensitivity:  0.9818181818181818
Specificity:  0.9818181818181818
Accuracy:  0.9854014598540146
Macro F1:  0.9847166443551986
Checkpoint loaded from epoch 77


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9762    0.9762    0.9762        84
           1     0.9623    0.9623    0.9623        53

    accuracy                         0.9708       137
   macro avg     0.9692    0.9692    0.9692       137
weighted avg     0.9708    0.9708    0.9708       137

N 77
Sensitivity:  0.9692273135669363
Specificity:  0.9692273135669363
Accuracy:  0.9708029197080292
Macro F1:  0.9692273135669363
Checkpoint loaded from epoch 78


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9762    0.9762    0.9762        84
           1     0.9623    0.9623    0.9623        53

    accuracy                         0.9708       137
   macro avg     0.9692    0.9692    0.9692       137
weighted avg     0.9708    0.9708    0.9708       137

N 78
Sensitivity:  0.9692273135669363
Specificity:  0.9692273135669363
Accuracy:  0.9708029197080292
Macro F1:  0.9692273135669363
Checkpoint loaded from epoch 79


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9762    0.9762    0.9762        84
           1     0.9623    0.9623    0.9623        53

    accuracy                         0.9708       137
   macro avg     0.9692    0.9692    0.9692       137
weighted avg     0.9708    0.9708    0.9708       137

N 79
Sensitivity:  0.9692273135669363
Specificity:  0.9692273135669363
Accuracy:  0.9708029197080292
Macro F1:  0.9692273135669363
Checkpoint loaded from epoch 80


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9762    0.9762    0.9762        84
           1     0.9623    0.9623    0.9623        53

    accuracy                         0.9708       137
   macro avg     0.9692    0.9692    0.9692       137
weighted avg     0.9708    0.9708    0.9708       137

N 80
Sensitivity:  0.9692273135669363
Specificity:  0.9692273135669363
Accuracy:  0.9708029197080292
Macro F1:  0.9692273135669363
Checkpoint loaded from epoch 81


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9762    0.9762    0.9762        84
           1     0.9623    0.9623    0.9623        53

    accuracy                         0.9708       137
   macro avg     0.9692    0.9692    0.9692       137
weighted avg     0.9708    0.9708    0.9708       137

N 81
Sensitivity:  0.9692273135669363
Specificity:  0.9692273135669363
Accuracy:  0.9708029197080292
Macro F1:  0.9692273135669363
Checkpoint loaded from epoch 82


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9762    0.9762    0.9762        84
           1     0.9623    0.9623    0.9623        53

    accuracy                         0.9708       137
   macro avg     0.9692    0.9692    0.9692       137
weighted avg     0.9708    0.9708    0.9708       137

N 82
Sensitivity:  0.9692273135669363
Specificity:  0.9692273135669363
Accuracy:  0.9708029197080292
Macro F1:  0.9692273135669363
Checkpoint loaded from epoch 83


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9762    0.9762    0.9762        84
           1     0.9623    0.9623    0.9623        53

    accuracy                         0.9708       137
   macro avg     0.9692    0.9692    0.9692       137
weighted avg     0.9708    0.9708    0.9708       137

N 83
Sensitivity:  0.9692273135669363
Specificity:  0.9692273135669363
Accuracy:  0.9708029197080292
Macro F1:  0.9692273135669363
Checkpoint loaded from epoch 84


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9880    0.9762    0.9820        84
           1     0.9630    0.9811    0.9720        53

    accuracy                         0.9781       137
   macro avg     0.9755    0.9787    0.9770       137
weighted avg     0.9783    0.9781    0.9781       137

N 84
Sensitivity:  0.9754573850959393
Specificity:  0.9754573850959393
Accuracy:  0.9781021897810219
Macro F1:  0.9769992724830712
Checkpoint loaded from epoch 85


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9880    0.9762    0.9820        84
           1     0.9630    0.9811    0.9720        53

    accuracy                         0.9781       137
   macro avg     0.9755    0.9787    0.9770       137
weighted avg     0.9783    0.9781    0.9781       137

N 85
Sensitivity:  0.9754573850959393
Specificity:  0.9754573850959393
Accuracy:  0.9781021897810219
Macro F1:  0.9769992724830712
Checkpoint loaded from epoch 86


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9880    0.9762    0.9820        84
           1     0.9630    0.9811    0.9720        53

    accuracy                         0.9781       137
   macro avg     0.9755    0.9787    0.9770       137
weighted avg     0.9783    0.9781    0.9781       137

N 86
Sensitivity:  0.9754573850959393
Specificity:  0.9754573850959393
Accuracy:  0.9781021897810219
Macro F1:  0.9769992724830712
Checkpoint loaded from epoch 87


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9762    0.9762    0.9762        84
           1     0.9623    0.9623    0.9623        53

    accuracy                         0.9708       137
   macro avg     0.9692    0.9692    0.9692       137
weighted avg     0.9708    0.9708    0.9708       137

N 87
Sensitivity:  0.9692273135669363
Specificity:  0.9692273135669363
Accuracy:  0.9708029197080292
Macro F1:  0.9692273135669363
Checkpoint loaded from epoch 88


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9762    0.9762    0.9762        84
           1     0.9623    0.9623    0.9623        53

    accuracy                         0.9708       137
   macro avg     0.9692    0.9692    0.9692       137
weighted avg     0.9708    0.9708    0.9708       137

N 88
Sensitivity:  0.9692273135669363
Specificity:  0.9692273135669363
Accuracy:  0.9708029197080292
Macro F1:  0.9692273135669363
Checkpoint loaded from epoch 89


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9880    0.9762    0.9820        84
           1     0.9630    0.9811    0.9720        53

    accuracy                         0.9781       137
   macro avg     0.9755    0.9787    0.9770       137
weighted avg     0.9783    0.9781    0.9781       137

N 89
Sensitivity:  0.9754573850959393
Specificity:  0.9754573850959393
Accuracy:  0.9781021897810219
Macro F1:  0.9769992724830712
Checkpoint loaded from epoch 90


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9880    0.9762    0.9820        84
           1     0.9630    0.9811    0.9720        53

    accuracy                         0.9781       137
   macro avg     0.9755    0.9787    0.9770       137
weighted avg     0.9783    0.9781    0.9781       137

N 90
Sensitivity:  0.9754573850959393
Specificity:  0.9754573850959393
Accuracy:  0.9781021897810219
Macro F1:  0.9769992724830712
Checkpoint loaded from epoch 91


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9880    0.9762    0.9820        84
           1     0.9630    0.9811    0.9720        53

    accuracy                         0.9781       137
   macro avg     0.9755    0.9787    0.9770       137
weighted avg     0.9783    0.9781    0.9781       137

N 91
Sensitivity:  0.9754573850959393
Specificity:  0.9754573850959393
Accuracy:  0.9781021897810219
Macro F1:  0.9769992724830712
Checkpoint loaded from epoch 92


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     1.0000    0.9762    0.9880        84
           1     0.9636    1.0000    0.9815        53

    accuracy                         0.9854       137
   macro avg     0.9818    0.9881    0.9847       137
weighted avg     0.9859    0.9854    0.9854       137

N 92
Sensitivity:  0.9818181818181818
Specificity:  0.9818181818181818
Accuracy:  0.9854014598540146
Macro F1:  0.9847166443551986
Checkpoint loaded from epoch 93


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9762    0.9762    0.9762        84
           1     0.9623    0.9623    0.9623        53

    accuracy                         0.9708       137
   macro avg     0.9692    0.9692    0.9692       137
weighted avg     0.9708    0.9708    0.9708       137

N 93
Sensitivity:  0.9692273135669363
Specificity:  0.9692273135669363
Accuracy:  0.9708029197080292
Macro F1:  0.9692273135669363
Checkpoint loaded from epoch 94


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     1.0000    0.9762    0.9880        84
           1     0.9636    1.0000    0.9815        53

    accuracy                         0.9854       137
   macro avg     0.9818    0.9881    0.9847       137
weighted avg     0.9859    0.9854    0.9854       137

N 94
Sensitivity:  0.9818181818181818
Specificity:  0.9818181818181818
Accuracy:  0.9854014598540146
Macro F1:  0.9847166443551986
Checkpoint loaded from epoch 95


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     1.0000    0.9762    0.9880        84
           1     0.9636    1.0000    0.9815        53

    accuracy                         0.9854       137
   macro avg     0.9818    0.9881    0.9847       137
weighted avg     0.9859    0.9854    0.9854       137

N 95
Sensitivity:  0.9818181818181818
Specificity:  0.9818181818181818
Accuracy:  0.9854014598540146
Macro F1:  0.9847166443551986
Checkpoint loaded from epoch 96


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     1.0000    0.9762    0.9880        84
           1     0.9636    1.0000    0.9815        53

    accuracy                         0.9854       137
   macro avg     0.9818    0.9881    0.9847       137
weighted avg     0.9859    0.9854    0.9854       137

N 96
Sensitivity:  0.9818181818181818
Specificity:  0.9818181818181818
Accuracy:  0.9854014598540146
Macro F1:  0.9847166443551986
Checkpoint loaded from epoch 97


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     1.0000    0.9762    0.9880        84
           1     0.9636    1.0000    0.9815        53

    accuracy                         0.9854       137
   macro avg     0.9818    0.9881    0.9847       137
weighted avg     0.9859    0.9854    0.9854       137

N 97
Sensitivity:  0.9818181818181818
Specificity:  0.9818181818181818
Accuracy:  0.9854014598540146
Macro F1:  0.9847166443551986
Checkpoint loaded from epoch 98


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     1.0000    0.9762    0.9880        84
           1     0.9636    1.0000    0.9815        53

    accuracy                         0.9854       137
   macro avg     0.9818    0.9881    0.9847       137
weighted avg     0.9859    0.9854    0.9854       137

N 98
Sensitivity:  0.9818181818181818
Specificity:  0.9818181818181818
Accuracy:  0.9854014598540146
Macro F1:  0.9847166443551986
Checkpoint loaded from epoch 99


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     0.9880    0.9762    0.9820        84
           1     0.9630    0.9811    0.9720        53

    accuracy                         0.9781       137
   macro avg     0.9755    0.9787    0.9770       137
weighted avg     0.9783    0.9781    0.9781       137

N 99
Sensitivity:  0.9754573850959393
Specificity:  0.9754573850959393
Accuracy:  0.9781021897810219
Macro F1:  0.9769992724830712
Checkpoint loaded from epoch 100


  checkpoint = torch.load(checkpoint_path)


              precision    recall  f1-score   support

           0     1.0000    0.9762    0.9880        84
           1     0.9636    1.0000    0.9815        53

    accuracy                         0.9854       137
   macro avg     0.9818    0.9881    0.9847       137
weighted avg     0.9859    0.9854    0.9854       137

N 100
Sensitivity:  0.9818181818181818
Specificity:  0.9818181818181818
Accuracy:  0.9854014598540146
Macro F1:  0.9847166443551986


In [3]:
f1


[]