# Tile Rubin tract 5063 + fetch Euclid counterparts (quota-friendly)

Downloads row-by-row: for each row of patches, tiles all Rubin data first,
then fetches Euclid counterparts, plots progress, archives the row into a
`.tar.gz`, and pauses so you can download and delete before continuing.

Progress is tracked in `tile_progress.log` for safe resume.


In [None]:
import os, glob, io, tarfile
import numpy as np
import matplotlib.pyplot as plt
from lsst.daf.butler import Butler
import lsst.geom as geom
import lsst.afw.geom as afwGeom
from lsst.daf.base import PropertySet
from astroquery.ipac.irsa import Irsa
from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.table import Table
from astropy.io import fits
from astropy.nddata import Cutout2D
from astropy.wcs import WCS
import fsspec

# ---- Config ----
TRACT = 5063
SKYMAP = "lsst_cells_v1"
REPO = "dp1"
COLLECTION = "LSSTComCam/DP1"
DATASETTYPE = "deep_coadd"

OUT_RUBIN_ROOT = "../data/rubin_tiles_tract5063"
OUT_EUCLID_DIR = "../data/euclid_tiles_tract5063"
ARCHIVE_DIR = "../data/tile_archives"
PROGRESS_FILE = "../data/tile_progress.log"
DATA_ROOT = os.path.abspath("../data")

for d in [OUT_RUBIN_ROOT, OUT_EUCLID_DIR, ARCHIVE_DIR]:
    os.makedirs(d, exist_ok=True)

bands_rubin = ("u", "g", "r", "i", "z", "y")
bands_euclid = ("VIS", "Y", "J", "H")
TILE_SIZE = 512
STRIDE = 256
EUCLID_SIZE_ARCSEC = 105.0
ROW_BATCH = 5   # patches per batch if row grouping fails

butler = Butler(REPO, collections=COLLECTION)


# ---- Rubin helpers ----

def get_patches_in_tract(butler, tract, band="r"):
    refs = butler.query_datasets(
        DATASETTYPE,
        where="tract = tract AND band = band AND skymap = skymap",
        bind={"tract": tract, "band": band, "skymap": SKYMAP},
        with_dimension_records=True,
    )
    return sorted({ref.dataId["patch"] for ref in refs})


def load_patch_exposures(butler, tract, patch, bands=bands_rubin):
    exps = {}
    available = []
    for b in bands:
        dataId = {"tract": tract, "patch": patch, "band": b, "skymap": SKYMAP}
        try:
            exps[b] = butler.get(DATASETTYPE, dataId=dataId)
            available.append(b)
        except Exception as e:
            print(f"    skipping band {b} for patch {patch}: {e}")
    if not available:
        raise RuntimeError(f"No bands found for patch {patch}")
    return exps, exps[available[0]].getWcs(), tuple(available)


def wcs_to_hdr_dict(wcs_lsst):
    md = wcs_lsst.getFitsMetadata()
    return {k: md.getScalar(k) for k in md.names()}


# ---- Euclid helpers ----

def sanitize_rms(rms, huge=1e10):
    rms = rms.astype(np.float32, copy=False)
    bad = (~np.isfinite(rms)) | (rms <= 0) | (rms > huge)
    rms = rms.copy(); rms[bad] = np.nan
    return rms


def robust_imshow(ax, img, title="", p=(1, 99)):
    if img is None:
        ax.set_title(f"{title} (missing)"); ax.axis("off"); return
    lo, hi = np.nanpercentile(img, p)
    ax.imshow(img, origin="lower", vmin=lo, vmax=hi)
    ax.set_title(title); ax.set_xticks([]); ax.set_yticks([])


