In [None]:
!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124
!pip install transformers accelerate huggingface_hub pandas optuna pyproj einops geopy matplotlib kaggle  geopandas cartopy

In [None]:
import os

os.environ['KAGGLE_USERNAME'] = ''
os.environ['KAGGLE_KEY'] = ''

from kaggle.api.kaggle_api_extended import KaggleApi

api = KaggleApi()
api.authenticate()

api.dataset_download_files('lctngdng/im2gps3k', path='data/im2gps3k', unzip=True)

In [None]:
import geopandas as gpd
from shapely.geometry import Point
import cartopy.io.shapereader as shpreader

# Load Natural Earth country polygons
shp_path = shpreader.natural_earth(
    resolution='110m',
    category='cultural',
    name='admin_0_countries'
)
world = gpd.read_file(shp_path)[["ADMIN", "geometry"]].to_crs(epsg=4326)
# Build spatial index for performance
sindex = world.sindex

# Function returns country name for a given latitude/longitude
def point_to_country(lat, lon, countries=world, index=sindex):
    pt = Point(lon, lat)
    # Find candidate polygons via spatial index
    candidates = list(index.intersection(pt.bounds))
    # Check strict land contains
    for idx in candidates:
        if countries.geometry.iloc[idx].contains(pt):
            return countries.ADMIN.iloc[idx]
    # Fallback to intersects for border/water cases
    for idx in candidates:
        if countries.geometry.iloc[idx].intersects(pt):
            return countries.ADMIN.iloc[idx]
    # No match found
    return None

In [None]:
import pandas as pd
df = pd.read_csv("data/im2gps3k/im2gps3k_places365.csv")
df.head()

In [None]:
def is_target_country(row):
    lat, lon = row["LAT"], row["LON"]
    return point_to_country(lat, lon) in ["Ukraine", "Israel", "Russia", "Palestine"]

# apply row‐wise and get a boolean mask
mask = df.apply(is_target_country, axis=1)

# filter the dataframe
df_filtered = df[mask]
df_filtered.head()
df_filtered.to_csv("data/im2gps3k/im2gps3k_places365.csv")

In [None]:
from pathlib import Path
import shutil
Path("data/im2gps3k/filtered_im2gps3k").mkdir(parents=True, exist_ok=True)
for image in df_filtered["IMG_ID"]:
    shutil.copyfile(
        f"data/im2gps3k/im2gps3ktest/{image}",
        f"data/im2gps3k/filtered_im2gps3k/{image}"
    )

In [None]:
import torch
from zeroshot_prediction import ZeroShotPredictor
predictor = ZeroShotPredictor(
    model='checkpoints/g3_9_.pth',  # Path to your model
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

print("Starting prediction on im2gps3k dataset...")
# Get predictions and evaluate
results_df,_ = predictor.evaluate_im2gps3k(
    df_path='data/im2gps3k/im2gps3k_places365.csv',
    top_k=5,
    root_path='data/im2gps3k',
    image_data_path='filtered_im2gps3k',
    text_data_path='im2gps3k_places365.csv'
)