In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

import typing as t

from matplotlib import pyplot
import matplotlib.patches as patches

import numpy
from numpy.typing import NDArray, ArrayLike

import cupy

## Overview

Steps in MLs reconstruction:

 - Setup
   - Make incident probe(s)
   - Generate scan positions
   - Generate object, large enough to contain $[1/d, 1/d]$ region around every scan position
   - Load and process measured patterns
 - Running
  - Pre-Calculate probe intensity (sum of probe intensity in object space)
    - High illumination -> good est. of object wavefunction
  - Pre-Calculate object intensity (sum of object intensity in probe space)
    - High object magnitude -> good signal of probe wavefunction
    - important for opaque regions/obstructions

  - Per iteration
    - Split probe positions into groupings

    - Per grouping
      - Add to probe & object intensity

      - Apply forward model to detector plane
      - Apply modulus constraint at detector plane
      - Calculate exit wavefunction update chi
      - Calculate object and probe update directions, per probe position
      - Average object and probe update directions
      - Calculate probe and object update step size, per probe position
      - Apply probe and object update weighted by object and probe intensity

    - Update probe & object intensity

## Conventions

- Reciprocal space: Zero-frequency in corner
- Real space: (0, 0) in center
- Energy is normalized in both reciprocal space:
  - $ \sum \left|f(x, y)\right|^2 = 1 $
  - $ \sum \left|F(k_x, k_y)\right|^2 = N_x N_y $
- Probe and object is stored in real space

## Implications

- FFTs need 'forward' normalization
- Reciprocal space to real space: `f = fftshift(ifft2(F, norm='forward'), axes=(-2, -1))`
- Real space to reciprocal space: `F = fft2(ifftshift(f, axes=(-2, -1)), norm='forward')`
- Normalize energy in real space: `abs2(f) / numpy.prod(f.shape[-2:])`
- Normalize amplitude in real space: `f / numpy.sqrt(numpy.prod(f.shape[-2:]))`
- Never need to fftshift in reciprocal space

In [2]:
from pathlib import Path
from ptycho_lebeau.metadata import AnyMetadata

#meta = AnyMetadata.parse_file("/Users/colin/Downloads/mos2/1/mos2/mos2_0.00_dstep1.0.json")
meta = AnyMetadata.parse_file(Path("~/mos2/1/mos2/mos2_0.00_dstep1.0.json").expanduser())
print(meta.json(indent='    '))

In [3]:
from phaser.io.empad import load_4d

assert meta.path is not None
assert meta.raw_filename is not None
raw = load_4d(meta.path / meta.raw_filename)

# fftshift patterns to corners
raw = numpy.fft.fftshift(raw, axes=(-1, -2)).reshape((-1, *raw.shape[-2:]))
# normalize patterns
raw /= numpy.sum(raw, axis=(-2, -1))[:, None, None]

dose = 100000  # e/A^2
px_area = numpy.prod(numpy.array(meta.scan_step) * 1e10)
dose_per_pattern = dose * px_area
# e * 1.602e-19 C/e / 1e-3 s * 1e12 pA/A -> pA
print(f"Equiv. beam current: {(dose_per_pattern * 1.6021766e-07) / 1e-3:.3f} pA")

# apply poisson noise
noisy = numpy.random.poisson(dose_per_pattern * raw).astype(numpy.float32) / dose_per_pattern

## Various helper functions

In [4]:
from phaser.utils.num import fft2, ifft2, abs2

def calc_groupings(grouping: int = 8, seed: t.Any = None) -> list[NDArray[numpy.int_]]:
    rng = numpy.random.RandomState(seed)
    idxs = numpy.arange(raw.shape[0])
    rng.shuffle(idxs)
    return numpy.array_split(idxs, numpy.ceil(raw.shape[0] / grouping).astype(numpy.int_))

## Make probe & propagators

In [5]:
from phaser.utils.num import Sampling
from phaser.utils.physics import Electron

wavelength = Electron(meta.voltage).wavelength
print(wavelength)

kstep = meta.diff_step*1e-3 / wavelength
(b, a) = (1/kstep,) * 2

grid = Sampling(raw.shape[-2:], extent=(b, a))

In [6]:
from phaser.utils.optics import make_focused_probe

init_probe = make_focused_probe(*grid.recip_grid(xp=cupy), wavelength, meta.conv_angle, defocus=meta.defocus*1e10)

