In [1]:
import os
import numpy as np
from lsst.daf.butler import Butler
import lsst.geom as geom

def get_one_deep_coadd_ref(butler, ra, dec, band, datasetType="deep_coadd"):
    refs = list(butler.query_datasets(
        datasetType,
        where="band.name = band AND patch.region OVERLAPS POINT(ra, dec)",
        bind={"band": band, "ra": ra, "dec": dec},
        with_dimension_records=True,
        order_by=["patch.tract"],
    ))
    if not refs:
        raise RuntimeError(f"No {datasetType} found for band={band} at ra,dec={ra},{dec}")
    return refs[0]

def load_patch_exposures(
    ra, dec,
    bands=("u","g","r","i","z","y"),
    repo="dp1",
    collection="LSSTComCam/DP1",
    datasetType="deep_coadd",
):
    butler = Butler(repo, collections=collection)

    ref0 = get_one_deep_coadd_ref(butler, ra, dec, bands[0], datasetType=datasetType)

    # Convert DataCoordinate -> plain python dict safely
    base_dataId = dict(ref0.dataId.mapping)


    exps = {}
    for b in bands:
        dataId = dict(base_dataId)          # normal dict copy
        dataId["band"] = b                  # override just the band
        exps[b] = butler.get(datasetType, dataId=dataId)

    wcs_full = exps[bands[0]].getWcs()
    return exps, wcs_full, base_dataId


def tile_patch_and_save(
    exps, wcs_full,
    out_dir,
    tile_size=512,
    stride=256,
    bands=("u","g","r","i","z","y"),
    max_tiles=None,
):
    os.makedirs(out_dir, exist_ok=True)

    # assume all bands share same array shape
    H, W = exps[bands[0]].image.array.shape
    n_saved = 0

    # iterate tile upper-left corners
    for y0 in range(0, H - tile_size + 1, stride):
        for x0 in range(0, W - tile_size + 1, stride):
            bbox = geom.BoxI(geom.PointI(x0, y0), geom.ExtentI(tile_size, tile_size))
            # cutout-local WCS: pixel origin shifted so (0,0) corresponds to (x0,y0) in full patch
            wcs_local = wcs_full.copyAtShiftedPixelOrigin(geom.Extent2D(-x0, -y0))

            imgs = []
            vars_ = []
            masks = []
            for b in bands:
                exp = exps[b]

                # slice arrays directly (fast)
                img = exp.image.array[y0:y0+tile_size, x0:x0+tile_size].astype(np.float32)
                var = exp.variance.array[y0:y0+tile_size, x0:x0+tile_size].astype(np.float32)
                msk = exp.mask.array[y0:y0+tile_size, x0:x0+tile_size].astype(np.int32)

                imgs.append(img)
                vars_.append(var)
                masks.append(msk)

            imgs = np.stack(imgs, axis=0)   # [B,H,W]
            vars_ = np.stack(vars_, axis=0) # [B,H,W]
            masks = np.stack(masks, axis=0) # [B,H,W]

            # minimal WCS serialization: store CD + CRPIX + CRVAL from the local WCS
            # (enough to reconstruct later; you can expand this if needed)
            md = wcs_local.getFitsMetadata()
            wcs_hdr = {k: md.getScalar(k) for k in md.names()}

            fn = os.path.join(out_dir, f"tile_x{x0:05d}_y{y0:05d}.npz")
            np.savez_compressed(
                fn,
                img=imgs,
                var=vars_,
                mask=masks,
                wcs_hdr=wcs_hdr,
                x0=np.int32(x0),
                y0=np.int32(y0),
            )
            n_saved += 1
            if (max_tiles is not None) and (n_saved >= max_tiles):
                return n_saved
    return n_saved

# ---- Example: ECDFS center (rough) ----
ra_ecdfs  = 53.16   # deg (rough)
dec_ecdfs = -28.10    # deg (rough)

bands = ("u","g","r","i","z","y")  # start small; add u,z later
exps, wcs_full, dataId0 = load_patch_exposures(
    ra_ecdfs, dec_ecdfs, bands=bands,
    repo="dp1", collection="LSSTComCam/DP1",
    datasetType="deep_coadd",
)

