# Tile Rubin tract 5063 (all patches) and fetch Euclid counterparts

This mirrors `01_getdata_patch.ipynb` but loops over every patch in a tract. It tolerates missing bands/patches by skipping what's unavailable.


In [None]:
import os, glob
import numpy as np
from lsst.daf.butler import Butler
import lsst.geom as geom

# ---- Config ----
TRACT = 5063
SKYMAP = "lsst_cells_v1"
REPO = "dp1"
COLLECTION = "LSSTComCam/DP1"
DATASETTYPE = "deep_coadd"

OUT_RUBIN_ROOT = "../data/rubin_tiles_tract5063"
os.makedirs(OUT_RUBIN_ROOT, exist_ok=True)

# You can reorder / trim if you only need a subset
bands_rubin = ("u","g","r","i","z","y")
TILE_SIZE = 512
STRIDE    = 256    # overlap
MAX_TILES = None   # None => all tiles in that patch image

butler = Butler(REPO, collections=COLLECTION)


def get_patches_in_tract(butler, tract, band="r", datasetType=DATASETTYPE, skymap=SKYMAP):
    """List patches that have *at least* the chosen band; safer than assuming all 100 exist."""
    refs = butler.query_datasets(
        datasetType,
        where="tract = tract AND band = band AND skymap = skymap",
        bind={"tract": tract, "band": band, "skymap": skymap},
        with_dimension_records=True,
    )
    return sorted({ref.dataId["patch"] for ref in refs})


def load_patch_exposures_by_id(butler, tract, patch, bands=bands_rubin, datasetType=DATASETTYPE, skymap=SKYMAP):
    exps = {}
    available = []
    for b in bands:
        dataId = {"tract": tract, "patch": patch, "band": b, "skymap": skymap}
        try:
            exps[b] = butler.get(datasetType, dataId=dataId)
            available.append(b)
        except Exception as e:
            print(f"  skipping band {b} for patch {patch}: {e}")
    if not available:
        raise RuntimeError(f"No bands found for patch {patch}")
    wcs_full = exps[available[0]].getWcs()
    return exps, wcs_full, available


def wcs_to_hdr_dict_lsst(wcs_lsst):
    md = wcs_lsst.getFitsMetadata()
    return {k: md.getScalar(k) for k in md.names()}


In [None]:

import lsst.afw.geom as afwGeom
from lsst.daf.base import PropertySet
import numpy as np
import glob
import matplotlib.pyplot as plt


def wcs_from_hdr_dict(hdr_dict):
    ps = PropertySet()
    for k, v in hdr_dict.items():
        ps.set(k, v)
    return afwGeom.makeSkyWcs(ps)


def tile_corners_from_npz(npz_path):
    d = np.load(npz_path, allow_pickle=True)
    wcs = wcs_from_hdr_dict(d["wcs_hdr"].item())
    tile_size = int(d["tile_size"])
    S = tile_size - 1
    corners_pix = [(0,0),(S,0),(S,S),(0,S),(0,0)]
    ra, dec = [], []
    for x, y in corners_pix:
        sp = wcs.pixelToSky(x, y)
        ra.append(sp.getRa().asDegrees())
        dec.append(sp.getDec().asDegrees())
    return np.array(ra), np.array(dec)

files = sorted(glob.glob(f"{OUT_RUBIN_ROOT}/**/tile_*.npz", recursive=True))
print("Found", len(files), "tiles")

plt.figure(figsize=(7,7))
for fn in files:
    ra, dec = tile_corners_from_npz(fn)
    plt.plot(ra, dec, linewidth=0.5, alpha=0.5)
plt.xlabel("RA [deg]")
plt.ylabel("Dec [deg]")
plt.title("Rubin 512Ã—512 tile footprints (tract 5063)")
plt.gca().invert_xaxis(); plt.gca().set_aspect("equal", adjustable="box")
plt.show()


# Batch download: Rubin tiles + Euclid counterparts (quota-friendly)

