In [None]:
from pathlib import Path
from ovro_lwa_portal.fits_to_zarr_xradio import convert_fits_dir_to_zarr

# paths relative to repo root; notebook is in ./notebooks/
INPUT_DIR = Path("../data/fits")
OUT_DIR   = Path("../out")
FIXED_DIR = OUT_DIR / "fixed"
ZARR_NAME = "test.zarr"

OUT_DIR.mkdir(parents=True, exist_ok=True)
FIXED_DIR.mkdir(parents=True, exist_ok=True)

out = convert_fits_dir_to_zarr(
    input_dir=INPUT_DIR,
    out_dir=OUT_DIR,
    zarr_name=ZARR_NAME,
    fixed_dir=FIXED_DIR,
    chunk_lm=1024,
    rebuild=True,          # overwrite if exists
)
print("Zarr:", out)


In [None]:
import xarray as xr, numpy as np
from astropy.io.fits import Header
from astropy.wcs import WCS

def get_wcs_from_zarr(z: xr.Dataset, var: str = "SKY") -> WCS:
    """Return a 2D celestial WCS reconstructed purely from the Zarr store."""
    if var not in z.data_vars:
        # fall back to BEAM if SKY missing
        var = "BEAM" if "BEAM" in z.data_vars else list(z.data_vars)[0]
    # prefer per-variable attr (we stored it there)
    hdr_str = z[var].attrs.get("fits_wcs_header")
    if not hdr_str:
        # fallback to the 0-D variable with robust decoding
        val = z["wcs_header_str"].values
        if isinstance(val, np.ndarray):
            val = val.item()
        if type(val).__name__ in ("bytes_",) or isinstance(val, (bytes, bytearray)):
            hdr_str = val.decode("utf-8", errors="strict")
        else:
            hdr_str = str(val)
    return WCS(Header.fromstring(hdr_str, sep="\n"))


In [None]:
# --- Zarr WCS plot with white grid + printed RA/Dec labels (no extra installs) ---

%matplotlib inline
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from astropy.io.fits import Header
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord
from astropy import units as u

# -----------------------
# Config
# -----------------------
ZARR = "../out/test.zarr"
VMIN, VMAX = -2, 12
R_MASK = 1833             # set None to disable circular trim
GRID_SPACING = 30 * u.deg # grid tick spacing
FONT = 16

# -----------------------
# Robust WCS-from-Zarr
# -----------------------
def get_wcs_from_zarr(z: xr.Dataset, var: str = "SKY") -> WCS:
    if var not in z.data_vars:
        var = "BEAM" if "BEAM" in z.data_vars else list(z.data_vars)[0]
    hdr_str = z[var].attrs.get("fits_wcs_header")
    if not hdr_str:
        val = z["wcs_header_str"].values
        if isinstance(val, np.ndarray): val = val.item()
        hdr_str = (
            val.decode("utf-8")
            if isinstance(val, (bytes, bytearray)) or type(val).__name__ == "bytes_"
            else str(val)
        )
    return WCS(Header.fromstring(hdr_str, sep="\n"))

# -----------------------
# Load plane from Zarr
# -----------------------
z = xr.open_zarr(ZARR, consolidated=False)
var = "SKY" if "SKY" in z.data_vars else ("BEAM" if "BEAM" in z.data_vars else list(z.data_vars)[0])
w   = get_wcs_from_zarr(z, var)

