In [1]:
%matplotlib widget
%reload_ext autoreload
%autoreload 2

import typing as t

from matplotlib import pyplot
import matplotlib.patches as patches

import numpy
from numpy.typing import NDArray, ArrayLike

## Overview

Steps in MLs reconstruction:

 - Setup
   - Make incident probe(s)
   - Generate scan positions
   - Generate object, large enough to contain $[1/d, 1/d]$ region around every scan position
   - Load and process measured patterns
 - Running
  - Pre-Calculate probe intensity (sum of probe intensity in object space)
    - High illumination -> good est. of object wavefunction
  - Pre-Calculate object intensity (sum of object intensity in probe space)
    - High object magnitude -> good signal of probe wavefunction
    - important for opaque regions/obstructions

  - Per iteration
    - Split probe positions into groupings

    - Per grouping
      - Add to probe & object intensity

      - Apply forward model to detector plane
      - Apply modulus constraint at detector plane
      - Calculate exit wavefunction update chi
      - Calculate object and probe update directions, per probe position
      - Average object and probe update directions
      - Calculate probe and object update step size, per probe position
      - Apply probe and object update weighted by object and probe intensity

    - Update probe & object intensity

## Conventions

- Reciprocal space: Zero-frequency in corner
- Real space: (0, 0) in center
- Energy is normalized in both reciprocal space:
  - $ \sum \left|f(x, y)\right|^2 = 1 $
  - $ \sum \left|F(k_x, k_y)\right|^2 = N_x N_y $
- Probe and object is stored in real space

## Implications

- FFTs need 'forward' normalization
- Reciprocal space to real space: `f = fftshift(ifft2(F, norm='forward'), axes=(-2, -1))`
- Real space to reciprocal space: `F = fft2(ifftshift(f, axes=(-2, -1)), norm='forward')`
- Normalize energy in real space: `abs2(f) / numpy.prod(f.shape[-2:])`
- Normalize amplitude in real space: `f / numpy.sqrt(numpy.prod(f.shape[-2:]))`
- Never need to fftshift in reciprocal space

In [2]:
from ptycho_lebeau.raw import load_4d
from ptycho_lebeau.metadata import AnyMetadata
import eutils

meta = AnyMetadata.parse_file("/Users/colin/Downloads/mos2/1/mos2/mos2_0.00_dstep1.0.json")
print(meta.json(indent='    '))

In [4]:
raw = load_4d(meta.path / meta.raw_filename)
# fftshift patterns to corners
raw = numpy.fft.fftshift(raw, axes=(-1, -2)).reshape((-1, *raw.shape[-2:]))
# normalize patterns
raw /= numpy.sum(raw, axis=(-2, -1))[:, None, None]

dose = 100000  # e/A^2
px_area = numpy.prod(numpy.array(meta.scan_step) * 1e10)
dose_per_pattern = dose * px_area
# e * 1.602e-19 C/e / 1e-3 s * 1e12 pA/A -> pA
print(f"Equiv. beam current: {(dose_per_pattern * 1.6021766e-07) / 1e-3:.3f} pA")

# apply poisson noise
noisy = numpy.random.poisson(dose_per_pattern * raw).astype(numpy.float32) / dose_per_pattern

## Various helper functions

In [5]:
from scipy.linalg import lstsq

def ifft2(a: ArrayLike) -> NDArray[numpy.complex_]:
    return numpy.fft.fftshift(numpy.fft.ifft2(a, norm='forward'), axes=(-2, -1))

def fft2(a: ArrayLike) -> NDArray[numpy.complex_]:
    return numpy.fft.fft2(numpy.fft.ifftshift(a, axes=(-2, -1)), norm='forward')

def fourier_shift_filter(ky: numpy.ndarray, kx: numpy.ndarray, shifts: ArrayLike) -> NDArray[numpy.complex_]:
    shifts = numpy.array(shifts)
    if shifts.ndim == 1:
        (x, y) = shifts
        return numpy.exp(-2.j*numpy.pi*(x*kx + y*ky))

    out = numpy.empty((*shifts.shape[:-1], *ky.shape), dtype=numpy.complex_)
    for idxs in numpy.ndindex(shifts.shape[:-1]):
        (x, y) = shifts[*idxs]
        out[*idxs] = numpy.exp(-2.j*numpy.pi*(x*kx + y*ky))

    return out

