In [1]:
import torch
import torch.nn as nn
import os
from torchvision import transforms, models
from PIL import Image
from sklearn.metrics import classification_report
import numpy as np

# Student Model
class StudentNetwork(nn.Module):
    def __init__(self, num_classes=10):
        super(StudentNetwork, self).__init__()
        self.backbone = models.mobilenet_v2(pretrained=True)        
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.backbone.classifier[1] = nn.Linear(self.backbone.last_channel, 4)        
#         self.fc_out = nn.Linear(512, num_classes)
    def forward(self, x):
        out = self.backbone(x)        
#         out = self.fc_out(out)
        return out

# Function to load the checkpoint and prepare the model for inference
def load_checkpoint(model, checkpoint_path='best_model.pth'):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()  # Set model to evaluation mode
    print(f"Checkpoint loaded from epoch {checkpoint['epoch']}")

# Image preprocessing function
def preprocess_image(image_path, transform):
    img = Image.open(image_path).convert('RGB')
    img = transform(img).unsqueeze(0)  # Add batch dimension
    return img

# Function to predict image class
def predict_image(model, image_path, transform, device='cuda'):
    img = preprocess_image(image_path, transform).to(device)  # Move image to GPU/CPU
    
    # Make prediction
    with torch.no_grad():
        output = model(img)
    
    # Get predicted class
    _, predicted_class = torch.max(output, 1)
    return predicted_class.item()

# Function to evaluate the model on the entire folder
def evaluate_folder(model, folder_path, transform, device='cuda'):
    all_preds = []
    all_labels = []

    # Iterate over all subfolders in the test folder
    for class_folder in os.listdir(folder_path):
        class_folder_path = os.path.join(folder_path, class_folder)
        
        if os.path.isdir(class_folder_path):  # Only process subfolders (class directories)
            for img_name in os.listdir(class_folder_path):
                img_path = os.path.join(class_folder_path, img_name)
                
                # Get the true label from the folder name (class name)
                true_label = int(class_folder)  # Assuming folder names are integer labels
                
                # Predict the class of the image
                predicted_class = predict_image(model, img_path, transform, device)
                
                # Store the true labels and predicted labels
                all_labels.append(true_label)
                all_preds.append(predicted_class)

    return all_labels, all_preds

# Define image transformation (same as during training)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to match model input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Main function to run inference and generate classification report
if __name__ == '__main__':
    # Load the model and checkpoint
    student_model = StudentNetwork(num_classes=4).cuda()  # Match the number of classes
    load_checkpoint(student_model, checkpoint_path='best_model.pth')  # Path to the saved checkpoint
    
    # Path to the folder containing the test images
    test_folder_path = './syn_vision_dataset/test/'  # Replace with the actual path
    
    # Get the true labels and predictions for all images in the folder
    true_labels, predicted_labels = evaluate_folder(student_model, test_folder_path, transform, device='cuda')
    
    # Generate and print the classification report
    report = classification_report(true_labels, predicted_labels, target_names=[str(i) for i in range(4)])  # Adjust target names if needed
    print(report)


  checkpoint = torch.load(checkpoint_path)


Checkpoint loaded from epoch 51
              precision    recall  f1-score   support

           0       0.75      1.00      0.86         3
           1       1.00      0.67      0.80         6
           2       1.00      1.00      1.00         6
           3       0.67      1.00      0.80         2

    accuracy                           0.88        17
   macro avg       0.85      0.92      0.86        17
weighted avg       0.92      0.88      0.88        17

