In [7]:
import os
import sys
import logging
import argparse
from pathlib import Path
from collections import OrderedDict

import torch
import numpy as np
from PIL import Image
from tqdm import tqdm

# --- PyTorch Imports ---
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter

# --- TorchVision & TorchMetrics Imports ---
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchmetrics.detection import MeanAveragePrecision
import torch.nn as nn

In [10]:

# --- Add Galileo to Path and Import ---
sys.path.insert(0, str(Path.cwd() / "models"))
print(sys.path)
try:
    from galileo import (
        Encoder as GalileoEncoder,
        SPACE_TIME_BANDS, SPACE_TIME_BANDS_GROUPS_IDX,
        SPACE_BANDS, SPACE_BAND_GROUPS_IDX,
        TIME_BANDS, TIME_BAND_GROUPS_IDX,
        STATIC_BANDS, STATIC_BAND_GROUPS_IDX
    )
except ImportError:
    print("Error: Could not import Galileo model. Make sure 'models/galileo.py' exists.")
    sys.exit(1)


['/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/iclr_2026/models', '/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/iclr_2026/models', '/opt/anaconda3/envs/rishabh_sat/lib/python312.zip', '/opt/anaconda3/envs/rishabh_sat/lib/python3.12', '/opt/anaconda3/envs/rishabh_sat/lib/python3.12/lib-dynload', '', '/home/rishabh.mondal/.local/lib/python3.12/site-packages', '/opt/anaconda3/envs/rishabh_sat/lib/python3.12/site-packages', '/opt/anaconda3/envs/rishabh_sat/lib/python3.12/site-packages/ISR-2.2.0-py3.12.egg', '/opt/anaconda3/envs/rishabh_sat/lib/python3.12/site-packages/setuptools/_vendor', '/home/rishabh.mondal/solo-learn', '/tmp/tmpb_rwcmur']


In [13]:
class GalileoBackboneWrapper(nn.Module):
    def __init__(self, pretrained_path: str, patch_size: int = 8):
        super().__init__()
        logging.info(f"Loading Galileo encoder from {pretrained_path}...")
        self.encoder = GalileoEncoder.load_from_folder(Path(pretrained_path), device='cpu')
        self.out_channels = self.encoder.embedding_size
        self.patch_size = patch_size
        self.projection = nn.Conv2d(self.encoder.embedding_size, self.out_channels, kernel_size=1)
        logging.info(f"Galileo backbone initialized with output channels: {self.out_channels}")

    def forward(self, x: torch.Tensor) -> OrderedDict:
        print(f"[Backbone] Input x: {x.shape}")  # <-- print input to backbone

        b, _, h, w = x.shape
        s_t_x = torch.zeros(b, h, w, 1, len(SPACE_TIME_BANDS), device=x.device, dtype=x.dtype)
        s2_rgb_indices = [SPACE_TIME_BANDS.index(b) for b in ["B2", "B3", "B4"]]
        s_t_x[..., s2_rgb_indices] = x.permute(0, 2, 3, 1).unsqueeze(-2)
        print(f"[Backbone] s_t_x: {s_t_x.shape}")

        s_t_m = torch.ones(b, h, w, 1, len(SPACE_TIME_BANDS_GROUPS_IDX), device=x.device, dtype=torch.long)
        s2_rgb_group_idx = list(SPACE_TIME_BANDS_GROUPS_IDX.keys()).index('S2_RGB')
        s_t_m[..., s2_rgb_group_idx] = 0
        print(f"[Backbone] s_t_m: {s_t_m.shape}")

        data_args, mask_args = {'device': x.device, 'dtype': x.dtype}, {'device': x.device, 'dtype': torch.long}
        sp_x, t_x, st_x = torch.zeros(b, h, w, len(SPACE_BANDS), **data_args), torch.zeros(b, 1, len(TIME_BANDS), **data_args), torch.zeros(b, len(STATIC_BANDS), **data_args)
        sp_m, t_m, st_m = torch.ones(b, h, w, len(SPACE_BAND_GROUPS_IDX), **mask_args), torch.ones(b, 1, len(TIME_BAND_GROUPS_IDX), **mask_args), torch.ones(b, len(STATIC_BAND_GROUPS_IDX), **mask_args)
        months = torch.ones(b, 1, device=x.device, dtype=torch.long) * 6

        s_t_out, _, _, _, _, _, _, _, _ = self.encoder(
            s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m, months,
            patch_size=self.patch_size, add_layernorm_on_exit=True
        )
        print(f"[Backbone] s_t_out: {s_t_out.shape}")

        feature_tokens = s_t_out[:, :, :, 0, s2_rgb_group_idx, :]
        print(f"[Backbone] feature_tokens: {feature_tokens.shape}")

        feature_map = feature_tokens.permute(0, 3, 1, 2).contiguous()
        print(f"[Backbone] feature_map: {feature_map.shape}")

        projected_map = self.projection(feature_map)
        print(f"[Backbone] projected_map: {projected_map.shape}")

        return OrderedDict([("0", projected_map)])