def remove_phase_ramp(data: numpy.ndarray) -> numpy.ndarray:
    output = numpy.empty_like(data)

    (yy, xx) = (arr.flatten() for arr in numpy.indices(data.shape[-2:], dtype=float))
    pts = numpy.stack((numpy.ones_like(xx), xx, yy), axis=-1)

    for idx in numpy.ndindex(data.shape[:-2]):
        layer = data[*idx]
        p, residues, rank, singular = lstsq(pts, layer.flatten())
        output[*idx] = layer - (p @ pts.T).reshape(layer.shape)

    return output

def calc_groupings(grouping: int = 8, seed: t.Any = None) -> list[NDArray[numpy.int_]]:
    rng = numpy.random.RandomState(seed)
    idxs = numpy.arange(raw.shape[0])
    rng.shuffle(idxs)
    return numpy.array_split(idxs, numpy.ceil(raw.shape[0] / grouping).astype(numpy.int_))

def abs2(x: NDArray[numpy.complexfloating]) -> NDArray[numpy.floating]:
    return x.real**2. + x.imag**2.

## Make probe & propagators

In [6]:
wavelength = eutils.Electron(meta.voltage / 1e3).wavelength

kstep = meta.diff_step*1e-3 / wavelength
(b, a) = (1/kstep,) * 2

(px_size_y, px_size_x) = (b/raw.shape[-2], a/raw.shape[-1])

yy = numpy.linspace(-b/2, b/2, raw.shape[-2], endpoint=False)
xx = numpy.linspace(-a/2, a/2, raw.shape[-1], endpoint=False)
yy, xx = numpy.meshgrid(yy, xx, indexing='ij')

# TODO switch with something more elegant
ky = numpy.fft.fftfreq(raw.shape[-2], b/raw.shape[-2])
kx = numpy.fft.fftfreq(raw.shape[-1], a/raw.shape[-1])
ky, kx = numpy.meshgrid(ky, kx, indexing='ij')

thetay, thetax = ky * wavelength, kx * wavelength
theta2 = thetay**2 + thetax**2

In [7]:
# TODO check sign of defocus

chi = meta.defocus * 1e10 * (theta2 / 2)
init_probe = numpy.exp(chi * -2.j*numpy.pi / wavelength)
# mask reciprocal space to aperture
aperture_mask = theta2 <= (meta.conv_angle * 1e-3)**2
init_probe *= aperture_mask
# normalize amplitude
init_probe /= numpy.sqrt(numpy.sum(numpy.abs(init_probe)))

init_probe = ifft2(init_probe)

In [8]:
fig, (ax1, ax2) = pyplot.subplots(ncols=2)

ax1.imshow(abs2(init_probe))
ax2.imshow(numpy.fft.fftshift(abs2(fft2(init_probe))))
pyplot.show()

## Probe modes

In [9]:
def make_hermetian_modes(base_probe: numpy.ndarray, n_modes: int, powers: ArrayLike = 0.02):
    powers = numpy.array(powers, dtype=numpy.float64).ravel()
    powers = numpy.pad(powers, (0, n_modes - len(powers) - 1), mode='edge')[:n_modes - 1]
    base_power = 1. - numpy.sum(powers)
    powers = numpy.concatenate(([base_power], powers))

    n_y = numpy.ceil(numpy.sqrt(n_modes)).astype(numpy.int_)
    n_x = numpy.ceil(n_modes / (n_y + 1)).astype(numpy.int_)

    print(f"{n_y=}, {n_x=}")

    realspace_norm = numpy.sqrt(numpy.prod(base_probe.shape[-2:]))
    modes = hermetian_modes(base_probe, n_y, n_x)[:n_modes] * numpy.sqrt(powers)[:, None, None] * realspace_norm

    return modes