tsel = 0
fsel = int(z.sizes.get("frequency", 1) // 2)
da   = z[var].isel(time=tsel, frequency=fsel).squeeze()
if {"m","l"}.issubset(da.dims):
    da = da.transpose("m","l", ...)
img  = np.asarray(da.values, dtype=float)
ny, nx = img.shape

# Optional rim mask
if R_MASK is not None:
    yy, xx = np.indices((ny, nx))
    cy, cx = ny//2, nx//2
    img[((yy-cy)**2 + (xx-cx)**2) > R_MASK**2] = np.nan

# -----------------------
# Colormap: black for NaNs/underflow
# -----------------------
cmap = plt.get_cmap("inferno").copy()
cmap.set_bad("black", 1.0)
cmap.set_under("black", 1.0)

# -----------------------
# Figure (black theme)
# -----------------------
fig = plt.figure(figsize=(12, 12), facecolor="black")
ax  = plt.subplot(projection=w, facecolor="black")

# Image
im = ax.imshow(img, origin="lower", cmap=cmap, vmin=VMIN, vmax=VMAX)

# Axes labels/ticks in white
ax.set_xlabel("RA",  color="white", fontsize=FONT, labelpad=15)
ax.set_ylabel("Dec", color="white", fontsize=FONT, labelpad=15)

# --- Auto check RA direction ---
try:
    cdelt1 = float(w.wcs.cdelt[0])
except Exception:
    cdelt1 = np.nan
if np.isfinite(cdelt1) and cdelt1 > 0:
    ax.invert_xaxis()   # Only flip if RA increases to the right

# White WCS grid + white tick labels
overlay = ax.get_coords_overlay("fk5")  # stored WCS is FK5/J2000
overlay.grid(color="white", ls=":", lw=1.0, alpha=0.8)
for c in overlay:
    c.set_ticks(spacing=GRID_SPACING)
    c.set_ticklabel_visible(True)
    c.set_ticklabel(color="white", size=FONT-2)
    c.tick_params(width=1, color="white")

# Colorbar (white)
cbar = plt.colorbar(im, ax=ax, fraction=0.026, pad=0.01, extend="both")
cbar.set_label("Jy/beam", color="white", fontsize=FONT)
cbar.ax.tick_params(color="white", labelcolor="white", labelsize=FONT-2)
cbar.outline.set_edgecolor("white")
cbar.ax.set_facecolor("black")

# -----------------------
# Compute convenient RA/Dec label lines
# -----------------------
def visible_world_box(wcs: WCS, nx: int, ny: int):
    pix = np.array([[0,0],[nx-1,0],[0,ny-1],[nx-1,ny-1]])
    world = wcs.pixel_to_world(pix[:,0], pix[:,1])
    ra   = world.ra.wrap_at(360*u.deg).deg
    dec  = world.dec.deg
    return float(ra.min()), float(ra.max()), float(dec.min()), float(dec.max())

ra_min, ra_max, dec_min, dec_max = visible_world_box(w, nx, ny)
center = w.pixel_to_world(nx/2, ny/2)

def nearest_multiple(x, step):
    return step * round(x/step)

dec_line = nearest_multiple(float(center.dec.deg), 30.0)
ra_line  = nearest_multiple(float(center.ra.wrap_at(360*u.deg).deg), 30.0)

# RA labels every 30° within visible range
for ra_deg in np.arange(0, 360, 30):
    coord = SkyCoord(ra=ra_deg*u.deg, dec=dec_line*u.deg, frame="fk5")
    x, y  = w.world_to_pixel(coord)
    if 0 <= x < nx and 0 <= y < ny and np.isfinite(x) and np.isfinite(y):
        ax.text(x, y, f"{int(ra_deg)%360}°", color="white", fontsize=FONT, ha="center", va="center")

# Dec labels every 30°
for dec_deg in np.arange(-90, 91, 30):
    coord = SkyCoord(ra=ra_line*u.deg, dec=dec_deg*u.deg, frame="fk5")
    x, y  = w.world_to_pixel(coord)
    if 0 <= x < nx and 0 <= y < ny and np.isfinite(x) and np.isfinite(y):
        s = f"{'+' if dec_deg>=0 else ''}{int(dec_deg)}°"
        ax.text(x, y, s, color="white", fontsize=FONT, ha="center", va="center")

# Optional white border circle
if R_MASK is not None:
    border = Circle((cx, cy), R_MASK, transform=ax.get_transform("pixel"),
                    fill=False, edgecolor="white", linewidth=0.8, alpha=0.95)
    ax.add_patch(border)

# Corner info box
freq_mhz = float(z.frequency.isel(frequency=fsel).values)/1e6 if "frequency" in z.coords else float("nan")
time_str = str(z.time.isel(time=tsel).values) if "time" in z.coords else f"time index {tsel}"
info_text = f"t = {time_str}\nν ≈ {freq_mhz:.2f} MHz"
ax.text(0.02, 0.98, info_text, transform=ax.transAxes, va="top", ha="left",
        fontsize=FONT-2, color="white",
        bbox=dict(facecolor="black", alpha=0.6, edgecolor="white", boxstyle="round,pad=0.3"))

plt.tight_layout()
plt.show()