Processes one tile at a time: **Rubin tile -> Euclid fetch -> next tile**.
Every `BATCH_SIZE` tiles, creates a `.tar.gz` archive and deletes the
originals to free disk space, then pauses for you to download the archive
and delete it before continuing.

Progress is tracked in `tile_progress.log` so re-running skips already-archived tiles.


In [None]:
import tarfile
from astroquery.ipac.irsa import Irsa
from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.table import Table
from astropy.io import fits
from astropy.nddata import Cutout2D
from astropy.wcs import WCS
import fsspec
import os, glob, numpy as np

from scipy.ndimage import maximum_filter, median_filter, zoom, gaussian_filter
from scipy.optimize import linear_sum_assignment
from scipy.stats import gaussian_kde
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse, Circle
import matplotlib.patches as mpatches
from astropy.coordinates import match_coordinates_sky, search_around_sky


def sanitize_rms(rms, huge=1e10):
    rms = rms.astype(np.float32, copy=False)
    bad = (~np.isfinite(rms)) | (rms <= 0) | (rms > huge)
    rms = rms.copy(); rms[bad] = np.nan
    return rms

def robust_imshow(ax, img, title="", p=(1, 99)):
    if img is None:
        ax.set_title(f"{title} (missing)"); ax.axis("off"); return
    lo, hi = np.nanpercentile(img, p)
    ax.imshow(img, origin="lower", vmin=lo, vmax=hi)
    ax.set_title(title); ax.set_xticks([]); ax.set_yticks([])

def save_bundle(path, **kw):
    np.savez_compressed(path, **{k: v for k, v in kw.items() if v is not None})
    print("saved:", path)


def load_euclid_cutouts(ra, dec, size_arcsec, bands=("VIS","Y","J","H"), collection="euclid_DpdMerBksMosaic", radius_arcsec=60):
    coord = SkyCoord(ra=ra*u.deg, dec=dec*u.deg, frame="icrs")
    tab = Irsa.query_sia(pos=(coord, radius_arcsec*u.arcsec), collection=collection)
    if not isinstance(tab, Table):
        tab = tab.to_table()
    out_img = {b: None for b in bands}; out_var = {b: None for b in bands}; wcs_out = {}

    def get_row(band, subtype):
        m = (tab["energy_bandpassname"] == band) & (tab["dataproduct_subtype"] == subtype)
        rows = tab[m]; return rows[0] if len(rows) else None

    for b in bands:
        row_sci = get_row(b, "science")
        if row_sci is None:
            continue
        with fsspec.open(row_sci["access_url"], "rb") as f:
            with fits.open(f, memmap=False) as hdul:
                wcs0 = WCS(hdul[0].header)
                cut = Cutout2D(hdul[0].data, coord, size_arcsec * u.arcsec, wcs=wcs0)
                out_img[b] = np.array(cut.data, dtype=np.float32)
                wcs_out[b] = cut.wcs
        row_rms = get_row(b, "noise")
        if row_rms is None:
            continue
        with fsspec.open(row_rms["access_url"], "rb") as f:
            with fits.open(f, memmap=False) as hdul:
                wcsn = WCS(hdul[0].header)
                cutn = Cutout2D(hdul[0].data, coord, size_arcsec * u.arcsec, wcs=wcsn)
                rms = np.array(cutn.data, dtype=np.float32)
        rms = sanitize_rms(rms, huge=1e10)
        out_var[b] = rms * rms
    return out_img, out_var, wcs_out


# ---- Config ----
OUT_EUCLID_DIR = "../data/euclid_tiles_tract5063"
bands_euclid = ("VIS", "Y", "J", "H")
EUCLID_SIZE_ARCSEC = 105.0

BATCH_SIZE = 10            # tiles per archive before pausing
ARCHIVE_DIR = "../data/tile_archives"
PROGRESS_FILE = "../data/tile_progress.log"
DATA_ROOT = os.path.abspath("../data")

os.makedirs(OUT_EUCLID_DIR, exist_ok=True)
os.makedirs(ARCHIVE_DIR, exist_ok=True)

