In [1]:
import torch
from Retrival_Dino_Salad.model import SaladFaissGPSDB
from preprocess import CampusGPSDataset
import torchvision.transforms.v2 as v2


# Extract Train Test

In [7]:
from torch.utils.data import DataLoader, random_split 
transform = v2.Compose([
    v2.Resize((4004, 3010), interpolation=v2.InterpolationMode.BILINEAR),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

full_dataset = CampusGPSDataset(csv_path="data\photo_locations.csv", image_dir="data\indexed_photos", transform=transform)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Train Phase

In [9]:
# 1) Load model
model = torch.hub.load("serizba/salad", "dinov2_salad")
db = SaladFaissGPSDB(model, use_cosine=True, normalize=True)
# 2) Build database from train loader (must yield images + gps tensor [B,2])
db.build_from_loader(train_loader)

# 3) Save index + gps/meta
db.save("salad_faiss_db")


Using cache found in C:\Users\PC/.cache\torch\hub\serizba_salad_main
Using cache found in C:\Users\PC/.cache\torch\hub\facebookresearch_dinov2_main
0it [00:33, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 4.31 GiB. GPU 0 has a total capacty of 4.00 GiB of which 0 bytes is free. Of the allocated memory 9.92 GiB is allocated by PyTorch, and 116.77 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [5]:
# 4) Load later
db2 = SaladFaissGPSDB.load("salad_faiss_db", model=model)

# 5) Predict GPS for a new image tensor [3,H,W]
query_img, query_gps = val_dataset[0]  # just an example
pred_lat, pred_lon = db2.predict_gps(query_img, k=5, weighted=True)

# # 6) Or inspect top matches (for debugging)
# matches = db2.query_image(query_img, k=5)
# for m in matches:
#     print(m.idx, m.score, m.gps)

print("real_gpez:", query_gps.tolist())
print("pred_gpez:", pred_lat, pred_lon)


real_gpez: [31.262210845947266, 34.80329513549805]
pred_gpez: 31.262016615941867 34.80351222407907