def load_euclid_cutouts(ra, dec, size_arcsec, bands=bands_euclid,
                        collection="euclid_DpdMerBksMosaic", radius_arcsec=60):
    coord = SkyCoord(ra=ra*u.deg, dec=dec*u.deg, frame="icrs")
    tab = Irsa.query_sia(pos=(coord, radius_arcsec*u.arcsec), collection=collection)
    if not isinstance(tab, Table):
        tab = tab.to_table()
    out_img = {b: None for b in bands}
    out_var = {b: None for b in bands}
    wcs_out = {}

    def get_row(band, subtype):
        m = (tab["energy_bandpassname"] == band) & (tab["dataproduct_subtype"] == subtype)
        rows = tab[m]
        return rows[0] if len(rows) else None

    for b in bands:
        row_sci = get_row(b, "science")
        if row_sci is None:
            continue
        with fsspec.open(row_sci["access_url"], "rb") as f:
            with fits.open(f, memmap=False) as hdul:
                wcs0 = WCS(hdul[0].header)
                cut = Cutout2D(hdul[0].data, coord, size_arcsec * u.arcsec, wcs=wcs0)
                out_img[b] = np.array(cut.data, dtype=np.float32)
                wcs_out[b] = cut.wcs
        row_rms = get_row(b, "noise")
        if row_rms is None:
            continue
        with fsspec.open(row_rms["access_url"], "rb") as f:
            with fits.open(f, memmap=False) as hdul:
                wcsn = WCS(hdul[0].header)
                cutn = Cutout2D(hdul[0].data, coord, size_arcsec * u.arcsec, wcs=wcsn)
                rms = np.array(cutn.data, dtype=np.float32)
        rms = sanitize_rms(rms, huge=1e10)
        out_var[b] = rms * rms
    return out_img, out_var, wcs_out


# ---- Geometry & grouping helpers ----

def _bbox_sky_outline(bbox, wcs, n_pts=40):
    """Sample points along edges of a pixel bbox, project to RA/Dec."""
    xmin, ymin = bbox.getMin()
    xmax, ymax = bbox.getMax()
    pts = []
    for x in np.linspace(xmin, xmax, n_pts): pts.append((x, ymin))
    for y in np.linspace(ymin, ymax, n_pts): pts.append((xmax, y))
    for x in np.linspace(xmax, xmin, n_pts): pts.append((x, ymax))
    for y in np.linspace(ymax, ymin, n_pts): pts.append((xmin, y))
    ra, dec = [], []
    for x, y in pts:
        sky = wcs.pixelToSky(geom.Point2D(x, y))
        ra.append(sky.getRa().asDegrees())
        dec.append(sky.getDec().asDegrees())
    return ra, dec


def get_patch_rows(tract_info, patch_ids):
    """Group patch IDs into rows by their grid y-index.
    Falls back to batches of ROW_BATCH if index lookup fails."""
    rows = {}
    for pid in patch_ids:
        try:
            pinfo = tract_info[pid]
            idx = pinfo.getIndex()
            # 2D index (ix, iy) -> group by iy
            if hasattr(idx, '__len__') and len(idx) == 2:
                row_key = idx[1]
            else:
                row_key = pid // ROW_BATCH
        except Exception:
            row_key = pid // ROW_BATCH
        rows.setdefault(row_key, []).append(pid)
    return dict(sorted(rows.items()))


# ---- Progress plot ----

