In [51]:
%cd /kaggle/working
!rm -rf G3-Original

/kaggle/working


In [1]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
import os

# Retrieve the secret and set env var
os.environ['GITHUB_PAT'] = user_secrets.get_secret("GITHUB_Geolocation")

# Then use shell expansion in Kaggle (Bash)
!git clone https://tungduong0708:${GITHUB_PAT}@github.com/tungduong0708/G3-Original.git

fatal: destination path 'G3-Original' already exists and is not an empty directory.


In [2]:
%cd G3-Original/g3

/kaggle/working/G3-Original/g3


In [3]:
!ls

filtered_eval.ipynb		    hparams.yaml
filtered_finetune.ipynb		    mp16_pro_filter.ipynb
filtered_llm_predict_dataset.ipynb  optuna_hyperparams_tune.ipynb
filtered_mp16_pro_index.ipynb	    __pycache__
filtered_train.ipynb		    results
full_train.ipynb		    utils
g3_prediction.py		    zeroshot_prediction.py


In [4]:
!pwd

/kaggle/working/G3-Original/g3


In [5]:
import torch
from torch.utils.data import IterableDataset, DataLoader, get_worker_info
from datasets import load_dataset
from PIL import Image
from io import BytesIO
import torchvision.transforms as T
from typing import Dict, Any, Iterator, Optional, Tuple
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient

# Authenticate to HF Hub (use Kaggle secret)
user_secrets = UserSecretsClient()
login(token=user_secrets.get_secret("HF_KEY"))

__all__ = [
    "MP16StreamingDataset",
    "mp16_collate",
]


class MP16StreamingDataset(IterableDataset):
    """Stream **MP‑16** samples from the HuggingFace Hub and yield a simple
    tuple per example::

        (image, text, longitude, latitude)

    * **image**  – either a tensor (``C×H×W``) if *vision_processor* is set or if
      the fallback transform is used, otherwise a PIL image.
    * **text**   – caption string (either provided by the dataset or generated
      from location fields).
    * **longitude**, **latitude** – floats.

    The class is an :class:`torch.utils.data.IterableDataset`, so wrap it in a
    :class:`~torch.utils.data.DataLoader` for batching.
    """

    def __init__(
        self,
        repo_id: str = "tduongvn/MP16-Pro-shards",
        split: str = "train",
        vision_processor: Optional[Any] = None,
        shuffle_buffer: int = 10_000,
    ) -> None:
        super().__init__()
        self.repo_id = repo_id
        self.split = split
        self.vision_processor = vision_processor
        self.shuffle_buffer = shuffle_buffer

        # Base transform when we *don't* have a fancy processor
        self.fallback_transform = T.Compose(
            [
                T.RandomHorizontalFlip(),
                T.RandomResizedCrop(size=224),
                T.ToTensor(),
            ]
        )

        # Prepare an initial dataset iterator for the main process
        self._base_iter = self._new_iterator()

    # ──────────────────────────────────────────────────────────────────────────
    # Internals                                                               ─┘

    def _new_iterator(self):
        return (
            load_dataset(self.repo_id, split=self.split, streaming=True)
            .shuffle(buffer_size=self.shuffle_buffer)
            .__iter__()
        )

    def _decode_image(self, img_bytes):
        """bytes → PIL.Image or tensor (if processor is set)."""
        img = Image.open(BytesIO(img_bytes)).convert("RGB")
        if self.vision_processor is not None:
            return self.vision_processor(images=img, return_tensors="pt")[
                "pixel_values"
            ].squeeze(0)
        return self.fallback_transform(img)

    def _caption(self, ex_json: Dict[str, Any]) -> str:
        parts = [ex_json.get(k) for k in ("city", "state", "country") if ex_json.get(k)]
        return "A street view photo taken in " + ", ".join(parts)

    # ──────────────────────────────────────────────────────────────────────────
    # IterableDataset API                                                     ─┘

    def __iter__(self) -> Iterator[Tuple[Any, str, float, float]]:
        # Each DataLoader worker gets its own iterator to avoid state clashes.
        worker = get_worker_info()
        iterator = self._new_iterator() if worker is not None else self._base_iter

        for ex in iterator:
            # Dataset structure: {'jpg': <PIL or bytes>, 'json': {...}, ...}
            img_field = ex["jpg"]
            if isinstance(img_field, Image.Image):
                img = img_field.convert("RGB")
                if self.vision_processor is not None:
                    img = self.vision_processor(images=img, return_tensors="pt")[
                        "pixel_values"
                    ].squeeze(0)
                else:
                    img = self.fallback_transform(img)
            else:  # bytes
                img = self._decode_image(img_field)

            meta = ex["json"] if "json" in ex else {}
            lon = float(meta.get("lon", meta.get("LON")))
            lat = float(meta.get("lat", meta.get("LAT")))
            text = meta.get("text") or self._caption(meta)

            yield img, text, lon, lat

    # No __len__ – this is a stream.


