In [5]:
import torch
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import os
from tqdm import tqdm
import json

# -------------------------------------------------------------------
# Paths
# -------------------------------------------------------------------
cwd = os.getcwd()
if os.path.isdir(os.path.join(cwd, 'data', 'training')):
    project_root = cwd
elif os.path.isdir(os.path.join(cwd, '..', 'data', 'training')):
    project_root = os.path.abspath(os.path.join(cwd, '..'))
else:
    raise RuntimeError("Impossibile trovare la cartella 'data/training'.")

root_dir_train = os.path.join(project_root, 'data', 'training')
root_dir_test  = os.path.join(project_root, 'data', 'test')

# -------------------------------------------------------------------
# 1) Data loader
# -------------------------------------------------------------------
def get_data(batch_size, test_batch_size=16, num_workers=2, mean=None, std=None, num_train_samples=None):
    target_size = (224, 224)

    # 1a) compute mean/std if needed
    if mean is None or std is None:
        tmp_tf = transforms.Compose([
            transforms.Resize(target_size),
            transforms.ToTensor()
        ])
        tmp_ds = ImageFolder(root_dir_train, transform=tmp_tf)
        imgs = torch.stack([img for img, _ in tmp_ds], dim=0)
        mean = float(imgs.mean())
        std  = float(imgs.std())

    # 1b) transforms
    tf = transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[mean]*3, std=[std]*3)  # 3 canali RGB
    ])

    # 1c) datasets
    full_train_ds = ImageFolder(root_dir_train, transform=tf)
    test_ds       = ImageFolder(root_dir_test,  transform=tf)

    # 1d) split
    N = len(full_train_ds)
    train_N = int(N * 0.7) if num_train_samples is None else num_train_samples
    val_N   = N - train_N
    train_ds, val_ds = torch.utils.data.random_split(full_train_ds, [train_N, val_N])

    # 1e) loaders
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=num_workers)
    val_loader   = DataLoader(val_ds,   batch_size=test_batch_size, shuffle=False, num_workers=num_workers)
    test_loader  = DataLoader(test_ds,  batch_size=test_batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader, test_loader

# -------------------------------------------------------------------
# 2) Build GoogLeNet feature extractor
# -------------------------------------------------------------------
def build_googlenet_extractor(device='cuda'):
    weights = torchvision.models.GoogLeNet_Weights.DEFAULT
    model = torchvision.models.googlenet(weights=weights, aux_logits=True)  # <-- obbligatorio
    model.aux1 = torch.nn.Identity()   # disattiva classificatore ausiliario 1
    model.aux2 = torch.nn.Identity()   # disattiva classificatore ausiliario 2
    model.fc = torch.nn.Identity()     # ottieni solo il vettore 1024-d
    model = model.to(device)
    model.eval()
    return model


# -------------------------------------------------------------------
# 3) Extract embeddings
# -------------------------------------------------------------------
@torch.no_grad()
def extract_embeddings(loader, model, device='cuda'):
    embs = []
    for imgs, _ in tqdm(loader, desc="Extracting embeddings"):
        imgs = imgs.to(device)
        feats = model(imgs)  # [B, 1024]
        embs.append(feats.cpu())
    return torch.cat(embs, dim=0)  # [N, 1024]

# -------------------------------------------------------------------
# 4) Top-k retrieval
# -------------------------------------------------------------------
def retrieve_topk(query_embs, gallery_embs, query_paths, gallery_paths, k=5):
    query_embs   = F.normalize(query_embs, dim=1)
    gallery_embs = F.normalize(gallery_embs, dim=1)
    sim = query_embs @ gallery_embs.t()  # cosine similarity
    topk = sim.topk(k + 1, dim=1, largest=True)[1]

    results = []
    for qi, idxs in enumerate(topk):
        neigh = idxs.tolist()[1 : k+1]
        qpath = query_paths[qi][0]
        retrieved = [gallery_paths[i][0] for i in neigh]
        results.append({
            'query':    os.path.basename(qpath),
            'retrieved': [os.path.basename(p) for p in retrieved]
        })
    return results

# -------------------------------------------------------------------
# 5) Main script
# -------------------------------------------------------------------
if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # a) carica i dati
    train_loader, val_loader, test_loader = get_data(
        batch_size=16,
        test_batch_size=32,
        num_workers=2
    )

    # b) costruisci il feature extractor
    feat_ext = build_googlenet_extractor(device)

    # c) estrai embeddings
    test_embs = extract_embeddings(test_loader, feat_ext, device)

    # d) ottieni i path
    test_paths = test_loader.dataset.imgs

    # e) retrieval top-k
    topk_results = retrieve_topk(test_embs, test_embs, test_paths, test_paths, k=5)

    # f) stampa risultati
    print(json.dumps(topk_results, indent=2))



Extracting embeddings: 100%|██████████| 1/1 [00:07<00:00,  7.89s/it]

[
  {
    "query": "n01855672_1037.jpg",
    "retrieved": [
      "n01855672_4393.jpg",
      "n01855672_10973.jpg",
      "n01855672_4197.jpg",
      "painting_085_000084.jpg",
      "painting_085_000118.jpg"
    ]
  },
  {
    "query": "n01855672_4197.jpg",
    "retrieved": [
      "n01855672_10973.jpg",
      "n01855672_4393.jpg",
      "painting_085_000118.jpg",
      "n01855672_1037.jpg",
      "painting_085_000084.jpg"
    ]
  },
  {
    "query": "n01855672_4393.jpg",
    "retrieved": [
      "n01855672_1037.jpg",
      "n01855672_4197.jpg",
      "n01855672_10973.jpg",
      "painting_085_000084.jpg",
      "painting_085_000118.jpg"
    ]
  },
  {
    "query": "painting_085_000045.jpg",
    "retrieved": [
      "painting_085_000118.jpg",
      "4597118805213184.jpg",
      "painting_085_000084.jpg",
      "n01855672_4197.jpg",
      "n01855672_4393.jpg"
    ]
  },
  {
    "query": "painting_085_000084.jpg",
    "retrieved": [
      "painting_085_000118.jpg",
      "n01855672_109