print("Using patch:", dataId0)

n = tile_patch_and_save(
    exps, wcs_full,
    out_dir="../data/rubin_tiles_ecdfs",
    tile_size=512,
    stride=256,
    bands=bands,
    max_tiles=200,   # just to start
)
print("Saved tiles:", n)


Using patch: {'band': 'u', 'skymap': 'lsst_cells_v1', 'tract': 5063, 'patch': 14}
Saved tiles: 144


# Adding Euclid and all other Rubin bands:

In [5]:
import os, glob
import numpy as np
import astropy.units as u
from astropy.io.fits import Header
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord
from astroquery.ipac.irsa import Irsa

from astropy.nddata import Cutout2D
from astropy.io import fits
import fsspec
from tqdm.auto import tqdm


# ---- helpers ----
def wcs_from_hdr_dict(wcs_hdr: dict) -> WCS:
    hdr = Header()
    for k, v in wcs_hdr.items():
        if isinstance(v, np.generic):
            v = v.item()
        hdr[k] = v
    return WCS(hdr)

def wcs_to_hdr_dict(wcs: WCS) -> dict:
    hdr = wcs.to_header(relax=True)
    return {k: hdr[k] for k in hdr.keys()}

def rubin_tile_center_radec(rubin_npz_path: str):
    f = np.load(rubin_npz_path, allow_pickle=True)
    img = f["img"]  # [B,H,W]
    H, W = img.shape[-2], img.shape[-1]
    w = wcs_from_hdr_dict(f["wcs_hdr"].item())
    cx, cy = (W - 1) / 2.0, (H - 1) / 2.0
    ra, dec = w.all_pix2world(cx, cy, 0)
    return float(ra), float(dec), (H, W)

# ---- Euclid cutout loader (FIXED: query_sia may already return a Table) ----
def load_euclid_cutouts(
    ra, dec, size_arcsec,
    bands=("VIS", "Y", "J", "H"),
    collection="euclid_DpdMerBksMosaic",
    radius_arcsec=60,
):
    coord = SkyCoord(ra=ra * u.deg, dec=dec * u.deg, frame="icrs")

    q = Irsa.query_sia(pos=(coord, radius_arcsec * u.arcsec), collection=collection)
    tab = q.to_table() if hasattr(q, "to_table") else q  # <-- FIX

    if len(tab) == 0:
        return {b: None for b in bands}, {}

    if "dataproduct_subtype" in tab.colnames:
        tab = tab[tab["dataproduct_subtype"] == "science"]

    out = {b: None for b in bands}
    wcs_out = {}

    for b in bands:
        rows = tab[tab["energy_bandpassname"] == b] if "energy_bandpassname" in tab.colnames else tab
        if len(rows) == 0:
            continue

        url = rows["access_url"][0]

        with fsspec.open(url, "rb") as f:
            with fits.open(f, memmap=False) as hdul:
                wcs0 = WCS(hdul[0].header)
                cut = Cutout2D(hdul[0].data, coord, size_arcsec * u.arcsec, wcs=wcs0)
                out[b] = np.array(cut.data, dtype=np.float32)
                wcs_out[b] = cut.wcs  # cutout-aware WCS

    return out, wcs_out