def hermetian_modes(base_probe: numpy.ndarray, n_y: int, n_x: int) -> numpy.ndarray:
    (yy, xx) = numpy.indices(base_probe.shape, dtype=numpy.float64)

    base_probe_mag = abs2(base_probe)

    (com_y, com_x) = (numpy.sum(a * base_probe_mag) / numpy.sum(base_probe_mag) for a in (yy, xx))
    yy -= com_y
    xx -= com_x
    (var_y, var_x) = (numpy.sum(a**2. * base_probe_mag) / numpy.sum(base_probe_mag) for a in (yy, xx))

    modes = []

    for i in range(n_y):
        for j in range(n_x):
            mode = yy**i * xx**j * base_probe
            if i > 1 or j > 1:
                mode = mode * numpy.exp(-xx**2./(2 * var_x) - yy*2./(2 * var_y))
                mode /= numpy.sqrt(numpy.sum(abs2(mode)))

            # orthogonalize to other modes
            for prev_mode in modes:
                mode -= prev_mode * numpy.sum(prev_mode * numpy.conj(mode))

            mode /= numpy.sqrt(numpy.sum(abs2(mode)))
            modes.append(mode)

    return numpy.stack(modes, axis=0)

In [10]:
n_probes = 4

modes = make_hermetian_modes(init_probe, n_probes)

fig, axs = pyplot.subplots(ncols=n_probes, sharex=True, sharey=True)

for (ax, mode) in zip(axs.flat, modes):
    ax.set_axis_off()
    ax.imshow(abs2(mode))
    print(numpy.sum(abs2(fft2(mode))))

pyplot.show()

In [11]:
# debug mode correlations and intensities, should be orthogonal

init_probes = make_hermetian_modes(init_probe, 4)

img = numpy.zeros((len(init_probes),) * 2)

energy_norm = numpy.prod(init_probe.shape[-2:])

for i in range(len(init_probes)):
    for j in range(len(init_probes)):
        img[i,j] = numpy.abs(numpy.sum(init_probes[i] * numpy.conj(init_probes[j]))) / energy_norm

fig, ax = pyplot.subplots()
ax.imshow(img)

#for (ax, mode) in zip(axs.flat, modes):
#    ax.set_axis_off()
#    ax.imshow(numpy.abs(mode)**2.)

pyplot.show()

## Make scan

In [15]:
scanx = (numpy.arange(meta.scan_shape[0], dtype=numpy.float_) - meta.scan_shape[0] / 2.) * meta.scan_step[0] * 1e10
scany = (numpy.arange(meta.scan_shape[1], dtype=numpy.float_) - meta.scan_shape[1] / 2.) * meta.scan_step[1] * 1e10
scany, scanx = map(numpy.ravel, numpy.meshgrid(scany, scanx, indexing='ij'))
# shape (n, 2)
scan = numpy.stack((scanx, scany), axis=-1)

fig, ax = pyplot.subplots()
ax.invert_yaxis()
ax.set_aspect(1.)
ax.scatter(scan[:, 0], scan[:, 1], s=1.)
ax.scatter([scan[2, 0]], [scan[2, 1]], c='red', s=2.)
pyplot.show()

## Make object

In [19]:
x_min, x_max = numpy.nanmin(scanx), numpy.nanmax(scanx)
y_min, y_max = numpy.nanmin(scany), numpy.nanmax(scany)

# pad for outer positions
sim_r = numpy.max((a, b)) / 2. + numpy.max((px_size_x, px_size_y)) * 2.
x_min -= sim_r
y_min -= sim_r
x_max += sim_r
y_max += sim_r

# TODO check this for off-by-one

# keep x_min and x_max, calculate number of object pixels required
n_x = numpy.ceil((x_max - x_min) / px_size_x).astype(numpy.int_) + 1
n_y = numpy.ceil((y_max - y_min) / px_size_x).astype(numpy.int_) + 1

# update x_max and y_max
x_max = x_min + n_x * px_size_x
y_max = y_min + n_y * px_size_y

fig, ax = pyplot.subplots()
ax.set_aspect(1.)
ax.add_patch(patches.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, fill=False, transform=ax.transData))
i, j = (a[0] for a in numpy.split(numpy.indices((n_y, n_x)), 2, axis=0))
xx = x_min + j * px_size_x
yy = y_min + i * px_size_y
ax.pcolormesh(x_min + j * px_size_x, y_min + i*px_size_y, i)
ax.scatter(scanx, scany, s=1, c='red')
pyplot.show()

