In [1]:
%matplotlib widget

In [132]:
import glob
import numpy as np
from image_plane_correction.flow import Flow
from image_plane_correction.catalogs import theoretical_sky
from image_plane_correction.preprocessing import normalize, normalize_high, preprocess
from image_plane_correction.postprocessing import rotational_translational_component, integrate_2d_fourier, gradient
from image_plane_correction.util import indices, rescale_quantile, gaussian_filter, circular_mask
from image_plane_correction import data
from image_plane_correction.interactive import plot_flow, plot_image, toggle_images

In [133]:
visualize = False

## Load image and PSF

In [147]:
image_fns = sorted(glob.glob('/lustre/gh/main/10h/2024-12-18/Run_Rescue_20260205_060823/55MHz/snapshots/55MHz*I-image.fits'))
psf_fns = sorted(glob.glob('/lustre/gh/main/10h/2024-12-18/Run_Rescue_20260205_060823/55MHz/snapshots/55MHz*-psf.fits'))

In [149]:
assert len(image_fns) == len(psf_fns)

In [135]:
ind = 0
image, imwcs = data.fits_image(image_fns[ind])
psf, psfwcs = data.fits_image(psf_fns[ind])

## Calculate model sky, preprocess, and run flow

In [136]:
sky = theoretical_sky(imwcs, psf, catalog="VLSSR", max_flux=10, path='/home/claw/vlssr_radecpeak.txt')
image_processed, sky_processed = preprocess(image, sky, weight=1.5)
flow = Flow.brox(image_processed, sky_processed, alpha=1.3, gamma=150, scale_factor=0.7)

In [137]:
if visualize:
    toggle_images(normalize_high(image), normalize_high(sky))

In [138]:
if visualize:
    plot_flow(flow, mask=circular_mask(r=0.65))

## Evaluate quality

In [139]:
dewarped = flow.apply(image)

In [140]:
if visualize:
    toggle_images(normalize_high(image), normalize_high(dewarped))

In [141]:
offsets = np.nan_to_num(flow.offsets)
if not offsets.any():
    print("Warning: All offsets zero")

shift_mag = np.linalg.norm(offsets, axis=2)
shift_mean = np.mean(shift_mag)
shift_5, shift_median, shift_95 = np.percentile(shift_mag, [5, 50, 95])
print(f"Shift magnitude mean {shift_mean:.1f} pix (5, 50, 95 percentiles: {shift_5:.1f}, {shift_median:.1f}, {shift_95:.1f} pix)")

Shift magnitude mean 2.9 pix (5, 50, 95 percentiles: 0.7, 2.0, 6.4 pix)


In [142]:
pcts = [5, 32, 50, 68, 95]
residuals = np.abs(np.percentile(dewarped-sky, pcts)) - np.abs(np.percentile(image-sky, pcts))

In [143]:
assert all(residuals < 0), f"Not all residuals reduced. (Percentile, residual difference): {list(zip(pcts, residuals.tolist()))}"