# ---- Resume support ----
processed = set()
if os.path.exists(PROGRESS_FILE):
    with open(PROGRESS_FILE) as f:
        processed = {line.strip() for line in f if line.strip()}
print(f"Resuming: {len(processed)} tiles already archived")

# ---- Main loop ----
patch_ids = get_patches_in_tract(butler, TRACT, band="r")
print(f"Tract {TRACT}: {len(patch_ids)} patches with r-band")

batch_files = []
batch_keys = []
batch_num = len(glob.glob(os.path.join(ARCHIVE_DIR, "batch_*.tar.gz")))
total_new = 0
total_skipped = 0

for patch in patch_ids:
    patch_label = f"patch{int(patch):02d}"
    rubin_out_dir = os.path.join(OUT_RUBIN_ROOT, patch_label)
    os.makedirs(rubin_out_dir, exist_ok=True)

    try:
        exps, wcs_full, bands_present = load_patch_exposures_by_id(
            butler, tract=TRACT, patch=patch)
    except Exception as e:
        print(f"Skipping {patch_label}: {e}")
        continue

    bands_tup = tuple(bands_present)
    patch_origin = exps[bands_tup[0]].getXY0()
    x0_patch, y0_patch = patch_origin.getX(), patch_origin.getY()
    H, W = exps[bands_tup[0]].image.array.shape

    for y0 in range(0, H - TILE_SIZE + 1, STRIDE):
        for x0 in range(0, W - TILE_SIZE + 1, STRIDE):
            tile_id = f"tile_x{x0:05d}_y{y0:05d}"
            progress_key = f"{patch_label}/{tile_id}"

            if progress_key in processed:
                total_skipped += 1
                continue

            # ---- Save Rubin tile ----
            rubin_fn = os.path.join(rubin_out_dir, f"{tile_id}.npz")
            gcx = x0_patch + x0 + (TILE_SIZE - 1) / 2.0
            gcy = y0_patch + y0 + (TILE_SIZE - 1) / 2.0
            sp = wcs_full.pixelToSky(gcx, gcy)
            ra_c = sp.getRa().asDegrees()
            dec_c = sp.getDec().asDegrees()
            wcs_local = wcs_full.copyAtShiftedPixelOrigin(
                geom.Extent2D(-(x0_patch + x0), -(y0_patch + y0)))

            imgs, vars_, masks = [], [], []
            for b in bands_tup:
                exp = exps[b]
                imgs.append(exp.image.array[y0:y0+TILE_SIZE, x0:x0+TILE_SIZE].astype(np.float32))
                vars_.append(exp.variance.array[y0:y0+TILE_SIZE, x0:x0+TILE_SIZE].astype(np.float32))
                masks.append(exp.mask.array[y0:y0+TILE_SIZE, x0:x0+TILE_SIZE].astype(np.int32))

            np.savez_compressed(rubin_fn,
                img=np.stack(imgs), var=np.stack(vars_), mask=np.stack(masks),
                wcs_hdr=wcs_to_hdr_dict_lsst(wcs_local),
                x0=np.int32(x0), y0=np.int32(y0),
                tile_id=np.bytes_(tile_id),
                ra_center=np.float64(ra_c), dec_center=np.float64(dec_c),
                tile_size=np.int32(TILE_SIZE), stride=np.int32(STRIDE),
                bands=np.array(list(bands_tup)))
            batch_files.append(rubin_fn)

            # ---- Fetch & save Euclid tile ----
            euclid_fn = os.path.join(OUT_EUCLID_DIR, f"{tile_id}_euclid.npz")
            if not os.path.exists(euclid_fn):
                try:
                    eu_imgs, eu_var, eu_wcss = load_euclid_cutouts(
                        ra_c, dec_c, size_arcsec=EUCLID_SIZE_ARCSEC, bands=bands_euclid)
                    save_dict = {"ra_center": ra_c, "dec_center": dec_c, "tile_id": tile_id}
                    for b in bands_euclid:
                        if eu_imgs[b] is not None:
                            save_dict[f"img_{b}"] = eu_imgs[b]
                            save_dict[f"wcs_{b}"] = eu_wcss[b].to_header_string()
                        if eu_var[b] is not None:
                            save_dict[f"var_{b}"] = eu_var[b]
                    np.savez_compressed(euclid_fn, **save_dict)
                except Exception as e:
                    print(f"  Euclid failed for {progress_key}: {e}")
            if os.path.exists(euclid_fn):
                batch_files.append(euclid_fn)

            batch_keys.append(progress_key)
            total_new += 1
            print(f"  {progress_key} ({total_new} new, {total_skipped} skipped)")

            # ---- Archive every BATCH_SIZE tiles ----
            if len(batch_keys) >= BATCH_SIZE:
                archive_path = os.path.join(ARCHIVE_DIR, f"batch_{batch_num:04d}.tar.gz")
                with tarfile.open(archive_path, "w:gz") as tar:
                    for fn in batch_files:
                        tar.add(fn, arcname=os.path.relpath(fn, DATA_ROOT))
                for fn in batch_files:
                    os.remove(fn)
                with open(PROGRESS_FILE, "a") as f:
                    for key in batch_keys:
                        f.write(key + "\n")
                processed.update(batch_keys)

                print(f"\n{'='*60}")
                print(f"  Batch {batch_num} -> {archive_path}")
                print(f"  {len(batch_keys)} tiles archived, {total_new + total_skipped} total")
                print(f"{'='*60}")

                batch_files = []
                batch_keys = []
                batch_num += 1
                input(">>> Download the archive, delete it, then press Enter to continue... ")

    del exps  # free patch memory