def make_euclid_tiles_from_rubin_tiles(
    rubin_dir,
    out_vis_dir="euclid_tiles_VIS",
    out_nisp_dir="euclid_tiles_NISP",
    size_arcsec=102.4,  # matches Rubin 512 * 0.2"/pix
    bands_vis=("VIS",),
    bands_nisp=("Y","J","H"),
    collection="euclid_DpdMerBksMosaic",
    radius_arcsec=120,          # search radius for SIA query
    max_tiles=50,
):
    os.makedirs(out_vis_dir, exist_ok=True)
    os.makedirs(out_nisp_dir, exist_ok=True)

    rubin_files = sorted(glob.glob(os.path.join(rubin_dir, "*.npz")))
    if not rubin_files:
        raise ValueError(f"No Rubin NPZ tiles found in {rubin_dir}")

    n_done = 0
    
    for rpath in tqdm(rubin_files[:max_tiles], desc="Fetching Euclid cutouts", unit="tile"):
        tile_id = os.path.splitext(os.path.basename(rpath))[0]
        ra, dec, _ = rubin_tile_center_radec(rpath)

        # --- VIS cutout ---
        out_vis, wcs_vis = load_euclid_cutouts(
            ra, dec, size_arcsec,
            bands=bands_vis,
            collection=collection,
            radius_arcsec=radius_arcsec,
        )

        if out_vis.get("VIS") is None:
            print(f"[skip] {tile_id}: no VIS found at ra,dec={ra:.6f},{dec:.6f}")
            continue

        vis_img = out_vis["VIS"][None, ...].astype(np.float32)   # [1,H,W]
        vis_var = np.ones_like(vis_img, dtype=np.float32)        # placeholder
        vis_msk = np.zeros_like(vis_img, dtype=np.int32)         # placeholder

        vis_hdr = wcs_to_hdr_dict(wcs_vis["VIS"])

        np.savez_compressed(
            os.path.join(out_vis_dir, f"{tile_id}.npz"),
            img=vis_img,
            var=vis_var,
            mask=vis_msk,
            wcs_hdr=vis_hdr,
            ra_center=np.float64(ra),
            dec_center=np.float64(dec),
            size_arcsec=np.float32(size_arcsec),
            source_rubin=rpath,
        )

        # --- NISP cutouts (Y/J/H) ---
        out_n, wcs_n = load_euclid_cutouts(
            ra, dec, size_arcsec,
            bands=bands_nisp,
            collection=collection,
            radius_arcsec=radius_arcsec,
        )

        n_imgs, n_hdrs = [], []
        have_all = True
        for b in bands_nisp:
            if out_n.get(b) is None:
                have_all = False
                break
            n_imgs.append(out_n[b].astype(np.float32))
            n_hdrs.append(wcs_to_hdr_dict(wcs_n[b]))

        if not have_all:
            print(f"[warn] {tile_id}: missing one of NISP {bands_nisp} at ra,dec={ra:.6f},{dec:.6f} (still saved VIS)")
            n_done += 1
            continue

        nisp_img = np.stack(n_imgs, axis=0)
        nisp_var = np.ones_like(nisp_img, dtype=np.float32)  # placeholder
        nisp_msk = np.zeros_like(nisp_img, dtype=np.int32)   # placeholder

        np.savez_compressed(
            os.path.join(out_nisp_dir, f"{tile_id}.npz"),
            img=nisp_img,
            var=nisp_var,
            mask=nisp_msk,
            wcs_hdr=n_hdrs,  # list of dicts
            bands=np.array(list(bands_nisp)),
            ra_center=np.float64(ra),
            dec_center=np.float64(dec),
            size_arcsec=np.float32(size_arcsec),
            source_rubin=rpath,
        )

        n_done += 1
        print(f"[ok] {tile_id}: saved VIS + NISP at ra,dec={ra:.6f},{dec:.6f}")

    print("Done. Created Euclid tiles for:", n_done, "Rubin tiles")


# ---- run it ----
make_euclid_tiles_from_rubin_tiles(
    rubin_dir="../data/rubin_tiles_ecdfs",
    out_vis_dir="../data/euclid_tiles_VIS",
    out_nisp_dir="../data/euclid_tiles_NISP",
    size_arcsec=102.4,
    collection="euclid_DpdMerBksMosaic",
    radius_arcsec=120,
    max_tiles=50,
)


[ok] tile_x00000_y00000: saved VIS + NISP at ra,dec=54.018000,-28.336532


KeyboardInterrupt: 

In [1]:
# --- Rubin + Euclid tile reader (separate NPZ per survey/instrument), with safe batching ---
import os, glob
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from astropy.io.fits import Header
from astropy.wcs import WCS

# Bands (as requested)
bands_rubin  = ["u","g","r","i","z","y"]
bands_euclid = ["VIS","Y","J","H"]

