In [1]:
# datasets/io.py
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple

import h5py
import numpy as np
from PIL import Image

__all__ = [
    "list_by_ext",
    "build_stem_map",
    "intersect_modalities",
    "read_h5_cube",
    "read_rgb_image",
]


def list_by_ext(root: Path, exts: Sequence[str]) -> List[Path]:
    exts = [e.lower() for e in exts]
    files = [p for p in root.iterdir() if p.is_file() and p.suffix.lower() in exts]
    files.sort()
    return files


def build_stem_map(root: Path, exts: Sequence[str]) -> Dict[str, Path]:
    """
    Build {stem: path} for a folder, preferring earlier extensions in `exts` on collisions.
    """
    root = Path(root)
    exts = [e.lower() for e in exts]
    by_stem: Dict[str, Path] = {}
    for p in root.iterdir():
        if not p.is_file():
            continue
        ext = p.suffix.lower()
        if ext not in exts:
            continue
        stem = p.stem
        if stem not in by_stem:
            by_stem[stem] = p
        else:
            # prefer higher priority ext (smaller index)
            old = by_stem[stem]
            if exts.index(ext) < exts.index(old.suffix.lower()):
                by_stem[stem] = p
    return by_stem


def intersect_modalities(maps: Mapping[str, Dict[str, Path]]) -> List[str]:
    """
    Given a dict of {modality: {stem: path}}, return stems present in ALL modalities.
    """
    it = iter(maps.values())
    common = set(next(it).keys())
    for d in it:
        common &= set(d.keys())
    stems = sorted(common)
    return stems


def read_h5_cube(path: Path, dataset_name: str = "cube") -> np.ndarray:
    """
    Read HDF5 cube as HxWxC float32. Accepts CxHxW as well and transposes.
    """
    with h5py.File(path, "r") as f:
        if dataset_name not in f:
            raise KeyError(f"'{dataset_name}' not in {path}. Keys: {list(f.keys())}")
        arr = np.array(f[dataset_name], dtype=np.float32)
    # accept common layouts
    if arr.ndim != 3:
        raise ValueError(f"Expected 3D cube, got {arr.shape} in {path}")
    # If data is C,H,W -> transpose to H,W,C
    if arr.shape[0] in (31, 61, 62, 448) and arr.shape[0] < arr.shape[-1]:
        arr = np.transpose(arr, (1, 2, 0))
    return arr  # H,W,C


def read_rgb_image(path: Path) -> np.ndarray:
    """
    Read RGB as float32 HxWx3 in [0,1] without per-image min-max normalization.
    """
    img = Image.open(path).convert("RGB")
    arr = np.asarray(img, dtype=np.float32) / 255.0
    return arr


def read_mosaic(path: Path) -> np.ndarray:
    """
    Read a mosaic saved as .npy.
    Returns HxWx1 float32 array in [0,1].
    """
    arr = np.load(path).astype(np.float32)

    # If it's 2D, add channel dim -> (H,W,1)
    if arr.ndim == 2:
        arr = arr[..., None]

    # Normalize if values look like 8-bit integers
    if arr.max() > 1.0:
        arr = arr / 255.0

    return arr


In [2]:
# datasets/pairing.py
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Sequence



@dataclass(frozen=True)
class ModalitySpec:
    root: Path
    exts: Sequence[str]


def build_index(
    specs: Dict[str, ModalitySpec],
    id_list_path: Optional[Path] = None,
) -> List[str]:
    """
    Build a list of sample IDs (stems) present in ALL modalities.
    Optionally constrain by an id list file (one stem per line).
    """
    maps = {name: build_stem_map(spec.root, spec.exts) for name, spec in specs.items()}
    stems = intersect_modalities(maps)

    if id_list_path is not None:
        keep = set(
            [
                ln.strip()
                for ln in Path(id_list_path).read_text().splitlines()
                if ln.strip()
            ]
        )
        stems = [s for s in stems if s in keep]

    if not stems:
        raise RuntimeError("No common stems across modalities (after filtering).")
    return stems, maps


In [3]:
# datasets/base.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

import torch
from torch.utils.data import Dataset

__all__ = ["HSIDataset", "JointTransform"]


class HSIDataset(Dataset):
    """
    Thin Dataset base that mirrors torchvision semantics:
    - Optionally accepts a 'transforms' callable that can operate jointly
      on input(s) and target(s).
    - Or separate 'transform' and 'target_transform'.
    """

    _repr_indent = 4

    def __init__(
        self,
        root: str,
        transforms: Optional[Callable] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> None:
        self.root = root
        if transforms is not None and (
            transform is not None or target_transform is not None
        ):
            raise ValueError(
                "Pass either `transforms` or (`transform` and/or `target_transform`), not both."
            )
        self.transform = transform
        self.target_transform = target_transform
        self.transforms = transforms

    # subclasses implement __getitem__, __len__

    def __repr__(self) -> str:
        head = "Dataset " + self.__class__.__name__
        body = [f"Number of datapoints: {len(self)}"]
        if self.root is not None:
            body.append(f"Root location: {self.root}")
        if hasattr(self, "transforms") and self.transforms is not None:
            body += [repr(self.transforms)]
        lines = [head] + [" " * self._repr_indent + line for line in body]
        return "\n".join(lines)


class JointTransform:
    """
    Helper to apply a single callable to (inputs, target) jointly.
    Your callable should take and return a dict of arrays/tensors.
    """

    def __init__(self, fn: Callable[[Dict[str, Any]], Dict[str, Any]]):
        self.fn = fn

    def __call__(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        return self.fn(batch)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self.fn})"


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# datasets/rgb_hsi.py
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple

