In [None]:
import logging
from pathlib import Path
import numpy as np
import xarray as xr
import astropy
from astropy.io import fits
import xradio
import math

from ovro_lwa_portal.fits_to_zarr_xradio import convert_fits_dir_to_zarr


logging.basicConfig(level=logging.INFO)


## Conversion

In [None]:
from pathlib import Path
import os, re
from collections import defaultdict

# --- config ---
N_TIME_STEPS = 10
BASE_DIR     = Path("/lustre/nkosogor/Xarray/")   
INPUT_DIR    = BASE_DIR                           # where your FITS live
SUBSET_DIR   = BASE_DIR / f"fits_subset_first{N_TIME_STEPS}_times"
SUBSET_DIR.mkdir(parents=True, exist_ok=True)

# Regex for your OVRO-LWA filenames (handles *_fixed.fits too)
PAT = re.compile(
    r"^(?P<date>\d{8})_(?P<hms>\d{6})_(?P<sb>\d+)MHz_averaged_.*-I-image(?:_fixed)?\.fits$"
)

# 1) Group FITS by time key (YYYYMMDD_HHMMSS)
by_time = defaultdict(list)
for f in sorted(INPUT_DIR.glob("*.fits")):
    m = PAT.match(f.name)
    if not m:
        continue
    tkey = f"{m.group('date')}_{m.group('hms')}"
    by_time[tkey].append(f)

if not by_time:
    raise FileNotFoundError(f"No matching FITS found in {INPUT_DIR}")

# 2) First N time keys
time_keys = sorted(by_time.keys())[:N_TIME_STEPS]
print(f"Selected {len(time_keys)} times:", time_keys[:5], "..." if len(time_keys) > 5 else "")

# 3) Symlink all subbands for those times
linked = 0
for tk in time_keys:
    for src in by_time[tk]:
        dst = SUBSET_DIR / src.name
        if not dst.exists():
            os.symlink(src, dst)
            linked += 1
print(f"Symlinked {linked} files into {SUBSET_DIR}")

# 4) Convert subset to a separate Zarr so you don't overwrite the full run
subset_zarr_name = f"ovro_lwa_first{N_TIME_STEPS}times.zarr"
out_subset = convert_fits_dir_to_zarr(
    input_dir=SUBSET_DIR,
    out_dir=BASE_DIR / "zarr_out",
    zarr_name=subset_zarr_name,
    fixed_dir=BASE_DIR / "fixed_fits",   # reuse OK
    chunk_lm=1024,
    rebuild=True,
)
out_subset


## Summary of the zarr file

In [None]:
from pathlib import Path
BASE_DIR = Path("/lustre/nkosogor/Xarray/")
ZARR_DIR = BASE_DIR / "zarr_out"

TARGET = ZARR_DIR / "ovro_lwa_first10times.zarr"   # or use ZARR_DIR / ZARR_NAME
print("Target exists?", TARGET.exists(), "->", TARGET)

import xarray as xr
z = xr.open_zarr(str(TARGET), chunks={"l": 512, "m": 512})

# tidy print
z.attrs = {}
for v in z.data_vars: z[v].encoding = {}
print("Dims:", dict(z.dims))
print("Vars:", list(z.data_vars))
print("Coords:", list(z.coords))


### Helpers

In [None]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

CMAP       = "inferno"
VMIN, VMAX = -1.0, 16.0
POL        = 0  # default polarization index

def nearest_freq_index(freq_values_hz: np.ndarray, target_mhz: float) -> int:
    target_hz = target_mhz * 1e6
    return int(np.argmin(np.abs(freq_values_hz - target_hz)))

def auto_pick_tf(z: xr.Dataset, pol: int = 0) -> tuple[int, int]:
    """
    Find the first (time, freq) plane with any finite data for given polarization.
    Picks the lowest valid frequency at that time. Uses dask-backed reduction.
    """
    finite_frac = np.isfinite(z.SKY.isel(polarization=pol)).mean(dim=("l", "m"))
    arr = finite_frac.compute().values  # small (time, frequency) array
    for ti in range(arr.shape[0]):
        mask = arr[ti] > 0
        if mask.any():
            freq_vals = z.frequency.values[mask]
            f_val = freq_vals.min()
            return ti, int(np.where(z.frequency.values == f_val)[0][0])
    raise RuntimeError("All SKY planes are NaN; no valid (time, freq) found.")

