In [4]:
import os
import sys
import csv
import math
import logging
import argparse
from pathlib import Path
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
from tqdm import tqdm

from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign
from torchmetrics.detection import MeanAveragePrecision

# -------------------------
# TorchGeo DOFA bits
# -------------------------
try:
    from torchgeo.models.dofa import DOFA, DOFABase16_Weights
except ImportError:
    print("Error: torchgeo is not installed or too old for DOFA. Try: pip install 'torchgeo>=0.7'")
    sys.exit(1)


# -------------------------
# Logging
# -------------------------
def setup_logging(log_dir: str):
    os.makedirs(log_dir, exist_ok=True)
    log_file = os.path.join(log_dir, 'training.log')
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] - %(message)s',
        handlers=[logging.FileHandler(log_file), logging.StreamHandler(sys.stdout)]
    )


# -------------------------
# CSV Logger
# -------------------------
class CSVLogger:
    def __init__(self, csv_path: str):
        self.csv_path = csv_path
        with open(self.csv_path, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(["epoch", "train_loss", "val_map", "val_map50", "lr"])

    def log(self, epoch, train_loss, val_map, val_map50, lr):
        val_map = float('nan') if val_map is None else float(val_map)
        val_map50 = float('nan') if val_map50 is None else float(val_map50)
        with open(self.csv_path, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([epoch, train_loss, val_map, val_map50, lr])


# -------------------------
# Dataset (PNG RGB) + YOLO‑OBB(9) -> AABB
# -------------------------
class BrickKilnDataset(Dataset):
    def __init__(self, root: str, split: str, input_size: int = 224):
        self.root = Path(root)
        self.split = split
        self.img_dir = self.root / "images"
        self.label_dir = self.root / "yolo_obb_labels"

        self.transform = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),  # [0,1]
        ])

        self.img_files = []
        all_files = sorted([f for f in os.listdir(self.img_dir) if f.lower().endswith(".png")])
        logging.info(f"Scanning {len(all_files)} PNGs in {self.img_dir}...")
        for img_name in tqdm(all_files, desc=f"Verify {split} data"):
            if self._has_valid_annotations(img_name):
                self.img_files.append(img_name)
        logging.info(f"Found {len(self.img_files)} valid images in {self.img_dir}")

    def _has_valid_annotations(self, img_name: str) -> bool:
        p = self.label_dir / f"{Path(img_name).stem}.txt"
        if not p.exists():
            return False
        with open(p, 'r') as f:
            for line in f:
                if len(line.strip().split()) == 9:
                    return True
        return False

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx: int):
        name = self.img_files[idx]
        img = Image.open(self.img_dir / name).convert("RGB")
        img_tensor = self.transform(img)
        _, h, w = img_tensor.shape

        boxes, labels = [], []
        with open(self.label_dir / f"{Path(name).stem}.txt", 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) != 9:
                    continue
                cls_id = int(parts[0]) + 1  # 0 reserved for background
                obb = np.array([float(p) for p in parts[1:]])
                xs, ys = obb[0::2] * w, obb[1::2] * h
                xmin, ymin, xmax, ymax = np.min(xs), np.min(ys), np.max(xs), np.max(ys)
                if xmax > xmin and ymax > ymin:
                    boxes.append([xmin, ymin, xmax, ymax])
                    labels.append(cls_id)

        target = {
            "boxes": torch.as_tensor(boxes, dtype=torch.float32),
            "labels": torch.as_tensor(labels, dtype=torch.int64),
        }
        return img_tensor, target


def collate_fn(batch):
    batch = [b for b in batch if b[1]["boxes"].shape[0] > 0]
    if not batch:
        return None, None
    return tuple(zip(*batch))