# ─────────────────────────────────────────────────────────────────────────────
# Collate                                                                     ─┘

def make_mp16_collate(text_processor):
    def collate(batch):
        images, texts, lons, lats = zip(*batch)

        images = torch.stack(images)  # (B, C, H, W)

        token_out = text_processor(
            list(texts),
            padding="longest",
            truncation=True,
            max_length=77,
            return_tensors="pt"
        )

        lons = torch.tensor(lons, dtype=torch.float32)
        lats = torch.tensor(lats, dtype=torch.float32)

        return images, token_out, lons, lats

    return collate

In [22]:
ds = MP16StreamingDataset()
dl = DataLoader(ds, batch_size=1, num_workers=2)
imgs, texts, lons, lats = next(iter(dl))
print(type(imgs), len(texts), lons[:3], lats[:3])


Resolving data files:   0%|          | 0/363 [00:00<?, ?it/s]

<class 'torch.Tensor'> 1 tensor([10.7612], dtype=torch.float64) tensor([59.9240], dtype=torch.float64)


In [6]:
# Standard library
import os
import time
import warnings
import yaml

# Third-party libraries
import torch
import optuna
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader
from accelerate import Accelerator, DistributedDataParallelKwargs

# Local application imports
from utils.utils import MP16Dataset
from utils.G3 import G3
from zeroshot_prediction import ZeroShotPredictor


warnings.filterwarnings('ignore')

def train_1epoch(dataloader, eval_dataloader, earlystopper, model, vision_processor, text_processor, optimizer, scheduler, device, accelerator=None):
    model.train()
    t = tqdm(dataloader, disable=not accelerator.is_local_main_process)
    for i, (images, texts, longitude, latitude) in enumerate(t):
        texts = texts.to(device)
        images = images.to(device)
        texts = texts.to(device)
        longitude = longitude.to(device).float()
        latitude = latitude.to(device).float()
        optimizer.zero_grad()

        output = model(images, texts, longitude, latitude, return_loss=True)
        loss = output['loss']

        # loss.backward()
        accelerator.backward(loss)
        optimizer.step()
        if i % 1 == 0:
            t.set_description('step {}, loss {}, lr {}'.format(i, loss.item(), scheduler.get_last_lr()[0]))
    scheduler.step()


2025-07-02 04:05:15.828085: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751429116.019046     824 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751429116.075549     824 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [7]:
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

# fine-tune
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = 'cpu'
hparams = yaml.safe_load(open('hparams.yaml', 'r'))
pe = "projection_mercator"
nn = "rffmlp"
model = G3(
    device=device,
    positional_encoding_type=pe,
    neural_network_type=nn,
    hparams=hparams[f'{pe}_{nn}'],
).to(device)

# model = torch.load('g3_5_.pth')
# location_encoder_dict = torch.load('checkpoints/location_encoder_weights.pth') # from geoclip
# model.location_encoder.load_state_dict(location_encoder_dict)

dataset = MP16StreamingDataset(vision_processor = model.vision_processor)
collate_fn = make_mp16_collate(model.text_processor)
dataloader = DataLoader(
    dataset, 
    batch_size=64, 
    shuffle=False, 
    num_workers=16, 
    pin_memory=True, 
    prefetch_factor=5, 
    collate_fn=collate_fn
)


params = []
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.size())
        params.append(param)