def lm_cutout(z: xr.Dataset, l_center: float, m_center: float,
              dl: float, dm: float,
              t_idx: int, f_idx: int, pol: int = 0) -> xr.DataArray:
    """Extract an l/m rectangle centered at (l_center, m_center) with half-sizes dl, dm."""
    lmin, lmax = l_center - dl, l_center + dl
    mmin, mmax = m_center - dm, m_center + dm
    sel = dict(time=t_idx, frequency=f_idx, polarization=pol)
    return z.SKY.isel(**sel).sel(l=slice(lmin, lmax), m=slice(mmin, mmax))

def plot_da(da: xr.DataArray, title: str,
            vmin: float = VMIN, vmax: float = VMAX, cmap: str = CMAP,
            figsize=(6,5), add_colorbar=True):
    """Plot a 2D (l, m) DataArray; triggers compute only for this slice."""
    plt.figure(figsize=figsize)
    im = da.plot.imshow(cmap=cmap, vmin=vmin, vmax=vmax, add_colorbar=False)
    if add_colorbar:
        plt.colorbar(im, label="Jy/beam")
    plt.title(title)
    plt.xlabel("l"); plt.ylabel("m")
    plt.tight_layout()
    plt.show()


## Example Plots

### Single snapshot (some time and subband)

In [None]:
AUTO_PICK = False
TIME_IDX  = 1
FREQ_MHZ  = 46

if AUTO_PICK:
    t_idx, f_idx = auto_pick_tf_fast(z, pol=POL, max_times=8)
else:
    if TIME_IDX is None:
        raise ValueError("Set TIME_IDX if not using AUTO_PICK.")
    t_idx = TIME_IDX
    if FREQ_MHZ is not None:
        f_idx = int(np.argmin(np.abs(z.frequency.values - FREQ_MHZ*1e6)))
    else:
        frac = np.isfinite(z.SKY.isel(polarization=POL)).mean(dim=("l","m")).compute()
        mask = frac.isel(time=t_idx).values > 0
        if not mask.any():
            raise RuntimeError(f"No finite data at time index {t_idx}.")
        valid_freqs = z.frequency.values[mask]
        f_idx = int(np.where(z.frequency.values == valid_freqs.min())[0][0])

freq_mhz = float(z.frequency.values[f_idx] / 1e6)
try:
    t_val = float(z.time.values[t_idx]); title_time = f"{t_val:.8f} MJD"
except Exception:
    title_time = str(z.time.values[t_idx])

sky = z.SKY.isel(time=t_idx, frequency=f_idx, polarization=POL)

# Save to a file instead of showing (headless)
out_png = f"/lustre/nkosogor/Xarray/snapshot_t{t_idx}_f{freq_mhz:.1f}MHz.png"
plot_da(sky, f"(snapshot) t={title_time}, f={freq_mhz:.1f} MHz")


### LM cutout (small patch)

In [None]:
import numpy as np
import xarray as xr

def _slice_any_order(coord: xr.DataArray, lo: float, hi: float) -> slice:
    """Return a slice that respects the coord order (asc/desc)."""
    if coord.size == 0:
        return slice(lo, hi)
    if float(coord[0]) <= float(coord[-1]):   # ascending
        return slice(lo, hi)
    else:                                      # descending
        return slice(hi, lo)