In [42]:
slice_around_position(*group_scan[0])

In [49]:
# handle object cutouts and subpixel probe shifting

def pos_to_object_idx(pos: ArrayLike) -> NDArray[numpy.float_]:
    return (numpy.array(pos) - numpy.array([x_min, y_min])) / [px_size_x, px_size_y] - numpy.array(ky.shape[-2:]) / 2.

def slice_around_position(x: float, y: float) -> t.Tuple[slice, slice]:
    (start_j, start_i) = numpy.round(pos_to_object_idx((x, y))).astype(numpy.int_)
    #start_i = numpy.round((y - y_min) / px_size_y - ky.shape[-2] / 2).astype(numpy.int_)
    #start_j = numpy.round((x - x_min) / px_size_x - ky.shape[-1] / 2).astype(numpy.int_)
    return (
        slice(start_i, start_i + ky.shape[-2]),
        slice(start_j, start_j + ky.shape[-1]),
    )

def get_subpx_shifts(pos: ArrayLike) -> NDArray[numpy.float_]:
    pos = pos_to_object_idx(pos)
    return pos - numpy.round(pos)

def get_view_at_pos(obj: numpy.ndarray, pos: ArrayLike) -> numpy.ndarray:
    pos = numpy.array(pos)
    if pos.ndim == 1:
        return obj[slice_around_position(*pos)]

    out = numpy.empty((*pos.shape[:-1], *ky.shape[-2:]), dtype=obj.dtype)
    for idx in numpy.ndindex(pos.shape[:-1]):
        out[*idx] = obj[slice_around_position(*pos[idx])]

    return out

def set_view_at_pos(obj: numpy.ndarray, pos: ArrayLike, view: numpy.ndarray):
    pos = numpy.array(pos)
    if pos.ndim == 1:
        obj[slice_around_position(*pos)] = view
        return

    for idx in numpy.ndindex(pos.shape[:-1]):
        obj[slice_around_position(*pos[idx])] = view[idx]

def add_view_at_pos(obj: numpy.ndarray, pos: ArrayLike, view: numpy.ndarray):
    pos = numpy.array(pos)
    if pos.ndim == 1:
        obj[slice_around_position(*pos)] += view
        return

    for idx in numpy.ndindex(pos.shape[:-1]):
        obj[slice_around_position(*pos[idx])] += view[idx]


In [50]:
obj = numpy.zeros((n_y, n_x), dtype=numpy.float_)

cutout = numpy.ones((128, 128), dtype=numpy.float_)

pos = scan[[0, 20 + 64, 120, -1]]

#set_view_at_pos(obj, pos, cutout)
add_view_at_pos(obj, pos, cutout)

fig, ax = pyplot.subplots()
ax.invert_yaxis()
ax.set_aspect(1.)
fig.colorbar(ax.pcolormesh(xx, yy, obj))
ax.scatter(pos[:, 0], pos[:, 1])
pyplot.show()

## LSQ-ML method

The LSQ-ML update step proceeds in three parts for each scan position $i$:

1. Calculating the optimal exit wavefront $\psi^{opt}_i$
2. Calculating the object and probe update directions $\Delta P_{i}$ and $\Delta O_{i}$
3. Calculating the object and probe optimal update steps $\alpha_{P,i}$ and $\alpha_{O,i}$

The optimal exit wavefront $\psi^{opt}_i$ is calculated to minimize the distance between expected/modeled intensity on the diffraction plane $I^e_{i}$ and the measured intensity $I^m_{i}$. Specifically, we choose the exit wavefront $\psi^{opt}_i$ to maximize the likelihood $p(I^m_i|\psi^{opt}_i)$. For this reference implementation, we replace this with a simple modulus constraint, which is the output of the ML optimization in the case where only Poisson noise is present in the amplitude likelihood model.

To calculate real-space object and probe update, given an optimal exit wavefront, we calculate an 'exit wave update' $\chi_{i,r}$ as the difference between our optimal and modeled exit wave:
$$
\chi_{i,r} = \psi^{opt}_{i,r} - \psi^{fwd}_{i,r}
$$