optimizer = torch.optim.AdamW([param for name,param in model.named_parameters() if param.requires_grad], lr=hparams[f'{pe}_{nn}']['lr'], weight_decay=hparams[f'{pe}_{nn}']['wd'])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.87)

model, optimizer, dataloader, scheduler = accelerator.prepare(
    model, optimizer, dataloader, scheduler
)

eval_dataloader = None
earlystopper = None
for epoch in range(10):
    train_1epoch(dataloader, eval_dataloader, earlystopper, model, model.vision_processor, model.text_processor, optimizer, scheduler, device, accelerator)
    unwrapped_model = accelerator.unwrap_model(model)
    os.makedirs('checkpoints', exist_ok=True)
    torch.save(unwrapped_model, 'checkpoints/g3_{}_.pth'.format(epoch))

    # with open("threshold_accuracy.txt", "a") as f:
    #     f.write('checkpoints/g3_{}_.pth\n'.format(epoch))


    # predictor = ZeroShotPredictor(model='checkpoints/g3_{}_.pth'.format(epoch), device=device)
    # df, res = predictor.evaluate_im2gps3k(
    #     df_path="data/im2gps3k/im2gps3k_places365.csv",
    #     top_k=5,
    #     root_path="data/im2gps3k/",
    #     image_data_path="filtered_im2gps3k",
    #     text_data_path= "im2gps3k_places365.csv"
    # )

Resolving data files:   0%|          | 0/363 [00:00<?, ?it/s]

logit_scale1 torch.Size([])
logit_scale2 torch.Size([])
logit_scale3 torch.Size([])
location_encoder.neural_network.0.LocEnc0.capsule.1.weight torch.Size([1024, 512])
location_encoder.neural_network.0.LocEnc0.capsule.1.bias torch.Size([1024])
location_encoder.neural_network.0.LocEnc0.capsule.3.weight torch.Size([1024, 1024])
location_encoder.neural_network.0.LocEnc0.capsule.3.bias torch.Size([1024])
location_encoder.neural_network.0.LocEnc0.capsule.5.weight torch.Size([1024, 1024])
location_encoder.neural_network.0.LocEnc0.capsule.5.bias torch.Size([1024])
location_encoder.neural_network.0.LocEnc0.head.0.weight torch.Size([512, 1024])
location_encoder.neural_network.0.LocEnc0.head.0.bias torch.Size([512])
location_encoder.neural_network.0.LocEnc1.capsule.1.weight torch.Size([1024, 512])
location_encoder.neural_network.0.LocEnc1.capsule.1.bias torch.Size([1024])
location_encoder.neural_network.0.LocEnc1.capsule.3.weight torch.Size([1024, 1024])
location_encoder.neural_network.0.LocEnc1.

step 28, loss 7.859357833862305, lr 3e-05: : 29it [02:55,  6.04s/it] 


KeyboardInterrupt: 

In [6]:
from datasets import load_dataset

iter_ds = load_dataset(
    "tduongvn/MP16-Pro-shards",
    split="train",
    streaming=True
)

print(next(iter(iter_ds)))

Resolving data files:   0%|          | 0/363 [00:00<?, ?it/s]

{'jpg': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x640 at 0x7BABB6C34410>, 'json': {'lat': 37.790855, 'lon': -122.399485, 'text': 'A street view photo taken in San Francisco, California, United States'}, '__key__': '00_00_12844544494', '__url__': 'hf://datasets/tduongvn/MP16-Pro-shards@f2ffe6fecc33e044463f4e9d7a10c30bbc771ee6/mp16-0000.tar'}


In [8]:
from datasets import load_dataset

repo_id = "tduongvn/MP16-Pro-shards"
split    = "train"

# This uses the Arrow cache, not streaming, but just for metadata:
num_rows = load_dataset(repo_id, split=split, streaming=False).num_rows
print("Total rows:", num_rows)


Resolving data files:   0%|          | 0/363 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/363 [00:00<?, ?files/s]

mp16-0000.tar:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

mp16-0001.tar:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

mp16-0002.tar:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

mp16-0003.tar:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

mp16-0004.tar:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

mp16-0005.tar:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

mp16-0006.tar:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

mp16-0007.tar:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

mp16-0008.tar:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

KeyboardInterrupt: 