fig, (ax1, ax2) = pyplot.subplots(ncols=2)

ax1.imshow(abs2(init_probe).get())
ax2.imshow(numpy.fft.fftshift(abs2(fft2(init_probe)).get()))
pyplot.show()

## Probe modes

In [7]:
from phaser.utils.optics import make_hermetian_modes

n_probes = 4
init_probes = make_hermetian_modes(init_probe, n_probes, powers=0.02)

fig, axs = pyplot.subplots(ncols=n_probes, sharex=True, sharey=True)

for (ax, mode) in zip(axs.flat, init_probes):
    ax.set_axis_off()
    ax.set_title("{:.2%}".format(cupy.sum(abs2(fft2(mode)))))
    ax.imshow(abs2(mode).get())

pyplot.show()

In [8]:
# debug mode correlations and intensities, should be orthogonal

img = numpy.zeros((len(init_probes),) * 2)

for i in range(len(init_probes)):
    for j in range(len(init_probes)):
        img[i,j] = numpy.abs(numpy.sum(init_probes[i] * numpy.conj(init_probes[j])))

energy_norm = numpy.prod(init_probe.shape[-2:])
img /= energy_norm

fig, ax = pyplot.subplots()
ax.imshow(img)

pyplot.show()

## Make scan

In [8]:
fig, axs = pyplot.subplots(ncols=3, sharex=True, sharey=True)

axs[0].imshow(numpy.fft.fftshift(raw[0]))
axs[1].imshow(numpy.fft.fftshift(raw[1]))
axs[2].imshow(numpy.fft.fftshift(raw[2]))
pyplot.show()

In [15]:
from phaser.utils.scan import make_raster_scan

(n_x, n_y) = meta.scan_shape
scan = make_raster_scan((n_y, n_x), numpy.array(meta.scan_step[::-1]) * 1e10).reshape(-1, 2)

print(scan[1])

fig, ax = pyplot.subplots()
ax.invert_yaxis()
ax.set_aspect(1.)
ax.scatter(scan[..., 1], scan[..., 0], s=1.)
ax.scatter([scan[2, 1]], [scan[2, 0]], c='red', s=2.)
pyplot.show()

## Make object

In [18]:
from phaser.utils.object import ObjectSampling

# we pad by 1/2 the object size, plus an extra pixel
obj_grid = ObjectSampling.from_scan(scan, grid.sampling, pad=grid.extent / 2. + grid.sampling)

(y_min, x_min) = obj_grid.min
(y_max, x_max) = obj_grid.max

fig, ax = pyplot.subplots()
ax.invert_yaxis()
ax.set_aspect(1.)
ax.add_patch(patches.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, fill=False, transform=ax.transData))
(yy, xx) = obj_grid.grid()
#ax.pcolormesh(yy, xx, yy)
ax.scatter(scan[:, 1], scan[:, 0], s=1, c='blue')

i = 4
ax.scatter([scan[i, 1]], [scan[i, 0]], s=2, c='red')
slicey, slicex = obj_grid.slice_at_pos(scan[i], raw.shape[-2:])
(starty, startx) = (yy[slicey.start, 0], xx[0, slicex.start])
(stopy, stopx) = (yy[slicey.stop, 0], xx[0, slicex.stop])

ax.add_patch(patches.Rectangle((startx, starty), stopx - startx, stopy - starty, fill=False, transform=ax.transData, color='red'))

print(obj_grid.slice_at_pos(scan[i], (128, 128)))

pyplot.show()

## LSQ-ML method

The LSQ-ML update step proceeds in three parts for each scan position $i$:

1. Calculating the optimal exit wavefront $\psi^{opt}_i$
2. Calculating the object and probe update directions $\Delta P_{i}$ and $\Delta O_{i}$
3. Calculating the object and probe optimal update steps $\alpha_{P,i}$ and $\alpha_{O,i}$

The optimal exit wavefront $\psi^{opt}_i$ is calculated to minimize the distance between expected/modeled intensity on the diffraction plane $I^e_{i}$ and the measured intensity $I^m_{i}$. Specifically, we choose the exit wavefront $\psi^{opt}_i$ to maximize the likelihood $p(I^m_i|\psi^{opt}_i)$. For this reference implementation, we replace this with a simple modulus constraint, which is the output of the ML optimization in the case where only Poisson noise is present in the amplitude likelihood model.