# -------------------------
# Utilities: interpolate DOFA positional embeddings for any img_size
# -------------------------
def _interp_pos_embed(sd_pos: torch.Tensor, new_side: int) -> torch.Tensor:
    """
    sd_pos: [1, N+1, C] from weights (N = old_side^2, old_side=14 for 224)
    Returns resized [1, newN+1, C] with bicubic interpolation (cls token kept).
    """
    cls_tok = sd_pos[:, :1, :]          # [1,1,C]
    tok     = sd_pos[:, 1:, :]          # [1,N,C]
    C = tok.shape[-1]
    old_side = int(math.sqrt(tok.shape[1]))
    tok = tok.reshape(1, old_side, old_side, C).permute(0, 3, 1, 2)        # [1,C,old,old]
    tok = F.interpolate(tok, size=(new_side, new_side), mode="bicubic", align_corners=False)
    tok = tok.permute(0, 2, 3, 1).reshape(1, new_side * new_side, C)       # [1,newN,C]
    return torch.cat([cls_tok, tok], dim=1)                                 # [1,1+newN,C]


# -------------------------
# DOFA Backbone Wrapper (manual weight load, flexible img_size)
# -------------------------
class DOFABackboneWrapper(nn.Module):
    """
    - Builds DOFA at arbitrary img_size (multiple of 16) WITHOUT loading via helper (to avoid assertions).
    - Manually loads DOFABase16 weights, resizing pos_embed if needed.
    - Exposes feature pyramid for Faster R‑CNN.
    """
    def __init__(self, image_size: int = 224, freeze_dofa: bool = False):
        super().__init__()
        if image_size % 16 != 0:
            raise ValueError("input_size must be a multiple of 16 for DOFA base (patch_size=16).")
        self.image_size = image_size

        # Construct a bare DOFA (no weights yet)
        self.dofa = DOFA(
            img_size=image_size,
            patch_size=16,
            drop_rate=0.0,
            embed_dim=768,
            depth=12,
            num_heads=12,
            dynamic_embed_dim=128,
            num_classes=45,
            global_pool=False,     # we use tokens as features, no global pool
            mlp_ratio=4.0,
        )

        # Load pretrained weights (manual to avoid TorchGeo assert)
        sd = DOFABase16_Weights.DOFA_MAE.get_state_dict(progress=True)

        # Remove classification head/fc_norm (task-specific)
        for k in ["fc_norm.weight", "fc_norm.bias", "head.weight", "head.bias"]:
            if k in sd: sd.pop(k)

        # Resize positional embeddings if needed
        with torch.no_grad():
            new_side = self.image_size // self.dofa.patch_size
            pe = sd.get("pos_embed", None)
            if pe is not None and pe.shape[1] != (new_side * new_side + 1):
                sd["pos_embed"] = _interp_pos_embed(pe, new_side)

        missing, unexpected = self.dofa.load_state_dict(sd, strict=False)
        if unexpected:
            logging.warning(f"Unexpected DOFA keys: {unexpected}")
        # (missing may include fc_norm/head by design)

        if freeze_dofa:
            for p in self.dofa.parameters():
                p.requires_grad = False

        self.embed_dim   = self.dofa.embed_dim     # 768
        self.patch_size  = self.dofa.patch_size    # 16
        self.grid_side   = self.image_size // self.patch_size
        self.out_channels = self.embed_dim         # REQUIRED by torchvision detectors

        # Small FPN pyramid (keep channels constant)
        self.down4 = nn.Conv2d(self.embed_dim, self.embed_dim, kernel_size=3, stride=2, padding=1)
        self.down5 = nn.Conv2d(self.embed_dim, self.embed_dim, kernel_size=3, stride=2, padding=1)
        self.down6 = nn.Conv2d(self.embed_dim, self.embed_dim, kernel_size=3, stride=2, padding=1)

        self.norm3 = nn.GroupNorm(32, self.embed_dim)
        self.norm4 = nn.GroupNorm(32, self.embed_dim)
        self.norm5 = nn.GroupNorm(32, self.embed_dim)
        self.norm6 = nn.GroupNorm(32, self.embed_dim)

        # RGB wavelengths (μm) for PIL order R,G,B
        self.rgb_wavelengths = [0.665, 0.560, 0.490]

    @torch.no_grad()
    def _tokens_from_dofa(self, x: torch.Tensor) -> torch.Tensor:
        """
        Extract patch tokens BEFORE pooling.
        Returns tokens [B, N, C] with N=(H/patch)^2, C=embed_dim.
        """
        w = torch.tensor(self.rgb_wavelengths, device=x.device).float()
        # Patch embedding (dynamic conv across wavelengths)
        x_tokens, _ = self.dofa.patch_embed(x, w)              # [B, N, C]
        # Add pos embed (skip cls)
        x_tokens = x_tokens + self.dofa.pos_embed[:, 1:, :]    # [B, N, C]
        # Prepend cls token
        cls_token = self.dofa.cls_token + self.dofa.pos_embed[:, :1, :]
        x = torch.cat((cls_token.expand(x_tokens.shape[0], -1, -1), x_tokens), dim=1)
        # Transformer blocks
        for block in self.dofa.blocks:
            x = block(x)
        # Remove cls
        return x[:, 1:, :]                                      # [B, N, C]

    def forward(self, x: torch.Tensor) -> "OrderedDict[str, torch.Tensor]":
        # x: [B,3,H,W] (H=W=image_size enforced by RCNN transform)
        B, C, H, W = x.shape
        assert H == self.image_size and W == self.image_size
        tokens = self._tokens_from_dofa(x)                      # [B, N, C]
        side = self.grid_side
        feat = tokens.permute(0, 2, 1).reshape(B, self.embed_dim, side, side)

        p3 = self.norm3(feat)            # stride ~16
        p4 = self.norm4(self.down4(p3))  # ~32
        p5 = self.norm5(self.down5(p4))  # ~64
        p6 = self.norm6(self.down6(p5))  # ~128

        return OrderedDict({"0": p3, "1": p4, "2": p5, "3": p6})


