In [2]:
import os
from pathlib import Path
from typing import Optional, Callable, Dict, Any, List

import numpy as np
import torch
from torch.utils.data import Dataset


# ----------------------------
# helpers
# ----------------------------
def _load_stats_npz(path: Optional[Path]):
    """
    Load stats npz produced by compute_feature_stats_safe().
    Returns dict with 'columns', 'mean', 'std' or None if missing.
    """
    if path is None:
        return None
    path = Path(path)
    if not path.exists():
        return None
    z = np.load(path, allow_pickle=True)
    cols = [c for c in z["columns"]]
    mean = np.array(z["mean"]).astype(np.float32)
    std = np.array(z["std"]).astype(np.float32)
    std[~np.isfinite(std)] = 1.0
    std[std == 0] = 1.0
    return {"columns": cols, "mean": mean, "std": std}


def _align_stats_to_current_columns(stats: Dict[str, Any], current_cols: List[str]):
    """
    Build per-feature mean/std vectors aligned to current column order.
    Unknown columns -> mean=0, std=1 (no-op).
    """
    if stats is None:
        mean = np.zeros(len(current_cols), dtype=np.float32)
        std = np.ones(len(current_cols), dtype=np.float32)
        return mean, std

    ref_cols = stats["columns"]
    ref_mean = stats["mean"]
    ref_std = stats["std"]

    ref_index = {c: i for i, c in enumerate(ref_cols)}
    mean = np.zeros(len(current_cols), dtype=np.float32)
    std = np.ones(len(current_cols), dtype=np.float32)
    for j, c in enumerate(current_cols):
        i = ref_index.get(c, None)
        if i is not None:
            mean[j] = ref_mean[i]
            std[j] = ref_std[i] if (np.isfinite(ref_std[i]) and ref_std[i] != 0) else 1.0
    return mean, std


def _center_crop_chw_np(arr: np.ndarray, size: Optional[int]) -> np.ndarray:
    """
    Center-crop a CHW numpy array to (C, size, size).
    If size is None, returns unchanged.
    """
    if size is None:
        return arr
    C, H, W = arr.shape
    hh = min(size, H)
    ww = min(size, W)
    y0 = (H - hh) // 2
    x0 = (W - ww) // 2
    return arr[:, y0 : y0 + hh, x0 : x0 + ww]


def _normalize_image(arr: np.ndarray, mode: Optional[str]) -> np.ndarray:
    """
    arr: CHW numpy float32
    mode: 'median' | 'L2' | None
    """
    if mode is None:
        return arr
    x = arr.copy()
    if mode.lower() == "median":
        # per-channel median/std
        for c in range(x.shape[0]):
            plane = x[c]
            med = np.median(plane)
            plane = plane - med
            std = float(plane.std())
            if not np.isfinite(std) or std <= 1e-8:
                std = 1.0
            x[c] = plane / std
        return x
    if mode.lower() == "l2":
        denom = float(np.linalg.norm(x.ravel(), ord=2))
        if not np.isfinite(denom) or denom <= 1e-8:
            denom = 1.0
        return x / denom
    return arr