def lm_cutout(z: xr.Dataset, l_center: float, m_center: float,
              dl: float, dm: float,
              t_idx: int, f_idx: int, pol: int = 0) -> xr.DataArray:
    da = z.SKY.isel(time=t_idx, frequency=f_idx, polarization=pol)
    lmin, lmax = l_center - dl, l_center + dl
    mmin, mmax = m_center - dm, m_center + dm
    l_slice = _slice_any_order(da["l"], lmin, lmax)
    m_slice = _slice_any_order(da["m"], mmin, mmax)
    sub = da.sel(l=l_slice, m=m_slice)

    # if empty, fall back to nearest pixel box around (l_center, m_center)
    if sub.size == 0 or sub.sizes.get("l", 0) == 0 or sub.sizes.get("m", 0) == 0:
        # pick indices nearest to center and then expand by ~dl,dm in index space
        li = int(np.argmin(np.abs(da["l"].values - l_center)))
        mi = int(np.argmin(np.abs(da["m"].values - m_center)))
        # choose ~200x200 window or within bounds
        w = max(1, min(200, da.sizes["l"]//10))
        h = max(1, min(200, da.sizes["m"]//10))
        li0, li1 = max(0, li-w), min(da.sizes["l"], li+w)
        mi0, mi1 = max(0, mi-h), min(da.sizes["m"], mi+h)
        sub = da.isel(l=slice(li0, li1), m=slice(mi0, mi1))

    # final safety: if all-NaN, raise a clearer message
    frac_finite = float(np.isfinite(sub.isel(l=slice(0, min(10, sub.sizes['l'])),
                                             m=slice(0, min(10, sub.sizes['m'])))).mean().compute())
    if frac_finite == 0.0:
        raise RuntimeError("Cutout has no finite data. Try a different (l,m) region or (time,freq).")
    return sub


In [None]:
def plot_da(da: xr.DataArray, title: str,
            vmin: float | None = None, vmax: float | None = None, cmap: str = "inferno",
            figsize=(6,5), add_colorbar=True):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    if vmin is None and vmax is None:
        im = da.plot.imshow(robust=True, add_colorbar=False, cmap=cmap)
    else:
        im = da.plot.imshow(vmin=vmin, vmax=vmax, add_colorbar=False, cmap=cmap)
    if add_colorbar:
        plt.colorbar(im, label="Jy/beam")
    plt.title(title); plt.xlabel("l"); plt.ylabel("m")
    plt.tight_layout(); plt.show()


In [None]:
L_CENTER, M_CENTER = 0.0, 0.0
DL, DM = 0.10, 0.10

cut = lm_cutout(z, L_CENTER, M_CENTER, DL, DM, t_idx=t_idx, f_idx=f_idx, pol=POL)
plot_da(cut, f"(cutout) t={title_time}, f={freq_mhz:.1f} MHz, "
              f"l∈[{L_CENTER-DL:+.2f},{L_CENTER+DL:+.2f}], "
              f"m∈[{M_CENTER-DM:+.2f},{M_CENTER+DM:+.2f}]")


In [None]:

# ---- helpers ----
def _slice_any_order(coord: xr.DataArray, lo: float, hi: float) -> slice:
    if coord.size == 0:
        return slice(lo, hi)
    return slice(lo, hi) if float(coord[0]) <= float(coord[-1]) else slice(hi, lo)

def _to_float(val):
    try:
        return float(val)
    except Exception:
        return val

# ---- build cutout cube once (time, frequency, l, m) ----
l_slice = _slice_any_order(z["l"], L_CENTER - DL, L_CENTER + DL)
m_slice = _slice_any_order(z["m"], M_CENTER - DM, M_CENTER + DM)
cube = z.SKY.isel(polarization=POL).sel(l=l_slice, m=m_slice)

T = cube.sizes["time"]
F = cube.sizes["frequency"]
tvals = cube.time.values
fvals_mhz = cube.frequency.values / 1e6

# quick availability map
finite_map = (np.isfinite(cube).mean(dim=("l","m")).compute() > 0).values

# ---- manual grid: 4 columns per row ----
COLS = 4
N = T * F
ROWS = math.ceil(N / COLS)

# choose a sane figure size per panel
panel_w, panel_h = 3.0, 2.6
fig, axes = plt.subplots(ROWS, COLS, figsize=(panel_w*COLS, panel_h*ROWS))
if ROWS == 1 and COLS == 1:
    axes = np.array([[axes]])
elif ROWS == 1:
    axes = np.array([axes])
elif COLS == 1:
    axes = axes[:, None]

k = 0
for ti in range(T):
    for fi in range(F):
        r, c = divmod(k, COLS)
        ax = axes[r, c]

        da = cube.isel(time=ti, frequency=fi)
        has_data = bool(finite_map[ti, fi])
        title = f"t={_to_float(tvals[ti]):.6f} MJD\nf={float(fvals_mhz[fi]):.2f} MHz"

        if has_data:
            # small cutouts: using .values is fine and explicit
            arr = da.values  # triggers compute for this slice
            im = ax.imshow(arr, origin="lower", vmin=-1.0, vmax=16.0, cmap="inferno", aspect="auto")
        else:
            ax.text(0.5, 0.5, "NaN", ha="center", va="center", fontsize=10)
        ax.set_title(title, fontsize=9)
        ax.set_xticks([]); ax.set_yticks([])
        k += 1

# hide any leftover empty panels
while k < ROWS*COLS:
    r, c = divmod(k, COLS)
    axes[r, c].axis("off")
    k += 1

plt.tight_layout()
plt.show()


In [None]:


def _nearest_index(coord: xr.DataArray, x: float) -> int:
    """Return index of coord value nearest to x (works with asc/desc)."""
    return int(np.nanargmin(np.abs(coord.values - x)))

def dynamic_spectrum_center_pixel(z: xr.Dataset, l0: float, m0: float, pol: int = 0) -> xr.DataArray:
    """
    Return a (time, frequency) dynamic spectrum for the **center pixel**
    nearest to (l0, m0). Uses index selection to ensure we grab the exact pixel.
    """
    # Find nearest pixel indices globally
    li = _nearest_index(z["l"], l0)
    mi = _nearest_index(z["m"], m0)

    # Slice exact pixel across all time × freq for given polarization
    ds = z.SKY.isel(l=li, m=mi, polarization=pol)

    # Canonical sort for plotting
    if "time" in ds.coords:
        ds = ds.sortby("time")
    if "frequency" in ds.coords:
        ds = ds.sortby("frequency")

    # Drop scalar coords to make it a clean 2D array (time, frequency)
    ds = ds.squeeze(drop=True)

    # Attach convenience attrs for reporting
    ds.attrs["pixel_l"]  = float(z["l"].values[li])
    ds.attrs["pixel_m"]  = float(z["m"].values[mi])
    ds.attrs["pol_idx"]  = pol
    ds.attrs["l_idx"]    = li
    ds.attrs["m_idx"]    = mi
    return ds

def plot_dynspec(ds2d: xr.DataArray, title: str,
                 figsize=(7,4), cmap="inferno", vmin=None, vmax=None, add_colorbar=True):
    # Convert frequency coord to MHz for labeling
    ds = ds2d
    if "frequency" in ds.coords:
        ds = ds.assign_coords(frequency_mhz=(ds.frequency / 1e6)).swap_dims({"frequency": "frequency_mhz"})

    plt.figure(figsize=figsize)
    im = ds.plot.imshow(
        x="time", y="frequency_mhz",
        cmap=cmap,
        add_colorbar=False,
        **({} if (vmin is None and vmax is None) else {"vmin": vmin, "vmax": vmax}),
        robust=(vmin is None and vmax is None),
    )
    ax = im.axes
    ax.set_aspect("auto")  # <- set aspect on axes, not via xarray kwarg
    if add_colorbar:
        plt.colorbar(im, label="Jy/beam", ax=ax)
    ax.set_ylabel("frequency (MHz)")
    ax.set_title(title)
    plt.tight_layout()
    plt.show()


In [None]:
center_l, center_m = L_CENTER, M_CENTER
dyn = dynamic_spectrum_center_pixel(z, center_l, center_m, pol=POL)

print(f"center pixel -> l={dyn.attrs['pixel_l']:+.6f}, m={dyn.attrs['pixel_m']:+.6f} "
      f"(indices l={dyn.attrs['l_idx']}, m={dyn.attrs['m_idx']}), pol={dyn.attrs['pol_idx']}")

plot_dynspec(
    dyn,
    title=f"(dynspec @ pixel near l={center_l:+.3f}, m={center_m:+.3f})",
    # set vmin/vmax if you want fixed scaling; otherwise omit for robust autoscale:
    # vmin=-1.0, vmax=16.0,
)


### Quick difference map 

In [None]:
MODE = "freq"   # "freq" to diff adjacent freqs at same time; "time" to diff adjacent times

if MODE == "freq":
    f1, f2 = max(0, f_idx-1), f_idx
    a = z.SKY.isel(time=t_idx, frequency=f1, polarization=POL)
    b = z.SKY.isel(time=t_idx, frequency=f2, polarization=POL)
    f1_mhz = float(z.frequency.values[f1]/1e6)
    f2_mhz = float(z.frequency.values[f2]/1e6)
    title = f"(diff) t={title_time}, f={f2_mhz:.1f}-{f1_mhz:.1f} MHz"
else:
    t1, t2 = max(0, t_idx-1), t_idx
    a = z.SKY.isel(time=t1, frequency=f_idx, polarization=POL)
    b = z.SKY.isel(time=t2, frequency=f_idx, polarization=POL)
    title = f"(diff) time Δ at f={freq_mhz:.1f} MHz (t{t2}-t{t1})"

diff = b - a
plot_da(diff, title, vmin=-2.0, vmax=2.0)
