In [1]:
%matplotlib widget
%reload_ext autoreload
%autoreload 2

import typing as t

from matplotlib import pyplot
import matplotlib.patches as patches

import numpy
from numpy.typing import NDArray, ArrayLike

In [2]:
from pathlib import Path
from ptycho_lebeau.metadata import AnyMetadata

#meta = AnyMetadata.parse_file("/Users/colin/Downloads/mos2/1/mos2/mos2_0.00_dstep1.0.json")
meta = AnyMetadata.parse_file(Path("~/mos2/1/mos2/mos2_0.00_dstep1.0.json").expanduser())
print(meta.json(indent='    '))

In [8]:
from phaser.io.empad import load_4d

assert meta.path is not None
assert meta.raw_filename is not None
raw = load_4d(meta.path / meta.raw_filename)

# fftshift patterns to corners
raw = numpy.fft.fftshift(raw, axes=(-1, -2)).reshape((-1, *raw.shape[-2:]))
# normalize patterns
raw /= numpy.sum(raw, axis=(-2, -1))[:, None, None]

dose = 100000  # e/A^2
px_area = numpy.prod(numpy.array(meta.scan_step) * 1e10)
dose_per_pattern = dose * px_area
# e * 1.602e-19 C/e / 1e-3 s * 1e12 pA/A -> pA
print(f"Equiv. beam current: {(dose_per_pattern * 1.6021766e-07) / 1e-3:.3f} pA")

# apply poisson noise
noisy = numpy.random.poisson(dose_per_pattern * raw).astype(numpy.float32) / dose_per_pattern

## Various helper functions

In [10]:
from phaser.utils.num import fft2, ifft2, abs2

def calc_groupings(grouping: int = 8, seed: t.Any = None) -> list[NDArray[numpy.int_]]:
    rng = numpy.random.RandomState(seed)
    idxs = numpy.arange(raw.shape[0])
    rng.shuffle(idxs)
    return numpy.array_split(idxs, numpy.ceil(raw.shape[0] / grouping).astype(numpy.int_))

## Make probe & propagators

In [11]:
from phaser.utils.num import Sampling
from phaser.utils.physics import Electron

wavelength = Electron(meta.voltage).wavelength
print(wavelength)

kstep = meta.diff_step*1e-3 / wavelength
(b, a) = (1/kstep,) * 2

grid = Sampling(raw.shape[-2:], extent=(b, a))

In [24]:
from phaser.utils.optics import make_focused_probe

init_probe = make_focused_probe(*grid.recip_grid(), wavelength, meta.conv_angle, defocus=meta.defocus*1e10)

fig, (ax1, ax2) = pyplot.subplots(ncols=2)

ax1.imshow(abs2(init_probe))
ax2.imshow(numpy.fft.fftshift(abs2(fft2(init_probe))))
pyplot.show()

## Make scan

In [13]:
from phaser.utils.scan import make_raster_scan

(n_x, n_y) = meta.scan_shape
scan = make_raster_scan((n_y, n_x), numpy.array(meta.scan_step[::-1]) * 1e10).reshape(-1, 2)

fig, ax = pyplot.subplots()
ax.invert_yaxis()
ax.set_aspect(1.)
ax.scatter(scan[..., 1], scan[..., 0], s=1.)
ax.scatter([scan[2, 1]], [scan[2, 0]], c='red', s=2.)
pyplot.show()

## Make object

In [14]:
from phaser.utils.object import ObjectSampling

# we pad by 1/2 the object size, plus an extra pixel
obj_grid = ObjectSampling.from_scan(scan, grid.sampling, pad=grid.extent / 2. + grid.sampling)

(y_min, x_min) = obj_grid.min
(y_max, x_max) = obj_grid.max

fig, ax = pyplot.subplots()
ax.invert_yaxis()
ax.set_aspect(1.)
ax.add_patch(patches.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, fill=False, transform=ax.transData))
(yy, xx) = obj_grid.grid()
#ax.pcolormesh(yy, xx, yy)
ax.scatter(scan[:, 1], scan[:, 0], s=1, c='blue')

i = 4
ax.scatter([scan[i, 1]], [scan[i, 0]], s=2, c='red')
slicey, slicex = obj_grid.slice_at_pos(scan[i], raw.shape[-2:])
(starty, startx) = (yy[slicey.start, 0], xx[0, slicex.start])
(stopy, stopx) = (yy[slicey.stop, 0], xx[0, slicex.stop])

