In [1]:
import torch
import torch.nn as nn
import os
from torchvision import transforms, models
from PIL import Image
from sklearn.metrics import classification_report
import numpy as np
class StudentNetwork(nn.Module):
    def __init__(self, num_classes=10):
        super(StudentNetwork, self).__init__()
        self.backbone = models.mobilenet_v2(pretrained=True)        
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.backbone.classifier[1] = nn.Linear(self.backbone.last_channel, 2)        
    def forward(self, x):
        out = self.backbone(x)        
        return out
def load_checkpoint(model, checkpoint_path='best_model.pth'):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval() 
    print(f"Checkpoint loaded from epoch {checkpoint['epoch']}")
def preprocess_image(image_path, transform):
    img = Image.open(image_path).convert('RGB')
    img = transform(img).unsqueeze(0)  
    return img
def predict_image(model, image_path, transform, device='cuda'):
    img = preprocess_image(image_path, transform).to(device)      
    with torch.no_grad():
        output = model(img)    
    _, predicted_class = torch.max(output, 1)
    return predicted_class.item()
def evaluate_folder(model, folder_path, transform, device='cuda'):
    all_preds = []
    all_labels = []
    for class_folder in os.listdir(folder_path):
        class_folder_path = os.path.join(folder_path, class_folder)
        if os.path.isdir(class_folder_path):  
            for img_name in os.listdir(class_folder_path):
                img_path = os.path.join(class_folder_path, img_name)
                true_label = int(class_folder) 
                predicted_class = predict_image(model, img_path, transform, device)
                all_labels.append(true_label)
                all_preds.append(predicted_class)
    return all_labels, all_preds
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
if __name__ == '__main__':
    student_model = StudentNetwork(num_classes=2).cuda() 
    load_checkpoint(student_model, checkpoint_path='best_model.pth') 
    test_folder_path = './syn_vision_dataset/test/' 
    true_labels, predicted_labels = evaluate_folder(student_model, test_folder_path, transform, device='cuda')
    report = classification_report(true_labels, predicted_labels, target_names=[str(i) for i in range(2)])  
    print(report)

  checkpoint = torch.load(checkpoint_path)


Checkpoint loaded from epoch 4
              precision    recall  f1-score   support

           0       0.84      1.00      0.92        27
           1       1.00      0.71      0.83        17

    accuracy                           0.89        44
   macro avg       0.92      0.85      0.87        44
weighted avg       0.90      0.89      0.88        44