We attempt to split $\chi_{i,r}$ into probe and object updates, so as to minimize a real-space cost function $\mathcal{L}_r$, defined as the sum squared error between the optimal exit wavefront and the updated wavefront:
$$\begin{aligned}
\mathcal{L}_r &= \sum_{r} \left| \psi^{updated}_{i,r} - \psi^{opt}_{i,r} \right|^2
\end{aligned}$$

Object and probe update directions are calculated as the gradient of the real-space cost function $\mathcal{L}_r$:
$$\begin{aligned}
\Delta P_{i,r} &= -\nabla_P \mathcal{L}_r = \chi_{i,r} O^*_{r+r_i} \\
\Delta O_{i,r+r_i} &= -\nabla_O \mathcal{L}_r = \chi_{i,r} P^*_{r} \\
\end{aligned}$$

(These gradients come from the complex derivative of $\psi_{i, r} = P_{i, r} O_{i, r+r_i}$ assuming the second Wirtinger derivative $\partial/\partial \tilde{z} = 0$)

Over a batch $\mathcal{N} \subset \mathcal{N}_0$, we average object and probe update directions:
$$\begin{aligned}
\Delta P_{r} &= \frac{\sum_{i \in \mathcal{N}} \Delta P_{i,r}}{\sum_{i \in \mathcal{N}_0} \left| O_{r+r_i} \right|^2 + \delta_P} \\
\Delta O_{r+r_i} &= \frac{\sum_{i \in \mathcal{N}} \Delta O_{r+r_i}}{\sum_{i \in \mathcal{N}_0} \left| P_{r} \right|^2 + \delta_O}
\end{aligned}$$

"$\delta_O$ and $\delta_P$ can be seen as preconditioners of the gradient descent task", and "penalize large values in the update, particularly in the weakly illuminated regions."
The denominator sums can be considered the object and probe illumination, and are calculated for all positions to avoid noise amplification.

Next, we calculate the step sizes $\alpha_P$ and $\alpha_O$. These can be calculated using the matrix (22), or assuming a diagonal matrix, in which case the step sizes become:
$$\begin{aligned}
\alpha_{P,i} = \frac{\sum_r Re\left[ \chi_{i,r} (\Delta P_{i, r} O_{r+r_i} )^* \right]}{\sum_r \left| \Delta P_{i,r} O_{r+r_i} \right|^2 + \gamma} \\
\alpha_{O,i} = \frac{\sum_r Re\left[ \chi_{i,r} (\Delta O_{i, r} P_{r} )^* \right]}{\sum_r \left| \Delta O_{i,r} P_r \right|^2 + \gamma}
\end{aligned}$$

$\gamma = 0.1$ in fold_slice?

Finally, we update the probe and object as a weighted sum of updates:
$$\begin{aligned}
P_r^{updated} &= P_r + \frac{\sum_{i \in \mathcal{N}} \alpha_{P,i} \Delta P_r \left| O_{r+r_i} \right|^2}{\sum_{i \in \mathcal{N}} \left| O_{r+r_i} \right|^2} \\
O_r^{updated} &= O_r + \frac{\sum_{i \in \mathcal{N}} \alpha_{P,i} \Delta P_r \left| P_{r-r_i} \right|^2}{\sum_{i \in \mathcal{N}} \left| P_{r-r_i} \right|^2}
\end{aligned}$$

Note that step sizes are calculated per scan position (and then averaged as a function of illumination), but step directions are averaged together into a single step direction per grouping. In practice, we need to add a small epsilon in the denominator to prevent 0/0 conditions.

## Iteration loop

In [53]:
# experimental data, shape (N, ky, kx)
exp_data = raw

# random start object
rng = numpy.random.RandomState(seed=123456)
obj_angle = rng.normal(0., 1e-8, (n_y, n_x))
obj = (numpy.cos(obj_angle) + numpy.sin(obj_angle) * 1.j)

# probe model
probes = init_probes[:1].copy() / 0.94
#probes = init_probes.copy()

# delta_P and delta-O in eq. 25
illum_reg_P = 1e-3 # * numpy.prod(probe.shape[-2:])
illum_reg_O = 1e-3 # * numpy.prod(probe.shape[-2:])
# gamma in eq. 23, 0.1 in fold_slice
gamma = 0.1 * numpy.prod(probes.shape[-2:])
# small epsilon to prevent 0/0
eps = 1e-16