To calculate real-space object and probe update, given an optimal exit wavefront, we calculate an 'exit wave update' $\chi_{i,r}$ as the difference between our optimal and modeled exit wave:
$$
\chi_{i,r} = \psi^{opt}_{i,r} - \psi^{fwd}_{i,r}
$$

We attempt to split $\chi_{i,r}$ into probe and object updates, so as to minimize a real-space cost function $\mathcal{L}_r$, defined as the sum squared error between the optimal exit wavefront and the updated wavefront:
$$\begin{aligned}
\mathcal{L}_r &= \sum_{r} \left| \psi^{updated}_{i,r} - \psi^{opt}_{i,r} \right|^2
\end{aligned}$$

Object and probe update directions are calculated as the gradient of the real-space cost function $\mathcal{L}_r$:
$$\begin{aligned}
\Delta P_{i,r} &= -\nabla_P \mathcal{L}_r = \chi_{i,r} O^*_{r+r_i} \\
\Delta O_{i,r+r_i} &= -\nabla_O \mathcal{L}_r = \chi_{i,r} P^*_{r} \\
\end{aligned}$$

(These gradients come from the complex derivative of $\psi_{i, r} = P_{i, r} O_{i, r+r_i}$ assuming the second Wirtinger derivative $\partial/\partial \tilde{z} = 0$)

Over a batch $\mathcal{N} \subset \mathcal{N}_0$, we average object and probe update directions:
$$\begin{aligned}
\Delta P_{r} &= \frac{\sum_{i \in \mathcal{N}} \Delta P_{i,r}}{\sum_{i \in \mathcal{N}_0} \left| O_{r+r_i} \right|^2 + \delta_P} \\
\Delta O_{r+r_i} &= \frac{\sum_{i \in \mathcal{N}} \Delta O_{r+r_i}}{\sum_{i \in \mathcal{N}_0} \left| P_{r} \right|^2 + \delta_O}
\end{aligned}$$

"$\delta_O$ and $\delta_P$ can be seen as preconditioners of the gradient descent task", and "penalize large values in the update, particularly in the weakly illuminated regions."
The denominator sums can be considered the object and probe illumination, and are calculated for all positions to avoid noise amplification.

Next, we calculate the step sizes $\alpha_P$ and $\alpha_O$. These can be calculated using the matrix (22), or assuming a diagonal matrix, in which case the step sizes become:
$$\begin{aligned}
\alpha_{P,i} = \frac{\sum_r Re\left[ \chi_{i,r} (\Delta P_{i, r} O_{r+r_i} )^* \right]}{\sum_r \left| \Delta P_{i,r} O_{r+r_i} \right|^2 + \gamma} \\
\alpha_{O,i} = \frac{\sum_r Re\left[ \chi_{i,r} (\Delta O_{i, r} P_{r} )^* \right]}{\sum_r \left| \Delta O_{i,r} P_r \right|^2 + \gamma}
\end{aligned}$$

$\gamma = 0.1$ in fold_slice?

Finally, we update the probe and object as a weighted sum of updates:
$$\begin{aligned}
P_r^{updated} &= P_r + \frac{\sum_{i \in \mathcal{N}} \alpha_{P,i} \Delta P_r \left| O_{r+r_i} \right|^2}{\sum_{i \in \mathcal{N}} \left| O_{r+r_i} \right|^2} \\
O_r^{updated} &= O_r + \frac{\sum_{i \in \mathcal{N}} \alpha_{P,i} \Delta P_r \left| P_{r-r_i} \right|^2}{\sum_{i \in \mathcal{N}} \left| P_{r-r_i} \right|^2}
\end{aligned}$$

Note that step sizes are calculated per scan position (and then averaged as a function of illumination), but step directions are averaged together into a single step direction per grouping. In practice, we need to add a small epsilon in the denominator to prevent 0/0 conditions.

In [62]:
def check_nan(*arrs) -> bool:
    return not all(numpy.isfinite(arr.get()).all() for arr in arrs)

In [105]:
numpy.prod(probes.shape[-2:])

In [104]:
xp.sum(abs2(group_objs))

## Iteration loop

In [117]:
from phaser.utils.optics import fourier_shift_filter

# experimental data, shape (N, ky, kx)
exp_data = raw

xp = cupy

