In [None]:
def _add_shape_pixel(ax, segments, tr, alpha=0.8, color="k"):
    """Add shapefile to pixel-coordinate panels"""
    if segments:
        segments_pixel = []
        for seg in segments:
            if len(seg) > 0:
                pixel_coords = []
                for x_geo, y_geo in seg:
                    col, row = ~tr * (x_geo, y_geo)
                    pixel_coords.append([col, row])
                segments_pixel.append(np.array(pixel_coords))
        ax.add_collection(LineCollection(segments_pixel, colors=color, linewidths=0.6, zorder=6, alpha=alpha))
        


        
def visualize_batch(epoch, model, normalizer, dataloader, batch_idx=0, sample_idx=0,
                    device="cuda", save=False, train=True, shp_path=shp_path, avg_thr=0.2,
                    pandora_df=pandora_df):
    """
    Visualize model predictions from a DataLoader following the dataset's pandora setup,
    with Pandora station plotting (consistent colors + per-station RMSE).
    """
    model.eval()

    # --- Get a batch ---
    try:
        for i, batch in enumerate(dataloader):
            if i == batch_idx:
                break
        else:
            print(f"Batch index {batch_idx} not found in dataloader")
            return
    except Exception as e:
        print(f"Error getting batch from dataloader: {e}")
        return

    # --- Extract sample ---
    batch_size = batch["masked_img"].shape[0]
    if sample_idx >= batch_size:
        print(f"Sample index {sample_idx} not available in batch of size {batch_size}")
        sample_idx = 0
        print(f"Using sample index {sample_idx} instead")

    sample = {}
    for key, value in batch.items():
        if isinstance(value, torch.Tensor):
            sample[key] = value[sample_idx]
        elif isinstance(value, (list, tuple)):
            sample[key] = value[sample_idx]
        else:
            sample[key] = value

    # --- Pandora data ---
    p_mask = sample.get('p_mask', torch.zeros_like(sample["known_mask"])).numpy().astype(bool)
    p_val_map = sample.get('p_val_mask', torch.zeros_like(sample["known_mask"])).numpy()

    # --- Inputs ---
    img = sample["masked_img"].unsqueeze(0).to(device)
    mask_obs = sample["known_mask"].unsqueeze(0).to(device)
    target = sample["target"].unsqueeze(0).to(device)

    if train:
        mask = sample["known_and_fake_mask"].unsqueeze(0).to(device)
    else:
        mask = mask_obs

    # --- Prediction ---
    with torch.no_grad():
        pred, out_mask = model(img, mask)

    inp_np = normalizer.denormalize_image(img[0,0].cpu().numpy())
    mask_obs_np = mask_obs[0,0].cpu().numpy().astype(bool)
    pred_np = normalizer.denormalize_image(pred[0,0].cpu().numpy())
    tgt_np = normalizer.denormalize_image(target[0,0].cpu().numpy())

    from scipy.ndimage import gaussian_filter
    pred_np = gaussian_filter(pred_np, sigma=0.5)

    if train:
        mask_np = sample["known_and_fake_mask"][0].cpu().numpy().astype(bool)
    else:
        mask_np = mask_obs_np

    hole_mask = ~mask_np
    pred_np_final = pred_np.copy()
    pred_np_final[hole_mask] = pred_np[hole_mask]

    # --- Metadata ---
    path = sample["path"]
    date = path.split('/')[-1].split('.')[0]
    date = datetime.strptime(date, "%Y%m%d%H%M%S").strftime("%Y-%m-%d %H:%M:%S")

    with rasterio.open(path) as src:
        tr = src.transform
        crs = src.crs
        H, W = src.height, src.width

    ts = datetime.strptime(date, "%Y-%m-%d %H:%M:%S")
    xmin, ymin, xmax, ymax = array_bounds(H, W, tr)

    segments = []
    segments = load_shapefile_segments_pyshp(shp_path, crs)
    
    pandora_rmse = None
    pandora_rho = None
    n_pandora_stations = 0
    
    if p_mask.any():
        pandora_rows, pandora_cols = np.where(p_mask)
        if len(pandora_rows) > 0:
            pandora_values = p_val_map[pandora_rows, pandora_cols]
            
            # Denormalize Pandora values if possible
            if hasattr(normalizer, 'denormalize_pandora'):
                pandora_values_denorm = np.array([normalizer.denormalize_pandora(v) for v in pandora_values])
            elif hasattr(normalizer, 'denormalize_pandora_array'):
                pandora_values_denorm = normalizer.denormalize_pandora_array(pandora_values)
            else:
                pandora_values_denorm = pandora_values
            
            pred_at_pandora = pred_np_final[pandora_rows, pandora_cols]
            
            valid_mask = np.isfinite(pandora_values_denorm) & np.isfinite(pred_at_pandora)
            if valid_mask.sum() > 0:
                pandora_valid = pandora_values_denorm[valid_mask]
                pred_valid = pred_at_pandora[valid_mask]
                
                pandora_rmse = np.sqrt(np.mean((pandora_valid - pred_valid)**2))
                if len(pandora_valid) > 1:
                    from scipy import stats
                    pandora_rho, _ = stats.spearmanr(pandora_valid, pred_valid)
                n_pandora_stations = len(pandora_valid)

    # --- Colormap ---
    finite_vals = tgt_np[np.isfinite(tgt_np)]
    if finite_vals.size:
        vmin, vmax = np.percentile(finite_vals, [2, 98])
    else:
        vmin, vmax = 0.0, 1.0

    cmap_v = plt.cm.viridis.copy()
    cmap_v.set_bad(color="white")
    
    def _add_pandora_stations(ax, tr, crs, H, W, ts, p_val_map, vmin, vmax, cmap_v,
                         pandora_df=pandora_df, add_legend=False, pred_np_final=None):

        legend_handles = []
        if pandora_df is None or ts is pd.NaT:
            return legend_handles

        # Filter stations near this timestamp
        dfw = pandora_df[
            (pandora_df["datetime"] >= ts - pd.Timedelta("30min")) &
            (pandora_df["datetime"] <= ts + pd.Timedelta("30min"))
        ].copy()

        if dfw.empty:
            return legend_handles

        # Pick closest record per station
        if "station" in dfw.columns:
            dfw["abs_dt"] = (dfw["datetime"] - ts).abs()
            dfw = dfw.sort_values(["station", "abs_dt"]).groupby("station", as_index=False).first()

        # --- Global station colormap ---
        stations_all = pandora_df["station"].unique()
        color_map = dict(zip(stations_all, cm.tab20c(np.linspace(0, 1, len(stations_all)))))

        # Coordinates
        lons = _wrap_lon_180(pd.to_numeric(dfw["lon"], errors="coerce").to_numpy())
        lats = pd.to_numeric(dfw["lat"], errors="coerce").to_numpy()
        ok_ll = np.isfinite(lons) & np.isfinite(lats) & (lats >= -90) & (lats <= 90)

        if ok_ll.sum() == 0:
            return legend_handles

        dfw = dfw.loc[ok_ll].copy()
        rr, cc = _lonlat_to_rowcol_vec(lons[ok_ll], lats[ok_ll], tr, crs)
        labels = dfw["station"].astype(str).to_numpy()

        rr_i = rr.astype(int)
        cc_i = cc.astype(int)
        ok_in = (rr_i >= 0) & (rr_i < H) & (cc_i >= 0) & (cc_i < W)
        rr_i, cc_i, labels = rr_i[ok_in], cc_i[ok_in], labels[ok_in]

        if rr_i.size == 0:
            return legend_handles

        xs, ys = _rowcol_to_xy_vec(rr_i, cc_i, tr)

        # Normalize values
        norm = plt.Normalize(vmin=vmin, vmax=vmax)

        # Build KD-tree of valid Pandora pixels
        p_val_arr = np.array(p_val_map, dtype=np.float64)
        p_val_arr[p_val_arr == 0.0] = np.nan
        valid_mask = np.isfinite(p_val_arr)
        if not np.any(valid_mask):
            return legend_handles
        valid_coords = np.argwhere(valid_mask)
        valid_values = p_val_arr[valid_mask]
        from scipy.spatial import cKDTree
        tree = cKDTree(valid_coords)

        # Loop stations
        for x, y, lab, r, c in zip(xs, ys, labels, rr_i, cc_i):
            outline_color = color_map.get(lab, "red")

            # Nearest Pandora pixel
            dist, idx = tree.query([r, c])
            pandora_val = valid_values[idx]

            # Fill color
            if np.isfinite(pandora_val):
                fill_color = cmap_v(norm(pandora_val))
            else:
                fill_color = "black"

            ax.scatter(
                x, y, s=100, marker='D',
                facecolor=fill_color,
                edgecolor=outline_color, linewidth=1.0,
                zorder=6
            )

            # Per-station RMSE
            rmse_str = ""
            if pred_np_final is not None and np.isfinite(pandora_val):
                pred_val = pred_np_final[r, c]
                if np.isfinite(pred_val):
                    rmse = np.sqrt((pandora_val - pred_val) ** 2)
                    rmse_str = f" (RMSE={rmse:.2E})"

            # Legend proxy
            if add_legend:
                proxy = Line2D(
                    [0], [0], marker="D", color="none",
                    markerfacecolor=fill_color, markeredgecolor=outline_color, markeredgewidth=1,
                    markersize=9, label=f"{lab}\n{rmse_str}"
                )
                legend_handles.append(proxy)

        return legend_handles

    def add_pandora_stations(ax, add_legend=False):
        """Add Pandora stations from pandora_df with global colors + per-station RMSE."""
        return _add_pandora_stations(
            ax=ax, tr=tr, crs=crs, H=H, W=W, ts=ts,
            p_val_map=p_val_map, vmin=vmin, vmax=vmax, cmap_v=cmap_v,
            pandora_df=pandora_df, add_legend=add_legend, pred_np_final=pred_np_final
        )

    fig, ax = plt.subplots(1, 3, figsize=(14, 6))

    disp_inp = np.ma.masked_invalid(np.ma.array(inp_np, mask=~mask_obs_np))
    im0 = rio_show(disp_inp, transform=tr, ax=ax[0], cmap=cmap_v, vmin=vmin, vmax=vmax)
    ax[0].set_xlim(xmin, xmax); ax[0].set_ylim(ymin, ymax)
    ax[0].set_aspect('equal', adjustable='box')
    ax[0].margins(0); ax[0].autoscale(False)
    ax[0].set_title(f"Input (N/A = white)")
    ax[0].axis("off")
    ax[0].add_collection(LineCollection(segments, colors='k', linewidths=0.5, zorder=3))
    legend_handles = add_pandora_stations(ax[0], add_legend=True)

    disp_pred = np.ma.masked_invalid(pred_np_final)
    im1 = rio_show(disp_pred, transform=tr, ax=ax[1], cmap=cmap_v, vmin=vmin, vmax=vmax)
    ax[1].set_xlim(xmin, xmax); ax[1].set_ylim(ymin, ymax)
    ax[1].set_aspect('equal', adjustable='box')
    ax[1].margins(0); ax[1].autoscale(False)
    ax[1].set_title("Reconstruction")
    ax[1].axis("off")
    ax[1].add_collection(LineCollection(segments, colors='k', linewidths=0.5, zorder=3))
    add_pandora_stations(ax[1])        

    filled_only = np.full_like(pred_np_final, np.nan, dtype=np.float32)
    filled_only[~mask_obs_np] = pred_np_final[~mask_obs_np]
    ax[2].imshow(np.ma.array(filled_only, mask=np.isnan(filled_only)), cmap=cmap_v, vmin=vmin, vmax=vmax)
    ax[2].set_title("Filled Values in Holes")
    ax[2].axis("off")
    _add_shape_pixel(ax[2], segments, tr)

    if legend_handles:
        ax[0].legend(
            handles=legend_handles,
            bbox_to_anchor=(-0.85, 1),
            loc="upper left",
            frameon=True, fontsize=13, markerscale=1.3
        )
    cbar_ax = fig.add_axes([0.29, 0.05, 0.6, 0.04])  # [left, bottom, width, height] in figure coords
    cbar = fig.colorbar(im1.get_images()[0], cax=cbar_ax, orientation="horizontal")
    cbar.set_label("NO₂ (molec·cm$^{-2}$)", fontsize=15)


    plt.suptitle(f"{date}\nStation RMSE: {pandora_rmse:.4E} | Station ρ: {pandora_rho:.2f}", fontsize=18)
    plt.tight_layout()
    plt.show()
    
    plt.close()
    
val_ds = TempoPandoraInpaintDataset(tif_dir=tif_dir, normalizer=normalizer,train=False,file_list=val_files,pandora_csv=pandora_df)

val_loader = DataLoader(
    val_ds, 
    batch_size=4, 
    shuffle=False,
    collate_fn=custom_collate_fn  # Add this line
)
model = OriginalPlusMinimalAttentionDeep(base_ch=32).cuda()  # Your original + tiny attention
model.load_state_dict(torch.load('/hpc/home/srs108/final_model.pt'))
for i in range(5):
    visualize_batch(epoch=5, model=model, normalizer=normalizer, dataloader=val_loader,
                        batch_idx=i, sample_idx=3, device="cuda", save=True, train=False, shp_path=shp_path)
    