# ----------------------------
# Dataset
# ----------------------------
class MultiModalDataset(Dataset):
    """
    One row per object (manifest row). For each item:
      - Select all events with dt <= horizon (or all if horizon is None)
      - Pick the last such event -> return its image & metadata
      - Return the full event sequence up to (and including) that last event

    Returns dict:
      {
        'events':   FloatTensor [T, Fe],  (sequence until horizon)
        'events_mask': BoolTensor [T],    (all True; kept for compatibility with padding)
        'image':    FloatTensor [3, H, W] (triplet at last event)
        'metadata': FloatTensor [Fm],     (meta row at last event)
        'label':    LongTensor [],        (class id)
        'label_str': str,
        'obj_id':   str,
        'dt':       FloatTensor [T],      (optional convenience: dt sequence)
      }
    """

    def __init__(
        self,
        manifest_df,
        horizon: Optional[float] = None,
        event_stats_path: Optional[Path] = None,
        meta_stats_path: Optional[Path] = None,
        normalize_events: bool = True,
        normalize_meta: bool = True,
        image_norm: Optional[str] = "median",  # 'median' | 'L2' | None
        crop_size: Optional[int] = None,  # center crop to (crop_size, crop_size)
        augment: bool = False,
        image_transform: Optional[Callable] = None,  # e.g. torchvision transforms for CHW tensor
        return_dt: bool = True,
    ):
        self.df = manifest_df.reset_index(drop=True).copy()
        self.horizon = horizon
        self.normalize_events = normalize_events
        self.normalize_meta = normalize_meta
        self.crop_size = crop_size
        self.image_norm = image_norm
        self.augment = augment
        self.image_transform = image_transform
        self.return_dt = return_dt

        # Load stats
        self.event_stats = _load_stats_npz(event_stats_path)
        self.meta_stats = _load_stats_npz(meta_stats_path)

    def __len__(self):
        return len(self.df)

    def _select_last_idx_within_horizon(self, dt: np.ndarray) -> int:
        """
        dt: array of event times since first photometry
        Returns last index <= horizon; if horizon is None use last; if none <= horizon use 0.
        """
        if self.horizon is None or not np.isfinite(self.horizon):
            return len(dt) - 1
        m = dt <= float(self.horizon)
        if not np.any(m):
            return 0
        return int(np.where(m)[0][-1])

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        filepath = Path(row.filepath)
        z = np.load(filepath, allow_pickle=True)

        images = z["images"]  # (T, 3, H, W)
        events = z["event_data"]  # (T, Fe)
        meta = z["meta_data"]  # (T, Fm)
        ecols = [c for c in z["event_columns"]]
        mcols = [c for c in z["meta_columns"]]
        label = int(row.label)
        label_str = row.label_str if "label_str" in row else ""
        obj_id = row.object_id if "object_id" in row else filepath.stem

        # select up to horizon
        dt_idx = ecols.index("dt")
        dt = events[:, dt_idx].astype(np.float32)

        last_idx = self._select_last_idx_within_horizon(dt)
        T_h = last_idx + 1

        events_seq = events[:T_h].astype(np.float32)  # [T_h, Fe]
        meta_last = meta[last_idx].astype(np.float32)  # [Fm]
        image_last = images[last_idx].astype(np.float32)  # [3, H, W]

        # normalize events / meta (align stats by column name)
        if self.normalize_events:
            e_mean, e_std = _align_stats_to_current_columns(self.event_stats, ecols)
            # broadcast [Fe] over first dim
            mask_valid = np.isfinite(events_seq)
            events_seq = (events_seq - e_mean) / e_std

            sent_mask = events_seq == -999.0
            events_seq[~mask_valid] = 0.0
            events_seq[sent_mask] = -999.0

        if self.normalize_meta:
            m_mean, m_std = _align_stats_to_current_columns(self.meta_stats, mcols)
            mask_valid = np.isfinite(meta_last)
            meta_last = (meta_last - m_mean) / m_std
            sent_mask = meta_last == -999.0
            meta_last[~mask_valid] = 0.0
            meta_last[sent_mask] = -999.0

        # crop + image normalization
        image_last = _center_crop_chw_np(image_last, self.crop_size)
        image_last = _normalize_image(image_last, self.image_norm)

        # -> torch tensors
        events_t = torch.from_numpy(events_seq)  # [T_h, Fe]
        meta_t = torch.from_numpy(meta_last)  # [Fm]
        image_t = torch.from_numpy(image_last)  # [3, H, W]
        label_t = torch.tensor(label, dtype=torch.long)

        # optional augmentation on image
        if self.augment and self.image_transform is not None:
            # transforms expect CHW float tensor
            image_t = self.image_transform(image_t)

        sample = {
            "events": events_t,
            "events_mask": torch.ones((events_t.shape[0],), dtype=torch.bool),
            "image": image_t,
            "metadata": meta_t,
            "label": label_t,
            "label_str": label_str,
            "obj_id": str(obj_id),
        }
        if self.return_dt:
            sample["dt"] = torch.from_numpy(dt[:T_h])

        return sample

    # -------- collate utility for variable-length sequences --------
    @staticmethod
    def pad_collate(batch: List[Dict[str, Any]], pad_value: float = 0.0):
        """
        Pads 'events' to the max T in batch and stacks everything else.
        Returns:
          events: FloatTensor [B, Tmax, Fe]
          events_mask: BoolTensor [B, Tmax]
          image: FloatTensor [B, 3, H, W]
          metadata: FloatTensor [B, Fm]
          label: LongTensor [B]
          label_str: list[str]
          obj_id: list[str]
          dt: FloatTensor [B, Tmax]  (if present)
        """
        B = len(batch)
        Tmax = max(x["events"].shape[0] for x in batch)
        Fe = batch[0]["events"].shape[1]

        events_pad = batch[0]["events"].new_full((B, Tmax, Fe), pad_value)
        mask_pad = torch.zeros((B, Tmax), dtype=torch.bool)
        if "dt" in batch[0]:
            dt_pad = batch[0]["dt"].new_full((B, Tmax), pad_value)
        else:
            dt_pad = None

        images = []
        metas = []
        labels = []
        label_strs = []
        obj_ids = []

        for i, ex in enumerate(batch):
            T = ex["events"].shape[0]
            events_pad[i, :T] = ex["events"]
            mask_pad[i, :T] = True
            if dt_pad is not None:
                dt_pad[i, :T] = ex["dt"]
            images.append(ex["image"])
            metas.append(ex["metadata"])
            labels.append(ex["label"])
            label_strs.append(ex["label_str"])
            obj_ids.append(ex["obj_id"])

        images = torch.stack(images, dim=0)
        metas = torch.stack(metas, dim=0)
        labels = torch.stack(labels, dim=0)

        out = {
            "events": events_pad,
            "events_mask": mask_pad,
            "image": images,
            "metadata": metas,
            "label": labels,
            "label_str": label_strs,
            "obj_id": obj_ids,
        }
        if dt_pad is not None:
            out["dt"] = dt_pad
        return out