# random start object
rng = numpy.random.RandomState(seed=123456)
#obj_angle = gaussian_filter(rng.normal(0., 1e-4, obj_grid.shape), 4.)
obj_angle = rng.normal(0., 1e-6, obj_grid.shape)
obj = xp.array(numpy.cos(obj_angle) + numpy.sin(obj_angle) * 1.j)

# probe model
#probes = init_probes[:1].copy() / 0.94
probes = init_probes.copy()

# delta_P and delta-O in eq. 25
illum_reg_P = 1e-2  # strongly affects unilluminated regions
illum_reg_O = 1e-1  # not really required because obj_mag is uniformly large
# gamma in eq. 23, 0.1 in fold_slice?
gamma = 1e-3 #1e-10  # don't really know where this has an effect. When large it heavily damps reconstruction
# small epsilon to prevent 0/0
eps = 1e-16

(ky, kx) = grid.recip_grid(xp=cupy)
cutout_shape = probes.shape[-2:]

realspace_norm = numpy.sqrt(numpy.prod(probes.shape[-2:]))

# pre-calculate obj_mag and probe_mag
obj_mag = xp.zeros(cutout_shape, dtype=numpy.float64)
probe_mag = xp.zeros_like(obj, dtype=numpy.float64)

# this loop should also fix incident probe intensity
groups = calc_groupings(16, 1234567)
for (group_i, group) in enumerate(groups):
    group_scan = scan[group]
    # group probes in real space
    # shape (len(group), 1, Ny, Nx)
    group_subpx_filters = fourier_shift_filter(ky, kx, obj_grid.get_subpx_shifts(group_scan, cutout_shape))[:, None, ...]
    # shape (len(group), probe modes, Ny, Nx)
    group_probes = ifft2(fft2(probes) * group_subpx_filters)
    obj_grid.add_view_at_pos(probe_mag, group_scan, xp.sum(abs2(group_probes), axis=1) / realspace_norm)
    obj_mag += xp.sum(obj_grid.get_view_at_pos(abs2(obj), group_scan, cutout_shape), axis=0) / realspace_norm

for i in range(5):
    groups = calc_groupings(8, 1234567 + i)

    new_obj_mag = xp.zeros(probes.shape[-2:], dtype=numpy.float64)
    new_probe_mag = xp.zeros_like(obj, dtype=numpy.float64)

    for (group_i, group) in enumerate(groups):
        group_scan = scan[group]
        # group probes in real space
        # shape (len(group), 1, Ny, Nx)
        group_subpx_filters = fourier_shift_filter(ky, kx, obj_grid.get_subpx_shifts(group_scan, cutout_shape))[:, None, ...]
        # shape (len(group), probe modes, Ny, Nx)
        group_probes = ifft2(fft2(probes) * group_subpx_filters)
        # add to illumination
        group_probe_mag = xp.zeros_like(obj, dtype=numpy.float64)
        obj_grid.add_view_at_pos(group_probe_mag, group_scan, xp.sum(abs2(group_probes), axis=1) / realspace_norm)
        new_probe_mag += group_probe_mag

        # experimental data
        group_patterns = xp.array(exp_data[group])

        # shape (len(group), 1, Ny, Nx)
        group_objs = obj_grid.get_view_at_pos(obj, group_scan, cutout_shape)[:, None, ...]
        # add to illumination
        group_obj_mag = xp.sum(abs2(group_objs), axis=0) / realspace_norm
        new_obj_mag += group_obj_mag[0]

        exit_wave = group_objs * group_probes

        pattern_model = fft2(exit_wave)
        # sum over incoherent modes
        intensity_model = xp.sum(abs2(pattern_model), axis=1)

        # modulus constraint
        intensity_update = xp.sqrt(group_patterns / (intensity_model + 1e-10)) - 1.0
        chi = ifft2(pattern_model * intensity_update[:, None, ...])
        if check_nan(chi):
            raise ValueError("NaN encountered in chi")

        # eq. 24, update directions per probe position
        delta_P = chi * xp.conj(group_objs) # and probe mode
        delta_O = chi * xp.conj(group_probes)

        # eq. 23, calculate step size per probe position and mode
        # we take the diagonal approxmiation
        alpha_P = xp.sum(xp.real(chi * xp.conj(delta_P * group_objs)), axis=(-1, -2), keepdims=True) / (xp.sum(abs2(delta_P * group_objs) + gamma))
        # sum over probe modes as well
        alpha_O = xp.sum(xp.sum(xp.real(chi * xp.conj(delta_O * group_probes)), axis=(-1, -2), keepdims=True), axis=1) / (xp.sum(abs2(delta_O * group_probes) + gamma))
        #print(f"alpha_P: {list(alpha_P.ravel())}")
        #print(f"alpha_O: {list(alpha_O.ravel())}")
        
        if check_nan(alpha_O, alpha_P):
            raise ValueError("NaN encountered in alpha_O / alpha_P")

        # eq. 25, average update directions
        # apply subpx shifts in reciprocal space
        # delta_P_avg is in cutout space
        delta_P_avg = ifft2(xp.sum(fft2(delta_P) / group_subpx_filters, axis=0))
        delta_P_avg /= (obj_mag + illum_reg_P)
        # delta_O_avg is in object space
        delta_O_avg = xp.zeros_like(obj)
        # sum over probe modes as well
        obj_grid.add_view_at_pos(delta_O_avg, group_scan, xp.sum(delta_O, axis=1))
        delta_O_avg /= (probe_mag + illum_reg_O)

        # eq. 27, final probe and object update
        # note that probe update is performed in cutout space, while object update is performed in full object space
        # a small epsilon is added to prevent 0/0 -> nan
        probe_update = xp.sum(alpha_P * delta_P_avg * group_obj_mag, axis=0) / (group_obj_mag + eps)
        obj_update = xp.sum(alpha_O * delta_O_avg * group_probe_mag, axis=0) / (group_probe_mag + eps)

        if check_nan(probe_update, obj_update):
            raise ValueError("NaN encountered in probe_update / obj_update")
        
        #if i > 0 and group_i > 200:
        #    break

        probes += probe_update
        obj += obj_update

    else:
        # update current obj_mag and probe_mag
        obj_mag = new_obj_mag
        probe_mag = new_probe_mag

        print(f"Finished iter #{i+1}")

