In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import ViTForImageClassification, ViTImageProcessor

from spottheplace.ml import GradCam
from spottheplace.ml.utils import AddMask

# Classification Task

## Example usage for ResNet model

In [None]:
# ResNet50 model for image classification
MODEL_PATH = hf_hub_download(
    repo_id="titouanlegourrierec/SpotThePlace",
    filename="Classification_ResNet50_4countries.pth"
    )

IMAGE_PATH = "" # Path to the image

# Define the model and load the weights
model = models.resnet50()
model.fc = nn.Linear(model.fc.in_features, 4) # 4 countries
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'), weights_only=False))
model.eval()
class_labels = {0: 'France', 1: 'Japan', 2: 'Mexico', 3: 'South Africa'} # Define the class labels

# Load the image and apply the transformations
image = Image.open(IMAGE_PATH).convert('RGB')
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    AddMask(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = transform(image).unsqueeze(0)

# Predict the class of the image
with torch.no_grad():
    output = model(input_tensor)
class_idx = output.argmax(dim=1).item()

# Print the predicted class
output_class = class_labels[class_idx]
print("Predicted class:", output_class)

# Print the probability of the predicted class
output_probs = F.softmax(output, dim=1).detach().numpy().squeeze()
print(f"Probability of the predicted class: {output_probs[class_idx]:.3f}")
print(f"Probability of each class: {[round(prob, 3) for prob in output_probs]}")

In [None]:
# Generate the Grad-CAM
grad_cam = GradCam(MODEL_PATH)
grad_cam.explain(IMAGE_PATH)

## Example Usage for ViT model

In [None]:
# ViT model for image classification
MODEL_PATH = hf_hub_download(
    repo_id="titouanlegourrierec/SpotThePlace",
    filename="Classification_ViT_4countries.pth"
    )

IMAGE_PATH = "" # Path to the image

# Define the model and load the weights
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=4)
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'), weights_only=False))
model.eval()
class_labels = {0: 'France', 1: 'Japan', 2: 'Mexico', 3: 'South Africa'} # Define the class labels

# Load the image and apply the transformations
image = Image.open(IMAGE_PATH).convert('RGB')
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
inputs = processor(images=image, return_tensors="pt")
input_tensor = inputs['pixel_values']

# Predict the class of the image
with torch.no_grad():
    output = model(input_tensor)
class_idx = output.logits.argmax(dim=1).item()

# Print the predicted class
output_class = class_labels[class_idx]
print("Predicted class:", output_class)

# Print the probability of the predicted class
output_probs = F.softmax(output.logits, dim=1).detach().numpy().squeeze()
print(f"Probability of the predicted class: {output_probs[class_idx]:.3f}")
print(f"Probability of each class: {[round(prob, 3) for prob in output_probs]}")

# Regression Task

In [None]:
# ResNet50 model for image regression
MODEL_PATH = hf_hub_download(
    repo_id="titouanlegourrierec/SpotThePlace",
    filename="Regression_ResNet50_4countries.pth"
    )

IMAGE_PATH = "" # Path to the image

# Define the model and load the weights
model = models.resnet50()
model.fc = nn.Linear(model.fc.in_features, 2) # 4 countries
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'), weights_only=False))
model.eval()

# Load the image and apply the transformations
image = Image.open(IMAGE_PATH).convert('RGB')
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    AddMask(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = transform(image).unsqueeze(0)

# Predict the longitude and latitude of the image
with torch.no_grad():
    output = model(input_tensor)

print("Predicted longitude:", output[0][0].item())
print("Predicted latitude:", output[0][1].item())