# ---- Final partial batch ----
if batch_keys:
    archive_path = os.path.join(ARCHIVE_DIR, f"batch_{batch_num:04d}.tar.gz")
    with tarfile.open(archive_path, "w:gz") as tar:
        for fn in batch_files:
            tar.add(fn, arcname=os.path.relpath(fn, DATA_ROOT))
    for fn in batch_files:
        os.remove(fn)
    with open(PROGRESS_FILE, "a") as f:
        for key in batch_keys:
            f.write(key + "\n")
    print(f"\nFinal batch {batch_num} -> {archive_path}")
    input(">>> Download this last archive and delete it. Press Enter when done... ")

print(f"\nDone! {total_new} new tiles archived, {total_skipped} previously done.")


In [None]:

import os, numpy as np, matplotlib.pyplot as plt
from astropy.wcs import WCS

# pick one tile id for visualization
tile_id_str = "tile_x00000_y00000"
rubin_path  = os.path.join(OUT_RUBIN_ROOT,  "patch00", f"{tile_id_str}.npz")
euclid_path = os.path.join(OUT_EUCLID_DIR, f"{tile_id_str}_euclid.npz")

r_data = np.load(rubin_path)
e_data = np.load(euclid_path)

rubin_bands_full = ["u", "g", "r", "i", "z", "y"]
nb_rubin = r_data['img'].shape[0]

fig, axes = plt.subplots(2, 5, figsize=(20, 8)); axes = axes.flatten()

# Rubin panels (only bands present)
for i in range(nb_rubin):
    band = rubin_bands_full[i] if i < len(rubin_bands_full) else f"b{i}"
    robust_imshow(axes[i], r_data['img'][i], title=f"Rubin {band}")

# Euclid panels
for i, band in enumerate(["VIS", "Y", "J", "H"]):
    ax = axes[i + 6]
    img_key = f"img_{band}"
    if img_key in e_data:
        img = e_data[img_key]
        robust_imshow(ax, img, title=f"Euclid {band}")
    else:
        ax.set_title(f"Euclid {band} (Missing)"); ax.axis('off')

plt.suptitle(f"Multi-band view: {tile_id_str}
RA: {r_data['ra_center']:.4f}, Dec: {r_data['dec_center']:.4f}", fontsize=16)
plt.tight_layout(); plt.show()