def plot_progress():
    """Show patches with data (from butler) + downloaded tile centers."""
    skymap_obj = butler.get("skyMap", skymap=SKYMAP)
    tract_info = skymap_obj[TRACT]
    tract_wcs = tract_info.getWcs()
    patch_ids = get_patches_in_tract(butler, TRACT, band="r")

    fig, ax = plt.subplots(figsize=(8, 8))

    # Tract outline
    try:
        verts = tract_info.getVertexList()
        vra = [v.getRa().asDegrees() for v in verts] + [verts[0].getRa().asDegrees()]
        vdec = [v.getDec().asDegrees() for v in verts] + [verts[0].getDec().asDegrees()]
        ax.plot(vra, vdec, color="k", linewidth=1.5, alpha=0.3, label="Tract boundary")
    except Exception:
        pass

    # Patches with data
    for pid in patch_ids:
        try:
            pinfo = tract_info[pid]
            bbox = geom.Box2D(pinfo.getOuterBBox())
            ra, dec = _bbox_sky_outline(bbox, tract_wcs)
            ax.plot(ra, dec, color="C3", linewidth=0.7, alpha=0.5)
        except Exception as e:
            print(f"  Could not plot patch {pid}: {e}")

    # Downloaded tile centers — on disk
    ra_disk, dec_disk = [], []
    for fn in sorted(glob.glob(f"{OUT_RUBIN_ROOT}/**/tile_*.npz", recursive=True)):
        d = np.load(fn, allow_pickle=True)
        ra_disk.append(float(d["ra_center"]))
        dec_disk.append(float(d["dec_center"]))

    # Downloaded tile centers — in archives
    ra_arch, dec_arch = [], []
    for arc in sorted(glob.glob(os.path.join(ARCHIVE_DIR, "batch_*.tar.gz"))):
        with tarfile.open(arc, "r:gz") as tar:
            for member in tar.getmembers():
                if "rubin_tiles" in member.name and member.name.endswith(".npz"):
                    f = tar.extractfile(member)
                    d = np.load(io.BytesIO(f.read()), allow_pickle=True)
                    ra_arch.append(float(d["ra_center"]))
                    dec_arch.append(float(d["dec_center"]))

    if ra_disk:
        ax.scatter(ra_disk, dec_disk, s=3, color="C0", alpha=0.7,
                   label=f"On disk ({len(ra_disk)})")
    if ra_arch:
        ax.scatter(ra_arch, dec_arch, s=3, color="C1", alpha=0.7,
                   label=f"Archived ({len(ra_arch)})")

    n_done = len(ra_disk) + len(ra_arch)
    print(f"{len(patch_ids)} patches with data, "
          f"{n_done} tiles downloaded ({len(ra_disk)} disk, {len(ra_arch)} archived)")

    ax.set_xlabel("RA (deg)")
    ax.set_ylabel("Dec (deg)")
    ax.set_title(f"Tract {TRACT}: {len(patch_ids)} patches with data")
    ax.legend(markerscale=3)
    ax.invert_xaxis()
    ax.set_aspect("equal", adjustable="box")
    fig.tight_layout()
    plt.show()


In [None]:
plot_progress()

# Tile-by-tile download

For each tile: save Rubin -> fetch Euclid -> plot progress -> compress -> pause for download.


In [None]:
# ---- Resume support ----
processed = set()
if os.path.exists(PROGRESS_FILE):
    with open(PROGRESS_FILE) as f:
        processed = {line.strip() for line in f if line.strip()}

patch_ids = get_patches_in_tract(butler, TRACT, band="r")
print(f"Already archived: {len(processed)} tiles")
print(f"Tract {TRACT}: {len(patch_ids)} patches with data: {patch_ids}")

tile_num = len(processed)

