In [3]:
import torch
from torch import nn
import torch.optim as optim
from train.train import train
from train.dataloader import GeoDataLoader, img_train_transform
from torch.utils.data import DataLoader
from geoclip import GeoCLIP
import os
import wandb
import random
from datetime import datetime

In [4]:
dataset_file = os.path.expanduser("~/mnt/cluster_storage/ai_geolocation/combined_train_geolocations.csv")
dataset_folder = "~/mnt/cluster_storage/ai_geolocation"
batch_size = 32

train_dataset = GeoDataLoader(dataset_file, dataset_folder, transform=img_train_transform())
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

Loading image paths and coordinates: 0it [00:00, ?it/s]

Loading image paths and coordinates: 15236it [00:00, 22233.25it/s]

Total images found: 15236





In [5]:
# Initialize model

step_size = 30
lr = 0.0001
num_epochs = 10
gamma = 0.1
model = GeoCLIP(from_pretrained=True)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# wandb setup
wandb.init(
    project="ai-geolocation",
    config={
        "learning_rate": {lr},
        "step_size": {step_size},
        "epochs":{num_epochs},
        "optimizer_gamma":{gamma}
    }
)

  self.load_state_dict(torch.load(f"{file_dir}/weights/{location_encoder_path}"))
  self.image_encoder.mlp.load_state_dict(torch.load(f"{self.weights_folder}/{image_encoder_path}"))
  self.location_encoder.load_state_dict(torch.load(f"{self.weights_folder}/{location_encoder_path}"))
  self.logit_scale = nn.Parameter(torch.load(f"{self.weights_folder}/{logit_scale_path}"))
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mselena-sun[0m ([33mvannevar[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
# Train
os.makedirs("snapshots", exist_ok=True)

for epoch in range(num_epochs):
    train(train_dataloader, model, optimizer, epoch, batch_size, device, scheduler=scheduler)

    snapshot_dir = f"snapshots/epoch_{epoch}"
    os.makedirs(snapshot_dir, exist_ok=True)
    
    print(f"Saved snapshot for epoch {epoch}")

# Save fine-tuned weights
current_time = datetime.now().strftime("%m-%d-%H:%M")
torch.save(model.image_encoder.mlp.state_dict(), f"model/weights/fine_tuned_image_encoder_mlp_weights_{current_time}.pth")
torch.save(model.location_encoder.state_dict(), f"model/weights/fine_tuned_location_encoder_weights_{current_time}.pth")
torch.save(model.logit_scale, f"model/weights/fine_tuned_logit_scale_weights_{current_time}.pth")
wandb.finish()

Starting Epoch 0


Epoch 0 loss: 0.43922: 100%|██████████| 477/477 [11:47<00:00,  1.48s/it]

Saved snapshot for epoch 0
Starting Epoch 1



Epoch 1 loss: 0.42820: 100%|██████████| 477/477 [10:54<00:00,  1.37s/it]

Saved snapshot for epoch 1
Starting Epoch 2



Epoch 2 loss: 0.41495: 100%|██████████| 477/477 [10:54<00:00,  1.37s/it]

Saved snapshot for epoch 2
Starting Epoch 3



Epoch 3 loss: 4.18579:  67%|██████▋   | 318/477 [07:16<03:37,  1.37s/it]