ax.add_patch(patches.Rectangle((startx, starty), stopx - startx, stopy - starty, fill=False, transform=ax.transData, color='red'))

print(obj_grid.slice_at_pos(scan[i], (128, 128)))

pyplot.show()

## Iteration loop

In [28]:
from phaser.utils.optics import fourier_shift_filter

# experimental data, shape (N, ky, kx)
exp_data = raw

# random start object
rng = numpy.random.RandomState(seed=123456)
#obj_angle = gaussian_filter(rng.normal(0., 1e-4, obj_grid.shape), 4.)
obj_angle = rng.normal(0., 1e-6, obj_grid.shape)
obj = (numpy.cos(obj_angle) + numpy.sin(obj_angle) * 1.j)

# probe model
probe = init_probe.copy()

(ky, kx) = grid.recip_grid()
cutout_shape = probe.shape[-2:]

alpha = 1.0        # beta in 2004 paper, alpha in 2009 paper
beta = 1.0         # beta in 2009 paper (probe update factor)
illum_reg = 0.2    # alpha in 2004 paper. used only for ePIE for some reason

for i in range(5):
    groups = calc_groupings(16)

    for (group_i, group) in enumerate(groups):
        group_scan = scan[group]
        # group probes in reciprocal space
        group_subpx_filters = fourier_shift_filter(ky, kx, obj_grid.get_subpx_shifts(group_scan, cutout_shape))
        group_probes = ifft2(group_subpx_filters * fft2(probe))
        group_patterns = raw[group]

        group_objs = obj_grid.get_view_at_pos(obj, group_scan, cutout_shape)

        exit_wave = group_objs * group_probes

        pattern_model = fft2(exit_wave)
        intensity_model = numpy.abs(pattern_model)**2

        # modulus constraint
        intensity_update = numpy.sqrt(group_patterns / (intensity_model + 1e-18)) - 1.0
        wave_diff = ifft2(pattern_model * intensity_update)

        # ePIE
        group_obj_update = alpha * group_probes.conj() / numpy.max(numpy.abs(group_probes)**2, axis=(-1, -2), keepdims=True) * wave_diff
        probe_update = beta * group_objs.conj() / numpy.max(numpy.abs(group_objs)**2, axis=(-1, -2), keepdims=True) * wave_diff
        # average probe update in group
        probe_update = ifft2(numpy.mean(fft2(probe_update) / group_subpx_filters, axis=0))

        obj_update = numpy.zeros_like(obj)
        obj_grid.set_view_at_pos(obj_update, group_scan, group_obj_update)

        obj += obj_update
        probe += probe_update

    print(f"Finished iter #{i+1}")

In [29]:
from phaser.utils.filter import remove_linear_ramp

fig, axs = pyplot.subplots(ncols=2, nrows=3)
fig.set_size_inches(4, 7)

for ax in axs.ravel():
    ax.set_axis_off()

axs[0, 1].sharex(axs[0, 0])
axs[0, 1].sharey(axs[0, 0])
for idxs in ((1, 1), (2, 0), (2, 1)):
    axs[*idxs].sharex(axs[1, 0])
    axs[*idxs].sharey(axs[1, 0])

pattern_max = max(map(numpy.nanmax, (group_patterns[0], intensity_model[0])))

axs[0, 0].set_title("Object")
if True:
    crop = obj_grid.get_region_crop()
    #crop = numpy.ceil(sim_r / numpy.array([px_size_y, px_size_x])).astype(numpy.int_)
    #crop = (slice(crop[0], -crop[0]), slice(crop[0], -crop[0]))
else:
    crop = (slice(0, None), slice(0, None))

axs[0, 0].imshow(remove_linear_ramp(numpy.angle(obj[*crop])))

#axs[0].imshow(numpy.angle(obj))
axs[1, 0].set_title("Measured Pattern")
axs[1, 0].imshow(numpy.fft.fftshift(group_patterns[0]), vmin=0., vmax=pattern_max)
axs[1, 1].set_title("Recons. Pattern")
axs[1, 1].imshow(numpy.fft.fftshift(intensity_model[0]), vmin=0., vmax=pattern_max)
axs[0, 1].set_title("Obj update phase")
axs[0, 1].imshow(numpy.angle(obj_update[*crop]))
axs[2, 0].set_title("Probe")
axs[2, 0].imshow(abs2(probe))
axs[2, 1].set_title("Probe update phase")
axs[2, 1].imshow(numpy.angle(probe_update))