def wcs_from_hdr_dict(wcs_hdr: dict) -> WCS:
    hdr = Header()
    for k, v in wcs_hdr.items():
        if isinstance(v, np.generic):
            v = v.item()
        hdr[k] = v
    return WCS(hdr)

# Rubin mask is a bitmask; set these later once you decide which planes to reject.
BAD_BITS_RUBIN = 0

def valid_from_bitmask(mask_int: torch.Tensor, bad_bits: int) -> torch.Tensor:
    if bad_bits == 0:
        return torch.ones_like(mask_int, dtype=torch.bool)
    return (mask_int & bad_bits) == 0

class SeparateSurveyTileDataset(Dataset):
    """
    Reads NPZ tiles saved like your Rubin code:
      img: [B,H,W] float32
      var: [B,H,W] float32
      mask: [B,H,W] int32
      wcs_hdr: dict (object)
      x0,y0: optional

    Works for multiple roots (Rubin + Euclid). Keeps files separate (no merging).
    """
    def __init__(self, rubin_dir=None, euclid_dir=None,
                 bands_rubin=bands_rubin, bands_euclid=bands_euclid):
        self.items = []
        if rubin_dir:
            for fn in sorted(glob.glob(os.path.join(rubin_dir, "*.npz"))):
                self.items.append(("rubin", fn))
        if euclid_dir:
            for fn in sorted(glob.glob(os.path.join(euclid_dir, "*.npz"))):
                self.items.append(("euclid", fn))

        if not self.items:
            raise ValueError("No NPZ tiles found. Check rubin_dir/euclid_dir paths.")

        self.bands_rubin = list(bands_rubin)
        self.bands_euclid = list(bands_euclid)

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        survey, fn = self.items[idx]
        f = np.load(fn, allow_pickle=True)

        img = torch.from_numpy(f["img"]).float()     # [B,H,W]
        var = torch.from_numpy(f["var"]).float()
        msk = torch.from_numpy(f["mask"]).int()      # Rubin: bitmask; Euclid: whatever you saved

        wcs_hdr = f["wcs_hdr"].item()
        x0 = int(f["x0"]) if "x0" in f else 0
        y0 = int(f["y0"]) if "y0" in f else 0

        bands = self.bands_rubin if survey == "rubin" else self.bands_euclid

        return {
            "survey": survey,
            "bands": bands,
            "img": img,
            "var": var,
            "mask": msk,
            "wcs_hdr": wcs_hdr,  # keep as dict; build WCS only when needed
            "x0": x0,
            "y0": y0,
            "path": fn,
        }

def collate_tiles(batch):
    # Split by survey so we never try to stack Rubin and Euclid together
    out = {"rubin": None, "euclid": None}

    for survey in ("rubin", "euclid"):
        items = [b for b in batch if b["survey"] == survey]
        if not items:
            continue

        out[survey] = {
            "survey": survey,
            "bands": items[0]["bands"],                     # band list for this survey
            "img": torch.stack([x["img"] for x in items]),  # [N,B,H,W]
            "var": torch.stack([x["var"] for x in items]),
            "mask": torch.stack([x["mask"] for x in items]),
            "wcs_hdr": [x["wcs_hdr"] for x in items],       # list, not stacked
            "x0": torch.tensor([x["x0"] for x in items], dtype=torch.int32),
            "y0": torch.tensor([x["y0"] for x in items], dtype=torch.int32),
            "path": [x["path"] for x in items],
        }

    return out

# ------------------ Exact usage: all bands, and use Euclid VIS WCS as the reference frame ------------------

# 1) Set your tile directories (update these)
RUBIN_DIR      = "rubin_tiles_u_g_r_i_z_y"      # contains NPZ with img[6,H,W], var, mask, wcs_hdr
EUCLID_VIS_DIR = "euclid_tiles_VIS"            # contains NPZ with img[1,H,W], var, mask, wcs_hdr
EUCLID_NISP_DIR= "euclid_tiles_NISP"          # contains NPZ with img[3,H,W], var, mask, wcs_hdr