In [14]:
def create_model(weights_path: str, num_classes: int) -> nn.Module:
    backbone = GalileoBackboneWrapper(pretrained_path=weights_path)
    anchor_sizes = ((16, 32, 64, 96, 128),)
    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
    anchor_generator = AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
    model = FasterRCNN(backbone, num_classes=num_classes, rpn_anchor_generator=anchor_generator)
    return model

In [21]:
dummy = torch.randn(1, 3, 96, 96)
model = create_model(weights_path="/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/iclr_2026/nano", num_classes=4)

model.eval()
with torch.no_grad():
    out = model(dummy)
# print("Final output:", out)


[Backbone] Input x: torch.Size([1, 3, 800, 800])
[Backbone] s_t_x: torch.Size([1, 800, 800, 1, 13])
[Backbone] s_t_m: torch.Size([1, 800, 800, 1, 7])
[Backbone] s_t_out: torch.Size([1, 100, 100, 1, 7, 128])
[Backbone] feature_tokens: torch.Size([1, 100, 100, 128])
[Backbone] feature_map: torch.Size([1, 128, 100, 100])
[Backbone] projected_map: torch.Size([1, 128, 100, 100])


In [22]:
def print_shape_hook(module, input, output):
    name = module.__class__.__name__
    if isinstance(output, (list, tuple)):
        shapes = [o.shape if hasattr(o, "shape") else type(o) for o in output]
    else:
        shapes = output.shape if hasattr(output, "shape") else type(output)
    print(f"[{name}] output shape: {shapes}")

# Attach hooks to key parts of FasterRCNN
for name, module in model.named_modules():
    if any(k in name.lower() for k in ["rpn", "roi_heads", "box_head", "box_predictor"]):
        module.register_forward_hook(print_shape_hook)


In [23]:
print(model)


FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): GalileoBackboneWrapper(
    (encoder): Encoder(
      (blocks): ModuleListWithInit(
        (0-3): 4 x Block(
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (attn): Attention(
            (q): Linear(in_features=128, out_features=128, bias=True)
            (k): Linear(in_features=128, out_features=128, bias=True)
            (v): Linear(in_features=128, out_features=128, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=128, out_features=128, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path): DropPath()
          (norm2): LayerNorm((128,), eps=1e-05, ele

In [25]:
import torch
import torch.nn as nn
from collections import OrderedDict
from pathlib import Path
import logging

from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.transform import GeneralizedRCNNTransform

# --- Your Galileo imports must already work ---
# from galileo import (
#     Encoder as GalileoEncoder,
#     SPACE_TIME_BANDS, SPACE_TIME_BANDS_GROUPS_IDX,
#     SPACE_BANDS, SPACE_BAND_GROUPS_IDX,
#     TIME_BANDS, TIME_BAND_GROUPS_IDX,
#     STATIC_BANDS, STATIC_BAND_GROUPS_IDX
# )

# ===== Backbone (unchanged except for optional prints) =====
class GalileoBackboneWrapper(nn.Module):
    def __init__(self, pretrained_path: str, patch_size: int = 8, debug: bool = False):
        super().__init__()
        self.debug = debug
        logging.info(f"Loading Galileo encoder from {pretrained_path}...")
        self.encoder = GalileoEncoder.load_from_folder(Path(pretrained_path), device='cpu')
        self.out_channels = self.encoder.embedding_size
        self.patch_size = patch_size
        self.projection = nn.Conv2d(self.encoder.embedding_size, self.out_channels, kernel_size=1)
        logging.info(f"Galileo backbone initialized with output channels: {self.out_channels}")

    def forward(self, x: torch.Tensor) -> OrderedDict:
        if self.debug: print(f"[Backbone] Input x: {tuple(x.shape)}")

        b, _, h, w = x.shape
        s_t_x = torch.zeros(b, h, w, 1, len(SPACE_TIME_BANDS), device=x.device, dtype=x.dtype)
        s2_rgb_indices = [SPACE_TIME_BANDS.index(band) for band in ["B2", "B3", "B4"]]
        s_t_x[..., s2_rgb_indices] = x.permute(0, 2, 3, 1).unsqueeze(-2)
        if self.debug: print(f"[Backbone] s_t_x: {tuple(s_t_x.shape)}")

        s_t_m = torch.ones(b, h, w, 1, len(SPACE_TIME_BANDS_GROUPS_IDX), device=x.device, dtype=torch.long)
        s2_rgb_group_idx = list(SPACE_TIME_BANDS_GROUPS_IDX.keys()).index('S2_RGB')
        s_t_m[..., s2_rgb_group_idx] = 0
        if self.debug: print(f"[Backbone] s_t_m: {tuple(s_t_m.shape)}")

        data_args, mask_args = {'device': x.device, 'dtype': x.dtype}, {'device': x.device, 'dtype': torch.long}
        sp_x = torch.zeros(b, h, w, len(SPACE_BANDS), **data_args)
        t_x  = torch.zeros(b, 1, len(TIME_BANDS), **data_args)
        st_x = torch.zeros(b, len(STATIC_BANDS), **data_args)
        sp_m = torch.ones(b, h, w, len(SPACE_BAND_GROUPS_IDX), **mask_args)
        t_m  = torch.ones(b, 1, len(TIME_BAND_GROUPS_IDX), **mask_args)
        st_m = torch.ones(b, len(STATIC_BAND_GROUPS_IDX), **mask_args)
        months = torch.ones(b, 1, device=x.device, dtype=torch.long) * 6  # adjust if 0/1-indexed months differ

        s_t_out, *_ = self.encoder(
            s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m, months,
            patch_size=self.patch_size, add_layernorm_on_exit=True
        )
        if self.debug: print(f"[Backbone] s_t_out: {tuple(s_t_out.shape)}")  # [B, H', W', 1, Groups, C]

        feature_tokens = s_t_out[:, :, :, 0, s2_rgb_group_idx, :]  # [B, H', W', C]
        if self.debug: print(f"[Backbone] feature_tokens: {tuple(feature_tokens.shape)}")

        feature_map = feature_tokens.permute(0, 3, 1, 2).contiguous()  # [B, C, H', W']
        if self.debug: print(f"[Backbone] feature_map: {tuple(feature_map.shape)}")

        projected_map = self.projection(feature_map)  # [B, C(out), H', W']
        if self.debug: print(f"[Backbone] projected_map: {tuple(projected_map.shape)}")

        return OrderedDict([("0", projected_map)])


# ===== Patched create_model forcing 128×128 =====
def create_model(
    weights_path: str,
    num_classes: int,
    img_size: int = 128,
    patch_size: int = 8,
    debug_backbone: bool = False,
) -> nn.Module:
    """
    Builds Faster R-CNN with:
      - GalileoBackboneWrapper
      - Transform fixed to img_size x img_size (no auto-resize to 800)
      - Anchors chosen for small images
    """
    # Backbone
    backbone = GalileoBackboneWrapper(pretrained_path=weights_path, patch_size=patch_size, debug=debug_backbone)

    # Anchor sizes chosen for a 128×128 image with stride=patch_size (i.e., feature map ~16×16).
    # Adjust if your objects are much smaller/larger.
    anchor_sizes = ((16, 32, 64),)  # pixels on the input image
    aspect_ratios = ((0.5, 1.0, 2.0),)
    anchor_generator = AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)

    # Build model
    model = FasterRCNN(
        backbone,
        num_classes=num_classes,
        rpn_anchor_generator=anchor_generator,
        # box_nms_thresh, rpn_nms_thresh, etc. can be tuned later
    )

    # OVERRIDE the internal image transform to FIX input size to 128×128
    # If your data are already in [0,1], using ImageNet stats is fine; adjust if using a different normalization.
    model.transform = GeneralizedRCNNTransform(
        min_size=img_size,   # shorter side
        max_size=img_size,   # longer side
        image_mean=[0.485, 0.456, 0.406],
        image_std=[0.229, 0.224, 0.225],
    )

    return model


# ===== Quick smoke test (run this once to verify shapes) =====
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = create_model(
        weights_path="/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/iclr_2026/nano",
        num_classes=2,
        img_size=128,
        patch_size=8,
        debug_backbone=True,   # set False to silence prints
    ).to(device).eval()

    dummy = torch.randn(1, 3, 128, 128, device=device)
    with torch.no_grad():
        outputs = model(dummy)   # will keep input at 128×128 internally
    print("OK. Forward pass complete.")


[Backbone] Input x: (1, 3, 128, 128)
[Backbone] s_t_x: (1, 128, 128, 1, 13)
[Backbone] s_t_m: (1, 128, 128, 1, 7)
[Backbone] s_t_out: (1, 16, 16, 1, 7, 128)
[Backbone] feature_tokens: (1, 16, 16, 128)
[Backbone] feature_map: (1, 128, 16, 16)
[Backbone] projected_map: (1, 128, 16, 16)
OK. Forward pass complete.