# -------------------------
# Build Faster R‑CNN
# -------------------------
def create_model(num_classes: int, image_size: int):
    backbone = DOFABackboneWrapper(image_size=image_size, freeze_dofa=False)

    anchor_generator = AnchorGenerator(
        sizes=((16,), (32,), (64,), (128,)),
        aspect_ratios=((0.5, 1.0, 2.0),) * 4
    )
    roi_pooler = MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'],
                                    output_size=7, sampling_ratio=2)

    model = FasterRCNN(
        backbone,
        num_classes=num_classes,
        rpn_anchor_generator=anchor_generator,
        box_roi_pool=roi_pooler,
        # lock size to avoid torchvision’s default 800‑px resize
        min_size=image_size,
        max_size=image_size,
        image_mean=[0.0, 0.0, 0.0],
        image_std=[1.0, 1.0, 1.0],
    )
    return model


detector = create_model(num_classes=4, image_size=224)

# # -------------------------
# # Train / Validate
# # -------------------------
# def train_one_epoch(model, optimizer, data_loader, device):
#     model.train()
#     total_loss = 0.0
#     steps = 0
#     for images, targets in tqdm(data_loader, desc="Training"):
#         if images is None:
#             continue
#         images = [img.to(device) for img in images]
#         targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

#         loss_dict = model(images, targets)
#         losses = sum(loss for loss in loss_dict.values())

#         optimizer.zero_grad(set_to_none=True)
#         losses.backward()
#         optimizer.step()

#         total_loss += losses.item()
#         steps += 1

#     return total_loss / max(1, steps)


# @torch.no_grad()
# def validate(model, data_loader, device):
#     model.eval()
#     metric = MeanAveragePrecision(box_format='xyxy', class_metrics=False)
#     for images, targets in tqdm(data_loader, desc="Validation"):
#         if images is None:
#             continue
#         images = [img.to(device) for img in images]
#         preds = model(images)
#         preds = [{k: v.to('cpu') for k, v in p.items()} for p in preds]
#         t_cpu = [{k: v.to('cpu') for k, v in t.items()} for t in targets]
#         metric.update(preds, t_cpu)
#     res = metric.compute()
#     return res.get('map', torch.tensor(0.)).item(), res.get('map_50', torch.tensor(0.)).item()


# def main(args):
#     os.makedirs(args.output_dir, exist_ok=True)
#     setup_logging(args.output_dir)
#     csv_logger = CSVLogger(os.path.join(args.output_dir, "results.csv"))

#     device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
#     logging.info(f"Using device: {device}")

#     # Data
#     train_dataset = BrickKilnDataset(args.train_path, 'train', args.input_size)
#     val_dataset   = BrickKilnDataset(args.val_path,   'val',   args.input_size)

