In [None]:
import os
import json
from PIL import Image
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import faiss
import numpy as np

# Configurazione
gallery_folder = 'Data_example/test/gallery'
query_folder   = 'Data_example/test/query'
top_k = 3

# Carica il modello pre-addestrato
model = models.resnet50(pretrained=True)
model = torch.nn.Sequential(*list(model.children())[:-1])  # Rimuove l'ultima classificazione
model.eval()

# Trasformazioni per le immagini
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Funzione per estrarre le feature
def extract_features(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)
    with torch.no_grad():
        feature = model(image)
    return feature.squeeze().numpy()

# Estrai feature per la galleria
gallery_paths = [os.path.join(gallery_folder, fname) for fname in os.listdir(gallery_folder)]
gallery_features = np.array([extract_features(p) for p in gallery_paths]).astype('float32')

# Costruisci l’indice FAISS
index = faiss.IndexFlatL2(gallery_features.shape[1])
index.add(gallery_features)

# Estrai feature per ogni immagine query e trova le immagini più simili
results = []
query_paths = [os.path.join(query_folder, fname) for fname in os.listdir(query_folder)]

for q_path in query_paths:
    q_feat = extract_features(q_path).astype('float32').reshape(1, -1)
    distances, indices = index.search(q_feat, top_k)
    similar_images = [gallery_paths[i] for i in indices[0]]
    results.append({
        "filename": q_path,
        "gallery_images": similar_images
    })


# Salva o stampa il risultato in formato JSON
with open('submission2.json', 'w') as f:
    json.dump(results, f, indent=2)

print('submission ok')