for patch in patch_ids:
    patch_label = f"patch{int(patch):02d}"
    rubin_out_dir = os.path.join(OUT_RUBIN_ROOT, patch_label)
    os.makedirs(rubin_out_dir, exist_ok=True)

    try:
        exps, wcs_full, bands_present = load_patch_exposures(
            butler, tract=TRACT, patch=patch)
    except Exception as e:
        print(f"Skipping {patch_label}: {e}")
        continue

    patch_origin = exps[bands_present[0]].getXY0()
    x0_patch, y0_patch = patch_origin.getX(), patch_origin.getY()
    H, W = exps[bands_present[0]].image.array.shape

    for y0 in range(0, H - TILE_SIZE + 1, STRIDE):
        for x0 in range(0, W - TILE_SIZE + 1, STRIDE):
            tile_id = f"tile_x{x0:05d}_y{y0:05d}"
            progress_key = f"{patch_label}/{tile_id}"

            if progress_key in processed:
                continue

            # ---- 1) Rubin tile ----
            rubin_fn = os.path.join(rubin_out_dir, f"{tile_id}.npz")
            gcx = x0_patch + x0 + (TILE_SIZE - 1) / 2.0
            gcy = y0_patch + y0 + (TILE_SIZE - 1) / 2.0
            sp = wcs_full.pixelToSky(gcx, gcy)
            ra_c = sp.getRa().asDegrees()
            dec_c = sp.getDec().asDegrees()
            wcs_local = wcs_full.copyAtShiftedPixelOrigin(
                geom.Extent2D(-(x0_patch + x0), -(y0_patch + y0)))

            imgs, vars_, masks = [], [], []
            for b in bands_present:
                exp = exps[b]
                imgs.append(exp.image.array[y0:y0+TILE_SIZE, x0:x0+TILE_SIZE].astype(np.float32))
                vars_.append(exp.variance.array[y0:y0+TILE_SIZE, x0:x0+TILE_SIZE].astype(np.float32))
                masks.append(exp.mask.array[y0:y0+TILE_SIZE, x0:x0+TILE_SIZE].astype(np.int32))

            np.savez_compressed(rubin_fn,
                img=np.stack(imgs), var=np.stack(vars_), mask=np.stack(masks),
                wcs_hdr=wcs_to_hdr_dict(wcs_local),
                x0=np.int32(x0), y0=np.int32(y0),
                tile_id=np.bytes_(tile_id),
                ra_center=np.float64(ra_c), dec_center=np.float64(dec_c),
                tile_size=np.int32(TILE_SIZE), stride=np.int32(STRIDE),
                bands=np.array(list(bands_present)))

            # ---- 2) Euclid tile ----
            euclid_fn = os.path.join(OUT_EUCLID_DIR, f"{tile_id}_euclid.npz")
            if not os.path.exists(euclid_fn):
                try:
                    eu_imgs, eu_var, eu_wcss = load_euclid_cutouts(
                        ra_c, dec_c, size_arcsec=EUCLID_SIZE_ARCSEC, bands=bands_euclid)
                    save_dict = {"ra_center": ra_c, "dec_center": dec_c, "tile_id": tile_id}
                    for b in bands_euclid:
                        if eu_imgs[b] is not None:
                            save_dict[f"img_{b}"] = eu_imgs[b]
                            save_dict[f"wcs_{b}"] = eu_wcss[b].to_header_string()
                        if eu_var[b] is not None:
                            save_dict[f"var_{b}"] = eu_var[b]
                    np.savez_compressed(euclid_fn, **save_dict)
                    print(f"  Euclid OK")
                except Exception as e:
                    print(f"  Euclid FAIL: {e}")

            # ---- 3) Plot progress ----
            plot_progress()

            # ---- 4) Compress ----
            files = [rubin_fn]
            if os.path.exists(euclid_fn):
                files.append(euclid_fn)

            archive_path = os.path.join(ARCHIVE_DIR, f"tile_{tile_num:04d}.tar.gz")
            with tarfile.open(archive_path, "w:gz") as tar:
                for fn in files:
                    tar.add(fn, arcname=os.path.relpath(fn, DATA_ROOT))
            for fn in files:
                os.remove(fn)

            with open(PROGRESS_FILE, "a") as f:
                f.write(progress_key + "\n")
            processed.add(progress_key)
            tile_num += 1

            # ---- 5) Pause ----
            print(f"Tile {tile_num}: {progress_key} -> {os.path.basename(archive_path)}")
            input(">>> Download the archive, delete it, then press Enter... ")

    del exps

print(f"\nDone! {tile_num} tiles archived total.")


In [None]:

import os, numpy as np, matplotlib.pyplot as plt
from astropy.wcs import WCS

# pick one tile id for visualization
tile_id_str = "tile_x00000_y00000"
rubin_path  = os.path.join(OUT_RUBIN_ROOT,  "patch00", f"{tile_id_str}.npz")
euclid_path = os.path.join(OUT_EUCLID_DIR, f"{tile_id_str}_euclid.npz")

r_data = np.load(rubin_path)
e_data = np.load(euclid_path)

rubin_bands_full = ["u", "g", "r", "i", "z", "y"]
nb_rubin = r_data['img'].shape[0]

fig, axes = plt.subplots(2, 5, figsize=(20, 8)); axes = axes.flatten()

# Rubin panels (only bands present)
for i in range(nb_rubin):
    band = rubin_bands_full[i] if i < len(rubin_bands_full) else f"b{i}"
    robust_imshow(axes[i], r_data['img'][i], title=f"Rubin {band}")

# Euclid panels
for i, band in enumerate(["VIS", "Y", "J", "H"]):
    ax = axes[i + 6]
    img_key = f"img_{band}"
    if img_key in e_data:
        img = e_data[img_key]
        robust_imshow(ax, img, title=f"Euclid {band}")
    else:
        ax.set_title(f"Euclid {band} (Missing)"); ax.axis('off')

plt.suptitle(f"Multi-band view: {tile_id_str}
RA: {r_data['ra_center']:.4f}, Dec: {r_data['dec_center']:.4f}", fontsize=16)
plt.tight_layout(); plt.show()