# 2) Build separate datasets so Rubin and Euclid are never tensor-stacked together by accident
ds_rubin = SeparateSurveyTileDataset(
    rubin_dir=RUBIN_DIR,
    euclid_dir=None,
    bands_rubin=bands_rubin,   # all 6: u,g,r,i,z,y
    bands_euclid=bands_euclid
)

ds_vis = SeparateSurveyTileDataset(
    rubin_dir=None,
    euclid_dir=EUCLID_VIS_DIR,
    bands_rubin=bands_rubin,
    bands_euclid=["VIS"]       # VIS only
)

ds_nisp = SeparateSurveyTileDataset(
    rubin_dir=None,
    euclid_dir=EUCLID_NISP_DIR,
    bands_rubin=bands_rubin,
    bands_euclid=["Y","J","H"] # NISP only
)

# 3) Make three loaders (you will "drive" everything off VIS)
dl_rubin = DataLoader(ds_rubin, batch_size=2, shuffle=True, num_workers=2, collate_fn=collate_tiles)
dl_vis   = DataLoader(ds_vis,   batch_size=2, shuffle=True, num_workers=2, collate_fn=collate_tiles)
dl_nisp  = DataLoader(ds_nisp,  batch_size=2, shuffle=True, num_workers=2, collate_fn=collate_tiles)

# 4) Pull one batch from each (you can zip these once filenames/tile-ids are aligned)
b_rubin = next(iter(dl_rubin))["rubin"]     # dict with img/var/mask/wcs_hdr list
b_vis   = next(iter(dl_vis))["euclid"]     # VIS batch stored under "euclid"
b_nisp  = next(iter(dl_nisp))["euclid"]    # NISP batch stored under "euclid"

print("Rubin:", b_rubin["img"].shape, b_rubin["bands"])
print("Euclid VIS:", b_vis["img"].shape, b_vis["bands"])
print("Euclid NISP:", b_nisp["img"].shape, b_nisp["bands"])

# 5) Reference WCS = Euclid VIS WCS for each sample in the VIS batch
# Build WCS objects only when needed (cheap enough for batch_size~2; cache later if needed)
vis_wcs_list = [wcs_from_hdr_dict(h) for h in b_vis["wcs_hdr"]]

# 6) Example: mapping Rubin/NISP pixels -> sky -> VIS pixel coordinates (for matching/alignment)
# (This is the correct way to "match everything to VIS" without resampling.)
from astropy.coordinates import SkyCoord
import astropy.units as u

def pix_to_sky(wcs: WCS, x: np.ndarray, y: np.ndarray) -> SkyCoord:
    ra, dec = wcs.all_pix2world(x, y, 0)
    return SkyCoord(ra*u.deg, dec*u.deg)

def sky_to_pix(wcs: WCS, sc: SkyCoord):
    x, y = wcs.all_world2pix(sc.ra.deg, sc.dec.deg, 0)
    return x, y

# Choose sample 0 in the VIS batch
wcs_vis  = vis_wcs_list[0]
wcs_rub  = wcs_from_hdr_dict(b_rubin["wcs_hdr"][0])
wcs_nisp = wcs_from_hdr_dict(b_nisp["wcs_hdr"][0])

# Example pixel (center of the Rubin tile) -> sky -> VIS pixel
Hr, Wr = b_rubin["img"].shape[-2:]
x_r, y_r = np.array([Wr/2.0]), np.array([Hr/2.0])
sc = pix_to_sky(wcs_rub, x_r, y_r)
x_v, y_v = sky_to_pix(wcs_vis, sc)

print("Rubin center maps to VIS pixel:", float(x_v[0]), float(y_v[0]))

# Same for NISP center -> VIS pixel
Hn, Wn = b_nisp["img"].shape[-2:]
x_n, y_n = np.array([Wn/2.0]), np.array([Hn/2.0])
sc2 = pix_to_sky(wcs_nisp, x_n, y_n)
x_v2, y_v2 = sky_to_pix(wcs_vis, sc2)

print("NISP center maps to VIS pixel:", float(x_v2[0]), float(y_v2[0]))



ValueError: No NPZ tiles found. Check rubin_dir/euclid_dir paths.