In [3]:
import pandas as pd
from torch.utils.data import DataLoader

# manifest produced by build_all_preprocessed(cfg)
manifest = pd.read_csv("/home/feradofogo/multimodal_preprocessed/built_all.csv")

ds = MultiModalDataset(
    manifest_df=manifest,
    horizon=100.0,  # days since first photometry
    event_stats_path=Path("/home/feradofogo/multimodal_preprocessed/feature_stats_event.npz"),
    meta_stats_path=Path("/home/feradofogo/multimodal_preprocessed/feature_stats_meta.npz"),
    image_norm="median",
    crop_size=63,  # or None to keep full cutout size
    augment=False,
    image_transform=None,  # todo: plug a torchvision pipeline
)

loader = DataLoader(ds, batch_size=16, shuffle=True, collate_fn=MultiModalDataset.pad_collate)

batch = next(iter(loader))
print(batch["events"].shape, batch["image"].shape, batch["metadata"].shape)
# should give torch.Size([16, Tmax, Fe]) torch.Size([16, 3, H, W]) torch.Size([16, Fm])

torch.Size([16, 108, 14]) torch.Size([16, 3, 63, 63]) torch.Size([16, 46])


In [5]:
batch.keys()

dict_keys(['events', 'events_mask', 'image', 'metadata', 'label', 'label_str', 'obj_id', 'dt'])

In [8]:
ds.__getitem__(0)