import numpy as np
import torch

__all__ = ["HyperObjectDataset"]


class HyperObjectDataset(HSIDataset):
    """
    Returns a dict:
      {
        "input": "mosaic" (1,H,W) float32 or "rgb_2"  (3,H,W) float32,
        "output":  "cube"   (C,H,W) float32,
        "id":     str
      }
    """

    def __init__(
        self,
        track: int,
        data_root: str,
        train: bool = True,
        transforms: Optional[Callable] = None,
        submission: bool = False,
    ) -> None:
        super().__init__(root=data_root, transforms=transforms)
        self.track = track

        if track == 1:
            if submission:
                mosaic_path = ModalitySpec(
                    root=Path(f"{data_root}/test-private/mosaic"), exts=(".npy",)
                )
            else:
                mosaic_path = ModalitySpec(
                    root=Path(
                        f"{data_root}/{'train' if train else 'test-public'}/mosaic"
                    ),
                    exts=(".npy",),
                )
                hsi_61_path = ModalitySpec(
                    root=Path(
                        f"{data_root}/{'train' if train else 'test-public'}/hsi_61"
                    ),
                    exts=(".h5",),
                )
            (self.ids, self._maps) = build_index(
                {
                    "mosaic": mosaic_path,
                    "hsi": hsi_61_path,
                }
            )
        elif track == 2:
            specs = {}
            if submission:
                rgb_2_path = ModalitySpec(
                    root=Path(f"{data_root}/test-private/rgb_2"), exts=(".png", ".jpg")
                )
                specs["rgb_2"] = rgb_2_path
            else:
                rgb_2_path = ModalitySpec(
                    root=Path(
                        f"{data_root}/{'train' if train else 'test-public'}/rgb_2"
                    ),
                    exts=(".png", ".jpg"),
                )
                hsi_61_path = ModalitySpec(
                    root=Path(
                        f"{data_root}/{'train' if train else 'test-public'}/hsi_61"
                    ),
                    exts=(".h5",),
                )
                specs["rgb_2"] = rgb_2_path
                specs["hsi"] = hsi_61_path
            (self.ids, self._maps) = build_index(specs)

    def __len__(self) -> int:
        return len(self.ids)

    def _load_(self, stem: str):
        if "hsi" in self._maps:
            p_hsi = self._maps["hsi"][stem]
            cube = read_h5_cube(p_hsi, "cube")  # (H,W,C)
            cube_t = torch.from_numpy(np.transpose(cube, (2, 0, 1)))  # C,H,W
        else:
            cube_t = torch.empty(0)  # Placeholder

        if self.track == 1:
            p_mosaic = self._maps["mosaic"][stem]
            mosaic = read_mosaic(p_mosaic)  # (H,W,1) float32 [0,1]
            mosaic_t = torch.from_numpy(np.transpose(mosaic, (2, 0, 1)))  # 1,H,W
            return mosaic_t, cube_t
        elif self.track == 2:
            p_rgb_2 = self._maps["rgb_2"][stem]
            rgb_2 = read_rgb_image(p_rgb_2)  # (H,W,3) float32 [0,1]
            rgb_2_t = torch.from_numpy(np.transpose(rgb_2, (2, 0, 1)))  # C,H,W
            return rgb_2_t, cube_t

    def __getitem__(self, idx: int):
        stem = self.ids[idx]
        input_data, output_data = self._load_(stem)

        # Apply transforms
        if self.transforms is not None:
            # joint transform expects dict
            out = self.transforms(
                {"input_data": input_data, "output_data": output_data, "id": stem}
            )
            input_data, output_data = out["input_data"], out["output_data"]

        return {
            "input": input_data,  # either mosaic or rgb_2 depending on track
            "output": output_data,  # hsi (61 bands)
            "id": stem,
        }


In [5]:
import os
import zipfile
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import Subset
import pandas as pd
from datetime import datetime

# from baselines import mstpp_up
# from SPECAT import SPECAT
from train_code.architecture import MST_Plus_Plus

# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
print(f"Using device: {device}")

TARGET_IDS = {
    "Category-1_a_0007",
    "Category-2_a_0009",
    "Category-3_a_0035",
    "Category-4_a_0018",
}

data_dir = "/ssd7/ICASSP_2026_Hyper-Object_Challenge/track2/dataset"
model_path = "/ssd7/ICASSP_2026_Hyper-Object_Challenge/track2/MST-plus-plus/exp/mst_plus_plus_sr2x/2025_10_08_16_56_56/MSTPP_3_35.249081.pth"
submission_files_dir = "submission/files"
submission_zip_path = f"submission/zip/submission_{datetime.now().strftime('%Y%m%d%H%M')}.zip"

