In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image

from spottheplace.ml import GradCam
from spottheplace.ml.utils import AddMask

MODEL_PATH = hf_hub_download(
    repo_id="titouanlegourrierec/SpotThePlace",
    filename="Classification_ResNet_4countries.pth"
    ) # Path to the model (downloaded from Hugging Face Hub)

IMAGE_PATH = "" # Path to the image

In [None]:
# Define the model and load the weights
model = models.resnet50()
model.fc = nn.Linear(model.fc.in_features, 4) # 4 countries
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'), weights_only=False))
model.eval()
class_labels = {0: 'France', 1: 'Japan', 2: 'Mexico', 3: 'South Africa'} # Define the class labels

# Load the image and apply the transformations
image = Image.open(IMAGE_PATH).convert('RGB')
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    AddMask(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = transform(image).unsqueeze(0)

# Predict the class of the image
output = model(input_tensor)
class_idx = output.argmax(dim=1).item()

# Print the predicted class
output_class = class_labels[class_idx]
print("Predicted class:", output_class)

# Print the probability of the predicted class
output_probs = F.softmax(output, dim=1).detach().numpy().squeeze()
print(f"Probability of the predicted class: {output_probs[class_idx]:.3f}")
print(f"Probability of each class: {[round(prob, 3) for prob in output_probs]}")

In [None]:
# Generate the Grad-CAM
grad_cam = GradCam(MODEL_PATH)
grad_cam.explain(IMAGE_PATH)