obj_mag = numpy.zeros(probes.shape[-2:], dtype=numpy.float64)
probe_mag = numpy.zeros_like(obj, dtype=numpy.float64)

# pre-calculate obj_mag and probe_mag

# this loop should also fix incident probe intensity
groups = calc_groupings(16, 1234567)
for (group_i, group) in enumerate(groups):
    group_scan = scan[group]
    # group probes in real space
    # shape (len(group), 1, Ny, Nx)
    group_subpx_filters = fourier_shift_filter(ky, kx, get_subpx_shifts(group_scan))[:, None, ...]
    # shape (len(group), probe modes, Ny, Nx)
    group_probes = ifft2(fft2(probes) * group_subpx_filters)
    add_view_at_pos(probe_mag, group_scan, numpy.sum(abs2(group_probes), axis=1) / numpy.sqrt(numpy.prod(probes.shape[-2:])))
    obj_mag += numpy.sum(get_view_at_pos(abs2(obj), group_scan), axis=0) / numpy.sqrt(numpy.prod(probes.shape[-2:]))

for i in range(5):
    groups = calc_groupings(8, 1234567 + i)

    new_obj_mag = numpy.zeros(probes.shape[-2:], dtype=numpy.float64)
    new_probe_mag = numpy.zeros_like(obj, dtype=numpy.float64)

    for (group_i, group) in enumerate(groups):
        group_scan = scan[group]
        # group probes in real space
        # shape (len(group), 1, Ny, Nx)
        group_subpx_filters = fourier_shift_filter(ky, kx, get_subpx_shifts(group_scan))[:, None, ...]
        # shape (len(group), probe modes, Ny, Nx)
        group_probes = ifft2(fft2(probes) * group_subpx_filters)
        # add to illumination
        group_probe_mag = numpy.zeros_like(obj, dtype=numpy.float64)
        add_view_at_pos(group_probe_mag, group_scan, numpy.sum(abs2(group_probes), axis=1) / numpy.sqrt(numpy.prod(probes.shape[-2:])))
        new_probe_mag += group_probe_mag

        # experimental data
        group_patterns = exp_data[group]

        # shape (len(group), 1, Ny, Nx)
        group_objs = get_view_at_pos(obj, group_scan)[:, None, ...]
        # add to illumination
        group_obj_mag = numpy.sum(abs2(group_objs), axis=0) / numpy.sqrt(numpy.prod(probes.shape[-2:]))
        new_obj_mag += group_obj_mag[0]

        exit_wave = group_objs * group_probes

        pattern_model = fft2(exit_wave)
        # sum over incoherent modes
        intensity_model = numpy.sum(abs2(pattern_model), axis=1)

        # modulus constraint
        intensity_update = numpy.sqrt(group_patterns / (intensity_model + 1e-18)) - 1.0
        chi = ifft2(pattern_model * intensity_update[:, None, ...])

        # eq. 24, update directions per probe position
        delta_P = chi * numpy.conj(group_objs) # and probe mode
        delta_O = chi * numpy.conj(group_probes)

        # eq. 23, calculate step size per probe position and mode
        # we take the diagonal approxmiation
        alpha_P = numpy.sum(numpy.real(chi * numpy.conj(delta_P * group_objs)), axis=(-1, -2), keepdims=True) / (numpy.sum(abs2(delta_P * group_objs)) + gamma)
        # sum over probe modes as well
        alpha_O = numpy.sum(numpy.sum(numpy.real(chi * numpy.conj(delta_O * group_probes)), axis=(-1, -2), keepdims=True), axis=1) / (numpy.sum(abs2(delta_O * group_probes)) + gamma)
        #print(f"alpha_P: {list(alpha_P.ravel())}")
        #print(f"alpha_O: {list(alpha_O.ravel())}")

        # eq. 25, average update directions
        # apply subpx shifts in reciprocal space
        # delta_P_avg is in cutout space
        delta_P_avg = ifft2(numpy.sum(fft2(delta_P) / group_subpx_filters, axis=0))
        delta_P_avg /= (obj_mag + illum_reg_P)
        # delta_O_avg is in object space
        delta_O_avg = numpy.zeros_like(obj)
        # sum over probe modes as well
        add_view_at_pos(delta_O_avg, group_scan, numpy.sum(delta_O, axis=1))
        delta_O_avg /= (probe_mag + illum_reg_O)

        # eq. 27, final probe and object update
        # note that probe update is performed in cutout space, while object update is performed in full object space
        # a small epsilon is added to prevent 0/0 -> nan
        probe_update = numpy.sum(alpha_P * delta_P_avg * group_obj_mag, axis=0) / (group_obj_mag + eps)
        obj_update = numpy.sum(alpha_O * delta_O_avg * group_probe_mag, axis=0) / (group_probe_mag + eps)

        probes += probe_update
        obj += obj_update

    break
    # update current obj_mag and probe_mag
    obj_mag = new_obj_mag
    probe_mag = new_probe_mag

    print(f"Finished iter #{i+1}")

