In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from transformers import ViTForImageClassification, ViTImageProcessor
from geopy.distance import geodesic

from spottheplace import data_to_dataframe
from spottheplace import france_region_to_dataframe
from spottheplace.ml.utils import AddMask

df_4countries = data_to_dataframe("") # path to the test data folder
df_FranceRegions = france_region_to_dataframe("") # path to the test data folder for France

countries_labels = {0: 'France', 1: 'Japan',
                    2: 'Mexico', 3: 'South Africa'}

regions_labels = {0: 'Auvergne-Rhône-Alpes', 1: 'Bourgogne-Franche-Comté',
                  2: 'Bretagne', 3: 'Centre-Val de Loire', 4: 'Corse',
                  5: 'Grand Est', 6: 'Hauts-de-France', 7: 'Normandie',
                  8: 'Nouvelle-Aquitaine', 9: 'Occitanie', 10: 'Pays de la Loire',
                  11: "Provence-Alpes-Côte-d'Azur", 12: 'Île-de-France'}

def compute_metrics(true_labels, predicted_labels):
    metrics = {
        "accuracy": accuracy_score(true_labels, predicted_labels),
        "precision": precision_score(true_labels, predicted_labels, average='weighted'),
        "recall": recall_score(true_labels, predicted_labels, average='weighted'),
        "f1": f1_score(true_labels, predicted_labels, average='weighted'),
        "confusion_matrix": confusion_matrix(true_labels, predicted_labels)
    }
    print("Accuracy:", metrics['accuracy'])
    print("Precision:", metrics["precision"])
    print("Recall:", metrics["recall"])
    print("F1:", metrics["f1"])
    print("Confusion matrix:\n", metrics["confusion_matrix"])
    return metrics

# Classification task

## ResNet Model

### 4 Countries classification (France, Japan, South Africa, Mexico)

In [None]:
MODEL_PATH = ""

# Define the model and load the weights
model = models.resnet50()
model.fc = nn.Linear(model.fc.in_features, 4) # 4 countries
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'), weights_only=False))
model.eval()

# Define the transformation to apply to the images for ResNet
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    AddMask(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

label_pred = []

for index, row in df_4countries.iterrows():
    image_path = row['image_path']
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0)

    with torch.no_grad():
        output = model(input_tensor)

    class_idx = output.argmax(dim=1).item()
    output_class = countries_labels[class_idx]
    label_pred.append(output_class)

df = df_4countries.copy()
df['label_pred'] = label_pred

metrics = compute_metrics(df['label'], df['label_pred'])

### France Regions classification (13 classes)

In [None]:
MODEL_PATH = ""

# Define the model and load the weights
model = models.resnet50()
model.fc = nn.Linear(model.fc.in_features, 13)  # 13 regions
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'), weights_only=False))
model.eval()

# Define the transformation to apply to the images for ResNet
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    AddMask(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

label_pred = []
top3_correct = 0

for index, row in df_FranceRegions.iterrows():
    image_path = row['image_path']
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0)

    with torch.no_grad():
        output = model(input_tensor)

    _, top3_idx = torch.topk(output, 3, dim=1)
    top3_idx = top3_idx.squeeze().tolist()
    output_class = regions_labels[output.argmax(dim=1).item()]
    true_label = row['label']
    label_pred.append(output_class)

    # Compute top-3 accuracy for the regions
    if true_label in [regions_labels[idx] for idx in top3_idx]:
        top3_correct += 1

df = df_FranceRegions.copy()
df['label_pred'] = label_pred

metrics = compute_metrics(df['label'], df['label_pred'])
top3_accuracy = top3_correct / len(df)
print(f"Top-3 Accuracy: {top3_accuracy:.4f}")

## Vision Transformer Model

### 4 Countries classification (France, Japan, South Africa, Mexico)

In [None]:
MODEL_PATH = ""

# Define the model and load the weights
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=4)
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'), weights_only=False))
model.eval()

# Define the image processor for the ViT model from Hugging Face
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

label_pred = []

for index, row in df_4countries.iterrows():
    image_path = row['image_path']
    image = Image.open(image_path).convert('RGB')
    inputs = processor(images=image, return_tensors="pt")
    input_tensor = inputs['pixel_values']

    with torch.no_grad():
        output = model(input_tensor)

    class_idx = output.logits.argmax(dim=1).item()
    output_class = countries_labels[class_idx]
    print(f"Prediction: {output_class}", f"True label: {row['label']}")

    label_pred.append(output_class)

df = df_4countries.copy()
df['label_pred'] = label_pred

metrics = compute_metrics(df['label'], df['label_pred'])

# Regression task

### 4 Countries classification (France, Japan, South Africa, Mexico)

In [None]:
MODEL_PATH = ""

# Define the model and load the weights
model = models.resnet50()
model.fc = nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'), weights_only=False))
model.eval()

# Define the transformation to apply to the images for ResNet
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    AddMask(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

long_preds = []
lat_preds = []
dist_pred = []

for index, row in df_FranceRegions.iterrows():
    image_path = row['image_path']
    long, lat = row['long'], row['lat']
    
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0)

    with torch.no_grad():
        output = model(input_tensor)

    long_preds.append(max(min(output[0][0].item(), 180), -180))
    lat_preds.append(max(min(output[0][1].item(), 90), -90))

    distance = geodesic((lat, long), (lat_preds[-1], long_preds[-1])).kilometers
    dist_pred.append(distance)

df = df_FranceRegions.copy()
df['long_pred'] = long_preds
df['lat_pred'] = lat_preds
df['dist_pred'] = dist_pred

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# Calculate the errors
df['long_error'] = df['long'] - df['long_pred']
df['lat_error'] = df['lat'] - df['lat_pred']

# Configure subplots
fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=False)

# Distribution of longitude error
sns.histplot(df['long_error'], kde=True, ax=axes[0], color='blue')
axes[0].set_title('Distribution of Longitude Error')
axes[0].set_xlabel('Longitude Error (°)')
axes[0].set_ylabel('Frequency')

# Distribution of latitude error
sns.histplot(df['lat_error'], kde=True, ax=axes[1], color='green')
axes[1].set_title('Distribution of Latitude Error')
axes[1].set_xlabel('Latitude Error (°)')
axes[1].set_ylabel('')

# Distribution of distance error
sns.histplot(df['dist_pred'], kde=True, ax=axes[2], color='red')
axes[2].set_title('Distribution of Distance Error')
axes[2].set_xlabel('Distance Error (km)')
axes[2].set_ylabel('')

plt.tight_layout()
plt.show()

print("Mean Longitude Error:", round(df['long_error'].mean(), 3), "degrees")
print("Mean Latitude Error:", round(df['lat_error'].mean(), 3), "degrees")
print("Mean Distance Error:", int(df['dist_pred'].mean()), "km")