In [None]:
import os
import torch
import timm
from sklearn.metrics import f1_score
from torchvision import transforms
from PIL import Image

# Function to evaluate the model on a given dataset directory
def evaluate_model(model, data_dir, class_names):
    model.eval()  # Set the model to evaluation mode
    all_preds = []
    all_labels = []
    total_images = 0  # To count total images in the directory

    # Loop over both class folders
    for class_idx, class_name in enumerate(class_names):
        class_dir = os.path.join(data_dir, class_name)
        for image_name in os.listdir(class_dir):
            total_images += 1  # Count total images
            image_path = os.path.join(class_dir, image_name)
            prediction = predict_image(image_path, model, class_names)

            # Append prediction and the actual label (class_idx)
            all_preds.append(class_names.index(prediction))
            all_labels.append(class_idx)

    # Compute F1 score
    f1 = f1_score(all_labels, all_preds)

    return f1, total_images

# Function to predict a single image's class
def predict_image(image_path, model, class_names):
    # Load the image
    img = Image.open(image_path)

    # Preprocess the image
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize image to 224x224 pixels
        transforms.ToTensor(),  # Convert image to tensor
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize pixel values
    ])

    img = transform(img).unsqueeze(0)  # Add batch dimension
    img = img.to(device)

    # Model prediction
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation for efficiency
        outputs = model(img)
        _, preds = torch.max(outputs, 1)  # Get the predicted class index

    return class_names[preds.item()]  # Return predicted class name

# Load transformer model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define model path and load the fine-tuned model
model_path = 'models/swin_large_fine-tuned.pth'
model = timm.create_model('swin_large_patch4_window7_224', pretrained=False, num_classes=2)
model.load_state_dict(torch.load(model_path))
model = model.to(device)

# Define class names
class_names = ['class_0', 'class_1']

# Specify the single directory path to evaluate
data_dir = "TEST_DIR"

# Evaluate the model on the specified directory
f1, total_images = evaluate_model(model, data_dir, class_names)

# Print results
print(f"F1 score for {data_dir}: {f1:.4f}, Total images: {total_images}")
