In [2]:
import torch
from torch import nn
import torch.optim as optim
from train.train import train
from train.dataloader import GeoDataLoader, img_train_transform
from torch.utils.data import DataLoader
from geoclip import GeoCLIP
import os
import wandb
import random
from datetime import datetime

In [3]:
dataset_file = os.path.expanduser("~/mnt/cluster_storage/ai_geolocation/combined_train_geolocations.csv")
dataset_folder = "~/mnt/cluster_storage/ai_geolocation"
batch_size = 32

train_dataset = GeoDataLoader(dataset_file, dataset_folder, transform=img_train_transform())
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

Loading image paths and coordinates: 0it [00:00, ?it/s]

Loading image paths and coordinates: 11024it [00:00, 22055.20it/s]

File not found: /home/ray/mnt/cluster_storage/ai_geolocation/googlestreetview/googlestreetview_all_imgs/Kyiv_50.43660429_30.52477515_heading90.jpg
File not found: /home/ray/mnt/cluster_storage/ai_geolocation/googlestreetview/googlestreetview_all_imgs/Kyiv_50.45544446_30.52440171_heading270.jpg
File not found: /home/ray/mnt/cluster_storage/ai_geolocation/googlestreetview/googlestreetview_all_imgs/Kharkiv_50.00115190_36.30741662_heading180.jpg
File not found: /home/ray/mnt/cluster_storage/ai_geolocation/googlestreetview/googlestreetview_all_imgs/Kharkiv_49.98259653_36.30778771_heading180.jpg
File not found: /home/ray/mnt/cluster_storage/ai_geolocation/googlestreetview/googlestreetview_all_imgs/Kharkiv_49.98948679_36.30714889_heading180.jpg


Loading image paths and coordinates: 15475it [00:00, 22172.05it/s]

File not found: /home/ray/mnt/cluster_storage/ai_geolocation/googlestreetview/googlestreetview_all_imgs/Kyiv_50.40888245_30.52532464_heading0.jpg
File not found: /home/ray/mnt/cluster_storage/ai_geolocation/googlestreetview/googlestreetview_all_imgs/Kyiv_50.37440137_30.52299690_heading0.jpg
File not found: /home/ray/mnt/cluster_storage/ai_geolocation/googlestreetview/googlestreetview_all_imgs/Kyiv_50.42963426_30.52491331_heading270.jpg
File not found: /home/ray/mnt/cluster_storage/ai_geolocation/googlestreetview/googlestreetview_all_imgs/Kyiv_50.45012498_30.52450715_heading0.jpg


Loading image paths and coordinates: 19827it [00:00, 22084.12it/s]

File not found: /home/ray/mnt/cluster_storage/ai_geolocation/eyesonrussia/eyesonrussia_imgs/img10126_0.jpg
File not found: /home/ray/mnt/cluster_storage/ai_geolocation/eyesonrussia/eyesonrussia_imgs/img2080_1.jpg
File not found: /home/ray/mnt/cluster_storage/ai_geolocation/eyesonrussia/eyesonrussia_imgs/img3441_1.jpg
Total images found: 19815





In [4]:
# Initialize model

step_size = 30
lr = 0.0001
num_epochs = 10
gamma = 0.1
model = GeoCLIP(from_pretrained=True)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# wandb setup
wandb.init(
    project="ai-geolocation",
    config={
        "learning_rate": {lr},
        "step_size": {step_size},
        "epochs":{num_epochs},
        "optimizer_gamma":{gamma}
    }
)

  self.load_state_dict(torch.load(f"{file_dir}/weights/location_encoder_weights.pth"))
  self.image_encoder.mlp.load_state_dict(torch.load(f"{self.weights_folder}/image_encoder_mlp_weights.pth"))
  self.location_encoder.load_state_dict(torch.load(f"{self.weights_folder}/location_encoder_weights.pth"))
  self.logit_scale = nn.Parameter(torch.load(f"{self.weights_folder}/logit_scale_weights.pth"))
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mselena-sun[0m ([33mvannevar[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
# Train
os.makedirs("snapshots", exist_ok=True)

for epoch in range(num_epochs):
    train(train_dataloader, model, optimizer, epoch, batch_size, device, scheduler=scheduler)

    snapshot_dir = f"snapshots/epoch_{epoch}"
    os.makedirs(snapshot_dir, exist_ok=True)
    
    print(f"Saved snapshot for epoch {epoch}")

# Save fine-tuned weights
current_time = datetime.now().strftime("%m-%d-%H:%M")
torch.save(model.image_encoder.mlp.state_dict(), f"fine_tuned_image_encoder_mlp_weights_{current_time}.pth")
torch.save(model.location_encoder.state_dict(), f"fine_tuned_location_encoder_weights_{current_time}.pth")
torch.save(model.logit_scale, f"fine_tuned_logit_scale_weights_{current_time}.pth")
wandb.finish()

Starting Epoch 0


Epoch 0 loss: 6.18077:  97%|█████████▋| 600/620 [13:58<00:27,  1.39s/it]