# MODEL_NAME = 'MST_Plus_Plus_Up'
# MODEL_NAME = "SPECAT"
MODEL_NAME = "MST_Plus_Plus"
UPSCALE_FACTOR = 2
BATCH_SIZE = 1


def create_submission():
    """
    Generates predictions and packages them for Kaggle submission.
    """
    processed_ids = []
    os.makedirs(submission_files_dir, exist_ok=True)
    print(f"Individual prediction files will be saved in: '{submission_files_dir}'")

    full_ds_test = HyperObjectDataset(
        data_root=f"{data_dir}",
        track=2,
        train=False,
        submission=True,
    )

    if TARGET_IDS is not None:
        print(f"Filtering dataset to {len(TARGET_IDS)} specific IDs for testing.")
        desired_indices = [
            i for i, sample in enumerate(full_ds_test) if sample["id"] in TARGET_IDS
        ]
        submission_dataset = Subset(full_ds_test, desired_indices)
    else:
        print("Processing the full test dataset for final submission.")
        submission_dataset = full_ds_test

    test_loader = torch.utils.data.DataLoader(
        dataset=submission_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=True if device == "cuda" else False,
    )
    print(f"DataLoader created for {len(submission_dataset)} samples.")

    print(f"Loading checkpoint from: {model_path}")
    # model = SPECAT(
    #     in_channels=3,
    #     dim=61,
    #     stage=1,
    #     num_blocks=[2, 1],
    #     attention_type=opt.attention_type,
    # ).cuda()
    model = MST_Plus_Plus(in_channels=3, out_channels=61, n_feat=31, stage=3, upscale_factor=UPSCALE_FACTOR)#.cuda()
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}, strict=True)
    # model.load_state_dict(checkpoint)
    model.to(device)
    model.eval()
    model.return_hr = True
    print(f"Model '{MODEL_NAME}' loaded and set to HR evaluation mode.")

    print("\nGenerating predictions...")
    for data in tqdm(test_loader, desc="Generating predictions"):
        x = data["input"].float().to(device)
        sample_id = data["id"][0]
        processed_ids.append(sample_id)

        with torch.no_grad():
            pred_hr_tensor = model(x)

        pred_np = pred_hr_tensor.squeeze(0).cpu().numpy()
        pred_hwc = np.transpose(pred_np, (1, 2, 0))

        pred_hwc_clipped = np.clip(pred_hwc, 0.0, 1.0)

        output_npz_path = os.path.join(submission_files_dir, f"{sample_id}.npz")
        np.savez_compressed(output_npz_path, cube=pred_hwc_clipped)

    print(f"\nAll {len(submission_dataset)} predictions saved.")

    print("\nCreating submission.csv...")
    submission_df = pd.DataFrame({"id": processed_ids, "prediction": 0})
    csv_path = os.path.join(submission_files_dir, "submission.csv")
    submission_df.to_csv(csv_path, index=False)
    print(f"submission.csv created with {len(submission_df)} entries.")

    print(f"Creating submission zip file at: '{submission_zip_path}'")
    with zipfile.ZipFile(submission_zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
        zf.write(csv_path, arcname="submission.csv")

        files_to_zip = [f"{sid}.npz" for sid in processed_ids]
        for filename in tqdm(files_to_zip, desc="Zipping .npz files"):
            file_path = os.path.join(submission_files_dir, filename)
            if os.path.exists(file_path):
                zf.write(file_path, arcname=filename)

    print("\n" + "=" * 50)
    print("Submission process complete!")
    print(f"File to submit to Kaggle: {submission_zip_path}")
    print("=" * 50)


if __name__ == "__main__":
    create_submission()


Using device: cpu
Individual prediction files will be saved in: 'submission/files'
Filtering dataset to 4 specific IDs for testing.
DataLoader created for 4 samples.
Loading checkpoint from: /ssd7/ICASSP_2026_Hyper-Object_Challenge/track2/MST-plus-plus/exp/mst_plus_plus_sr2x/2025_10_08_16_56_56/MSTPP_3_35.249081.pth
Model 'MST_Plus_Plus' loaded and set to HR evaluation mode.

Generating predictions...


Generating predictions: 100%|██████████| 4/4 [01:00<00:00, 15.14s/it]



All 4 predictions saved.

Creating submission.csv...
submission.csv created with 4 entries.
Creating submission zip file at: 'submission/zip/submission_202510082142.zip'


Zipping .npz files: 100%|██████████| 4/4 [00:24<00:00,  6.19s/it]


Submission process complete!
File to submit to Kaggle: submission/zip/submission_202510082142.zip





In [13]:
message = "MST++ with MRAE in 3 epochs"
!kaggle competitions submit -c "2026-icassp-hyper-object-challenge-track-2" -f "{submission_zip_path}" -m "{message}"

100%|████████████████████████████████████████| 828M/828M [00:31<00:00, 27.7MB/s]
Successfully submitted to 2026 ICASSP Hyper-Object Challenge: Track 2