In [4]:
import torch
from Retrival_Dino_Salad.model import SaladFaissGPSDB
from preprocess import CampusGPSDataset
import torchvision.transforms.v2 as v2


# Extract Train Test

In [None]:
from torch.utils.data import DataLoader, random_split 
transform = v2.Compose([
    v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC), # Resizes smaller edge to 224
    v2.CenterCrop(224), # Crops a 224x224 square from the center
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

full_dataset = CampusGPSDataset(csv_path="data\photo_locations.csv", image_dir="data\indexed_photos", transform=transform)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Train Phase

In [6]:
# 1) Load model
model = torch.hub.load("serizba/salad", "dinov2_salad")
db = SaladFaissGPSDB(model, use_cosine=True, normalize=True)
# 2) Build database from train loader (must yield images + gps tensor [B,2])
db.build_from_loader(train_loader)

# 3) Save index + gps/meta
db.save("salad_faiss_db")


Using cache found in C:\Users\user1/.cache\torch\hub\serizba_salad_main
Using cache found in C:\Users\user1/.cache\torch\hub\facebookresearch_dinov2_main
2it [00:06,  3.04s/it]


In [14]:
# 4) Load later
db2 = SaladFaissGPSDB.load("salad_faiss_db", model=model)

# 5) Predict GPS for a new image tensor [3,H,W]
query_img, query_gps = val_dataset[0]  # just an example
pred_lat, pred_lon = db2.predict_gps(query_img, k=5, weighted=True)

# # 6) Or inspect top matches (for debugging)
# matches = db2.query_image(query_img, k=5)
# for m in matches:
#     print(m.idx, m.score, m.gps)

print("real_gpez:", query_gps.tolist())
print("pred_gpez:", pred_lat, pred_lon)


real_gpez: [31.2620849609375, 34.80344009399414]
pred_gpez: 31.26202450055479 34.80343739580511
