# Model Testing
This notebook is used to test the model on the test set.

In [None]:
from glob import glob
import os
import torch
from PIL import Image
import matplotlib.pyplot as plt
from transformers import ViTImageProcessor
from transformers import ViTForImageClassification

In [None]:
fine_tuned_model = '/content/drive/MyDrive/models/vit-gender'
processor = ViTImageProcessor.from_pretrained(fine_tuned_model)
model = ViTForImageClassification.from_pretrained(fine_tuned_model)

In [None]:
def classify_image(model, processor, img_path):
    """
    Classify an image using a fine-tuned model.

    Parameters
    ----------
    model : ImageClassification
        A fine-tuned model.
    processor : ImageProcessor
        An ImageProcessor.
    img_path : str
        Path to the image to classify.

    Returns
    -------
    str
        The class label of the image.
    """
    img = Image.open(img_path)
    inputs = processor(img, return_tensors='pt')
    output = model(**inputs)
    proba = output.logits.softmax(1)
    preds = proba.argmax(1)

    # Return the class label as a string
    return 'Divers' if preds.item() == 0 else 'norm-beauty'

In [None]:
# Directories to loop through
folders = [
    '/content/drive/MyDrive/folder/valid/IdealisiertNormschön',
    '/content/drive/MyDrive/folder/valid/Divers'
]

In [None]:
# Calculate accuracy
correct = 0
total = 0

for folder in folders:
    for img_path in glob(os.path.join(folder, '*.jpg')):
        pred = classify_image(model, processor, img_path)
        total += 1
        if folder.split('/')[-1] == pred:
            correct += 1

print(f'Accuracy: {correct / total:.2f}')


In [None]:
# Loop through each directory
for folder in folders:
    # Loop through each file in the directory
    for img_path in glob(os.path.join(folder, '*.jpg')):  # Add more formats if you have, e.g. png, jpeg
        pred_class = classify_image(model, processor, img_path)
        true_class = folder.split('/')[-1]
        image_name = img_path.split('/')[-1]

        # Plot the image and the predicted class label on the title, only if the prediction is wrong
        if true_class != pred_class:
            img = Image.open(img_path)
            plt.imshow(img)
            plt.title(f"The image {image_name} is classified as {pred_class} but is actually {true_class}")
            plt.axis('off')
            plt.savefig(f'/content/drive/MyDrive/folder/wrong-predictions/{image_name}.jpg')