In [54]:
fig, axs = pyplot.subplots(ncols=2, nrows=3)
fig.set_size_inches(4, 7)

for ax in axs.ravel():
    ax.set_axis_off()

axs[0, 1].sharex(axs[0, 0])
axs[0, 1].sharey(axs[0, 0])
for idxs in ((1, 1), (2, 0), (2, 1)):
    axs[*idxs].sharex(axs[1, 0])
    axs[*idxs].sharey(axs[1, 0])

pattern_max = max(map(numpy.nanmax, (group_patterns[0], intensity_model[0])))

axs[0, 0].set_title("Object")
if True:
    crop = numpy.ceil(sim_r / numpy.array([px_size_y, px_size_x])).astype(numpy.int_)
    crop = (slice(crop[0], -crop[0]), slice(crop[0], -crop[0]))
else:
    crop = (slice(0, None), slice(0, None))

axs[0, 0].imshow(remove_phase_ramp(numpy.angle(obj[*crop])))

#axs[0].imshow(numpy.angle(obj))
axs[1, 0].set_title("Measured Pattern")
axs[1, 0].imshow(numpy.fft.fftshift(group_patterns[0]), vmin=0., vmax=pattern_max)
axs[1, 1].set_title("Recons. Pattern")
axs[1, 1].imshow(numpy.fft.fftshift(intensity_model[0]), vmin=0., vmax=pattern_max)
axs[0, 1].set_title("Obj update phase")
axs[0, 1].imshow(numpy.angle(obj_update[*crop]))
axs[2, 0].set_title("Probe")
axs[2, 0].imshow(numpy.sum(abs2(probes), axis=0))
axs[2, 1].set_title("Probe update phase")
axs[2, 1].imshow(numpy.angle(probe_update[0]))

In [46]:
fig, axs = pyplot.subplots(ncols=probes.shape[0], sharex=True, sharey=True, squeeze=False)

for (ax, probe) in zip (axs.flat, probes):
    ax.set_axis_off()
    ax.set_title("{:.3f}".format(numpy.sum(abs2(probe)) / numpy.prod(probe.shape[-2:])))
    #ax.imshow(abs2(numpy.fft.fftshift(fft2(probe))))
    ax.imshow(abs2(probe))

pyplot.show()

In [72]:
pyplot.close('all')

In [33]:
fig, axs = pyplot.subplots(ncols=3, sharex=True, sharey=True)

axs[0].imshow(numpy.fft.fftshift(raw[0]))
axs[1].imshow(numpy.fft.fftshift(raw[1]))
axs[2].imshow(numpy.fft.fftshift(raw[2]))
pyplot.show()

In [44]:
# Plot probe modes

fig, axs = pyplot.subplots(ncols=len(delta_O), sharex=True, sharey=True, squeeze=False)

for (ax, upd) in zip (axs.flat, delta_O):
    ax.set_axis_off()
    #ax.imshow(abs2(numpy.fft.fftshift(fft2(probe))))
    ax.imshow(abs2(upd[0]))

pyplot.show()

In [48]:
fig, ax = pyplot.subplots()

ax.imshow(numpy.fft.fftshift(group_patterns[0]))
pyplot.show()