In [118]:
from phaser.utils.filter import remove_linear_ramp

fig, axs = pyplot.subplots(ncols=2, nrows=3)
fig.set_size_inches(4, 7)

for ax in axs.ravel():
    ax.set_axis_off()

axs[0, 1].sharex(axs[0, 0])
axs[0, 1].sharey(axs[0, 0])
for idxs in ((1, 1), (2, 0), (2, 1)):
    axs[*idxs].sharex(axs[1, 0])
    axs[*idxs].sharey(axs[1, 0])

pattern_max = max(map(numpy.nanmax, (group_patterns[0].get(), intensity_model[0].get())))

axs[0, 0].set_title("Object")
if True:
    crop = obj_grid.get_region_crop()
    #crop = numpy.ceil(sim_r / numpy.array([px_size_y, px_size_x])).astype(numpy.int_)
    #crop = (slice(crop[0], -crop[0]), slice(crop[0], -crop[0]))
else:
    crop = (slice(0, None), slice(0, None))

axs[0, 0].imshow(remove_linear_ramp(xp.angle(obj[*crop]).get()))

#axs[0].imshow(numpy.angle(obj))
axs[1, 0].set_title("Measured Pattern")
axs[1, 0].imshow(numpy.fft.fftshift(group_patterns[0].get()), vmin=0., vmax=pattern_max)
axs[1, 1].set_title("Recons. Pattern")
axs[1, 1].imshow(numpy.fft.fftshift(intensity_model[0].get()), vmin=0., vmax=pattern_max)
axs[0, 1].set_title("Obj update phase")
axs[0, 1].imshow(numpy.angle(obj_update[*crop].get()))
axs[2, 0].set_title("Probe")
axs[2, 0].imshow(xp.sum(abs2(probes), axis=0).get())
axs[2, 1].set_title("Probe update phase")
axs[2, 1].imshow(xp.angle(probe_update[0]).get())

In [119]:
# Plot probe modes

fig, axs = pyplot.subplots(ncols=probes.shape[0], sharex=True, sharey=True, squeeze=False)

for (ax, probe) in zip (axs.flat, probes):
    ax.set_axis_off()
    ax.set_title("{:.3f}".format(xp.sum(abs2(probe)) / numpy.prod(probe.shape[-2:])))
    #ax.imshow(abs2(numpy.fft.fftshift(fft2(probe))))
    ax.imshow(abs2(probe).get())

pyplot.show()