{'events': tensor([[-1.0967, -0.1433,  0.6397, -1.8856,  1.9521, -0.9138,  1.0431, -0.2659,
           0.0000,  0.0000,  0.0000,  0.0000, -0.6341, -0.2331],
         [-1.0953, -0.1276,  0.6397, -1.6093,  1.7139, -0.9138,  1.0431, -0.2659,
           0.0000,  0.0000,  0.0000,  0.0000, -0.6341, -0.2331],
         [-1.0952, -0.1417, -1.0048, -1.7877,  1.6806,  1.0943, -0.9586, -0.2659,
              nan,     nan,  0.0000,  0.0000,  1.5770, -0.2331],
         [-1.0939, -0.1278,  0.6397, -1.4104,  0.7560, -0.9138,  1.0431, -0.2659,
           0.0000,  0.0000,  0.0000,  0.0000, -0.6341, -0.2331],
         [-1.0939, -0.1423, -1.0048, -1.6770,  1.4607,  1.0943, -0.9586, -0.2659,
              nan,     nan,  0.0000,  0.0000,  1.5770, -0.2331],
         [-1.0927, -0.1288, -1.0048, -1.4137,  1.4790,  1.0943, -0.9586, -0.2659,
              nan,     nan,  0.0000,  0.0000,  1.5770, -0.2331],
         [-1.0926, -0.1421,  0.6397, -1.2048,  0.9509, -0.9138,  1.0431, -0.2659,
           0.0000,  0.0000

In [10]:
filepath = Path("/home/feradofogo/multimodal_preprocessed/all/ZTF19aaxzdtw.npz")
z = np.load(filepath, allow_pickle=True)

images = z["images"]  # (T, 3, H, W)
events = z["event_data"]  # (T, Fe)
meta = z["meta_data"]  # (T, Fm)
ecols = [c for c in z["event_columns"]]
mcols = [c for c in z["meta_columns"]]
# label  = int(row.label)
# label_str = row.label_str if 'label_str' in row else ''
# obj_id = row.object_id if 'object_id' in row else filepath.stem

# select up to horizon
dt_idx = ecols.index("dt")
dt = events[:, dt_idx].astype(np.float32)

In [13]:
# todo: keep only dt, dt_prev, logflux, logflux_err, build one-hot band; drop colors

array(['dt', 'dt_prev', 'band_id', 'logflux', 'logflux_err', 'band_ztfg',
       'band_ztfr', 'band_ztfi', 'g_r', 'g_r_err', 'r_i', 'r_i_err',
       'has_g_r', 'has_r_i'], dtype='<U11')

In [16]:
z["event_columns"]

array(['dt', 'dt_prev', 'band_id', 'logflux', 'logflux_err', 'band_ztfg',
       'band_ztfr', 'band_ztfi', 'g_r', 'g_r_err', 'r_i', 'r_i_err',
       'has_g_r', 'has_r_i'], dtype='<U11')

In [51]:
col_to_idx = {name: i for i, name in enumerate(ecols)}
cols_to_keep = ["dt", "dt_prev", "logflux", "logflux_err"]
band_cols = ["band_ztfg", "band_ztfr", "band_ztfi"]
idxs = [col_to_idx[c] for c in cols_to_keep]

events_subset = events[:, idxs].astype(np.float32)

In [52]:
events_subset.shape

(67, 7)

In [55]:
events_subset[2]

array([1.0430189 , 0.09318638, 1.351563  , 0.0906086 , 1.        ,
       0.        , 0.        ], dtype=float32)

In [43]:
_BAND_OH = np.eye(3, dtype=np.float32)

In [44]:
_BAND_OH

array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

In [48]:
_BAND_OH = np.eye(3, dtype=np.float32)


def build_event_tensor(arr):
    dt = np.log1p(arr[:, 0])
    dt_prev = np.log1p(arr[:, 1])
    logf, logfe = arr[:, 3], arr[:, 4]
    oh = _BAND_OH[arr[:, 2].astype(np.int64)]
    vec4 = np.stack([dt, dt_prev, logf, logfe], 1)
    return torch.from_numpy(np.concatenate([vec4, oh], 1))

array([[1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 0., 1.]], dtype=float32)