In [None]:
!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124
!pip install transformers accelerate huggingface_hub pandas optuna pyproj einops geopy matplotlib kaggle  geopandas cartopy

In [None]:
from huggingface_hub import snapshot_download, hf_hub_download, login

login(token="API_KEY")

path = snapshot_download(
    repo_id="tduongvn/ACMMM25-Geolocation",
    repo_type="dataset",
    allow_patterns=["*.tar"],
    local_dir="data/mp16/",
    use_auth_token=True  # will use your local token from `huggingface-cli login`
)
print(f"Downloaded dataset to {path}")

files = [
    "metadata/MP16_Pro_filtered.csv",
    "metadata/MP16_Pro_places365.csv",
    "metadata/mp16_urls.csv"
]
for file in files:
    path = hf_hub_download(
        repo_id="Jia-py/MP16-Pro",
        filename=file,
        repo_type="dataset",
        local_dir="data/mp16/",
        use_auth_token=True  # will use your local token from `huggingface-cli login`
    )
    print(f"Downloaded {file} to {path}")


In [None]:
!mv data/mp16/metadata/*.csv data/mp16

In [None]:
import os

os.environ['KAGGLE_USERNAME'] = ''
os.environ['KAGGLE_KEY'] = ''

from kaggle.api.kaggle_api_extended import KaggleApi

api = KaggleApi()
api.authenticate()

api.dataset_download_files('lctngdng/im2gps3k', path='data/im2gps3k', unzip=True)

In [None]:
import geopandas as gpd
from shapely.geometry import Point
import cartopy.io.shapereader as shpreader

# Load Natural Earth country polygons
shp_path = shpreader.natural_earth(
    resolution='110m',
    category='cultural',
    name='admin_0_countries'
)
world = gpd.read_file(shp_path)[["ADMIN", "geometry"]].to_crs(epsg=4326)
# Build spatial index for performance
sindex = world.sindex

# Function returns country name for a given latitude/longitude
def point_to_country(lat, lon, countries=world, index=sindex):
    pt = Point(lon, lat)
    # Find candidate polygons via spatial index
    candidates = list(index.intersection(pt.bounds))
    # Check strict land contains
    for idx in candidates:
        if countries.geometry.iloc[idx].contains(pt):
            return countries.ADMIN.iloc[idx]
    # Fallback to intersects for border/water cases
    for idx in candidates:
        if countries.geometry.iloc[idx].intersects(pt):
            return countries.ADMIN.iloc[idx]
    # No match found
    return None

In [None]:
import pandas as pd
df = pd.read_csv("data/im2gps3k/im2gps3k_places365.csv")
df.head()

In [None]:
def is_target_country(row):
    lat, lon = row["LAT"], row["LON"]
    return point_to_country(lat, lon) in ["Ukraine", "Israel", "Russia", "Palestine"]

# apply row‐wise and get a boolean mask
mask = df.apply(is_target_country, axis=1)

# filter the dataframe
df_filtered = df[mask]
df_filtered.head()
df_filtered.to_csv("data/im2gps3k/im2gps3k_places365.csv")

In [None]:
from pathlib import Path
import shutil
Path("data/im2gps3k/filtered_im2gps3k").mkdir(parents=True, exist_ok=True)
for image in df_filtered["IMG_ID"]:
    shutil.copyfile(
        f"data/im2gps3k/im2gps3ktest/{image}",
        f"data/im2gps3k/filtered_im2gps3k/{image}"
    )

In [None]:
# Standard library
import os
import time
import warnings
import yaml

# Third-party libraries
import torch
import optuna
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader
from accelerate import Accelerator, DistributedDataParallelKwargs

# Local application imports
from utils.utils import MP16Dataset
from utils.G3 import G3
from zeroshot_prediction import ZeroShotPredictor


warnings.filterwarnings('ignore')

def train_1epoch(dataloader, eval_dataloader, earlystopper, model, vision_processor, text_processor, optimizer, scheduler, device, accelerator=None):
    model.train()
    t = tqdm(dataloader, disable=not accelerator.is_local_main_process)
    for i, (images, texts, longitude, latitude) in enumerate(t):
        texts = text_processor(text=texts, padding='max_length', truncation=True, return_tensors='pt', max_length=77)
        images = images.to(device)
        texts = texts.to(device)
        longitude = longitude.to(device).float()
        latitude = latitude.to(device).float()
        optimizer.zero_grad()

        output = model(images, texts, longitude, latitude, return_loss=True)
        loss = output['loss']

        # loss.backward()
        accelerator.backward(loss)
        optimizer.step()
        if i % 1 == 0:
            t.set_description('step {}, loss {}, lr {}'.format(i, loss.item(), scheduler.get_last_lr()[0]))
    scheduler.step()


In [None]:
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

# fine-tune
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = 'cpu'
hparams = yaml.safe_load(open('hparams.yaml', 'r'))
pe = "projection_mercator"
nn = "rffmlp"
model = G3(
    device=device,
    positional_encoding_type=pe,
    neural_network_type=nn,
    hparams=hparams[f'{pe}_{nn}'],
).to(device)

# model = torch.load('g3_5_.pth')
# location_encoder_dict = torch.load('checkpoints/location_encoder_weights.pth') # from geoclip
# model.location_encoder.load_state_dict(location_encoder_dict)

dataset = MP16Dataset(vision_processor = model.vision_processor, text_processor = model.text_processor, root_path='data/mp16/',image_data_path='filtered_mp16.tar')
dataloader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=16, pin_memory=True, prefetch_factor=5)


params = []
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.size())
        params.append(param)

optimizer = torch.optim.AdamW([param for name,param in model.named_parameters() if param.requires_grad], lr=hparams[f'{pe}_{nn}']['lr'], weight_decay=hparams[f'{pe}_{nn}']['wd'])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.87)

model, optimizer, dataloader, scheduler = accelerator.prepare(
    model, optimizer, dataloader, scheduler
)

eval_dataloader = None
earlystopper = None
for epoch in range(10):
    train_1epoch(dataloader, eval_dataloader, earlystopper, model, model.vision_processor, model.text_processor, optimizer, scheduler, device, accelerator)
    unwrapped_model = accelerator.unwrap_model(model)
    os.makedirs('checkpoints', exist_ok=True)
    torch.save(unwrapped_model, 'checkpoints/g3_{}_.pth'.format(epoch))

    with open("threshold_accuracy.txt", "a") as f:
        f.write('checkpoints/g3_{}_.pth\n'.format(epoch))


    predictor = ZeroShotPredictor(model='checkpoints/g3_{}_.pth'.format(epoch), device=device)
    df, res = predictor.evaluate_im2gps3k(
        df_path="data/im2gps3k/im2gps3k_places365.csv",
        top_k=5,
        root_path="data/im2gps3k/",
        image_data_path="filtered_im2gps3k",
        text_data_path= "im2gps3k_places365.csv"
    )

In [None]:
!mv checkpoints/g3_5_.pth checkpoints/sirensh_6epoch_tmp.pth

In [None]:
from huggingface_hub import upload_file

upload_file(
    path_or_fileobj="checkpoints/sirensh_6epoch_tmph.pth",        # Local file path
    path_in_repo="sirensh_6epoch_tmp.pth",                         # Desired path in the repo
    repo_id="tduongvn/Checkpoints-ACMMM25",               # e.g., "tungduong/my-model-repo"
    repo_type="model",                                    # or "dataset" or "space"
    commit_message="Upload sirensh_6epoch_tmp.pth"
)
