In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import torch.nn as nn
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix

# Define image transformations for preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224
    transforms.ToTensor(),  # Convert images to PyTorch tensors
])

# Load the test dataset
test_dir = 'TEST_DIR'
test_data = datasets.ImageFolder(root=test_dir, transform=transform)

# Create the DataLoader for test data
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

# Define the device (GPU if available, otherwise CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load the pre-trained model
model_path = 'models/resnet50_MC_fine_tuned.pth'
model = models.resnet50(pretrained=False)  # Don't use pretrained weights
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 6)  # Modify final layer to match the number of classes

# Load model weights
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()  # Set model to evaluation mode

# Initialize lists to store predictions and labels
all_preds, all_labels = [], []

# Disable gradient computation for inference
with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Testing"):  # Add description for progress bar
        inputs, labels = inputs.to(device), labels.to(device)
        
        outputs = model(inputs)  # Forward pass
        _, predicted = torch.max(outputs, 1)  # Get class with highest probability
        
        all_preds.extend(predicted.cpu().numpy())  # Store predictions
        all_labels.extend(labels.cpu().numpy())  # Store ground-truth labels

# Generate classification report and confusion matrix
report = classification_report(all_labels, all_preds, target_names=test_data.classes, output_dict=True)
matrix = confusion_matrix(all_labels, all_preds)

# Extract weighted average F1-score
weighted_avg_f1 = report['weighted avg']['f1-score']

print("Classification Report:")
print(report)
print("Confusion Matrix:")
print(matrix)

# Extract values for binary classification metrics (assuming class index 0 is of interest)
TP = matrix[0, 0]
FN = np.sum(matrix[0, :]) - TP
FP = np.sum(matrix[:, 0]) - TP
TN = np.sum(matrix) - (TP + FN + FP)

# Construct binary confusion matrix
binary_conf_matrix = np.array([[TN, FP], [FN, TP]])

# Compute precision, recall, and F1-score for the specific class
precision = TP / (TP + FP) if (TP + FP) > 0 else 0
recall = TP / (TP + FN) if (TP + FN) > 0 else 0
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

# Print binary classification results
print(f"Binary F1 score: {f1_score:.4f}")
print("Binary Confusion Matrix:")
print(binary_conf_matrix)