#     train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
#                               num_workers=args.num_workers, collate_fn=collate_fn, pin_memory=True)
#     val_loader   = DataLoader(val_dataset,   batch_size=args.batch_size, shuffle=False,
#                               num_workers=args.num_workers, collate_fn=collate_fn, pin_memory=True)

#     # Model
#     num_classes = 4  # background + 3 kiln classes
#     model = create_model(num_classes=num_classes, image_size=args.input_size).to(device)

#     # Two‑group LR: lower on DOFA, higher on heads
#     backbone_params, head_params = [], []
#     for name, p in model.named_parameters():
#         if not p.requires_grad:
#             continue
#         if "dofa" in name:
#             backbone_params.append(p)
#         else:
#             head_params.append(p)

#     optimizer = AdamW([
#         {"params": backbone_params, "lr": args.backbone_lr},
#         {"params": head_params, "lr": args.head_lr},
#     ], weight_decay=args.weight_decay)

#     lr_scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs)

#     logging.info(f"RCNN transform min_size={model.transform.min_size}, max_size={model.transform.max_size}")

#     best_map50 = 0.0
#     for epoch in range(1, args.epochs + 1):
#         train_loss = train_one_epoch(model, optimizer, train_loader, device)
#         val_map, val_map50 = validate(model, val_loader, device)
#         lr_scheduler.step()

#         current_lr = optimizer.param_groups[0]['lr']
#         logging.info(f"Epoch {epoch} - Loss: {train_loss:.4f}, mAP: {val_map:.4f}, mAP@50: {val_map50:.4f} - lr: {current_lr:.6f}")
#         csv_logger.log(epoch, train_loss, val_map, val_map50, current_lr)

#         if val_map50 > best_map50:
#             best_map50 = val_map50
#             torch.save(model.state_dict(), os.path.join(args.output_dir, "best_model.pth"))
#             logging.info(f"Saved new best model (mAP@50={best_map50:.4f})")

#     logging.info("Training finished.")


# if __name__ == "__main__":
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--train_path', type=str, default="/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/processed_data/sentinel/sentinelkilndb_bechmarking_data/train")
#     parser.add_argument('--val_path',   type=str, default="/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/processed_data/sentinel/sentinelkilndb_bechmarking_data/val")
#     parser.add_argument('--test_path',  type=str, default="/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/processed_data/sentinel/sentinelkilndb_bechmarking_data/test")
#     parser.add_argument('--output_dir', type=str, default="/home/suruchi.hardaha/train_galelio/notebooks/work_dirs/dofa_train")
#     parser.add_argument('--device',     type=str, default="cuda:0")

#     # You can change this to any multiple of 16 (e.g., 224, 256, 272, ...).
#     parser.add_argument('--input_size', type=int, default=224)

#     parser.add_argument('--epochs',       type=int,   default=10)
#     parser.add_argument('--batch_size',   type=int,   default=16)
#     parser.add_argument('--num_workers',  type=int,   default=8)
#     parser.add_argument('--head_lr',      type=float, default=1e-4)
#     parser.add_argument('--backbone_lr',  type=float, default=1e-5)
#     parser.add_argument('--weight_decay', type=float, default=0.05)

#     args = parser.parse_args()
#     main(args)


Downloading: "https://hf.co/torchgeo/dofa/resolve/b8db318b64a90b9e085ec04ba8851233c5893666/dofa_base_patch16_224-a0275954.pth" to /home/rishabh.mondal/.cache/torch/hub/checkpoints/dofa_base_patch16_224-a0275954.pth
100%|██████████| 425M/425M [00:07<00:00, 57.6MB/s] 


In [5]:
detector

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
      Resize(min_size=(224,), max_size=224, mode='bilinear')
  )
  (backbone): DOFABackboneWrapper(
    (dofa): DOFA(
      (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (patch_embed): DOFAEmbedding(
        (weight_generator): TransformerWeightGenerator(
          (transformer_encoder): TransformerEncoder(
            (layers): ModuleList(
              (0): TransformerEncoderLayer(
                (self_attn): MultiheadAttention(
                  (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
                )
                (linear1): Linear(in_features=128, out_features=2048, bias=True)
                (dropout): Dropout(p=False, inplace=False)
                (linear2): Linear(in_features=2048, out_features=128, bias=True)
                (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    