In [1]:
%matplotlib widget
%reload_ext autoreload
%autoreload 2

import typing as t

from matplotlib import pyplot
import matplotlib.patches as patches

import numpy
from numpy.typing import NDArray, ArrayLike

In [2]:
from ptycho_lebeau.raw import load_4d
from ptycho_lebeau.metadata import AnyMetadata
import eutils

meta = AnyMetadata.parse_file("/Users/colin/Downloads/mos2/1/mos2/mos2_0.00_dstep1.0.json")
print(meta.json(indent='    '))

In [3]:
raw = load_4d(meta.path / meta.raw_filename)
# fftshift patterns to corners
raw = numpy.fft.fftshift(raw, axes=(-1, -2)).reshape((-1, *raw.shape[-2:]))
# normalize patterns
raw /= numpy.sum(raw, axis=(-2, -1))[:, None, None]

## Make probe & propagators

In [4]:
wavelength = eutils.Electron(meta.voltage / 1e3).wavelength

kstep = meta.diff_step*1e-3 / wavelength
(b, a) = (1/kstep,) * 2

(px_size_y, px_size_x) = (b/raw.shape[-2], a/raw.shape[-1])

yy = numpy.linspace(-b/2, b/2, raw.shape[-2], endpoint=False)
xx = numpy.linspace(-a/2, a/2, raw.shape[-1], endpoint=False)
yy, xx = numpy.meshgrid(yy, xx, indexing='ij')

# TODO switch with something more elegant
ky = numpy.fft.fftfreq(raw.shape[-2], b/raw.shape[-2])
kx = numpy.fft.fftfreq(raw.shape[-1], a/raw.shape[-1])
ky, kx = numpy.meshgrid(ky, kx, indexing='ij')

thetay, thetax = ky * wavelength, kx * wavelength
theta2 = thetay**2 + thetax**2

def fourier_shift_filter(shifts: ArrayLike) -> NDArray[numpy.complex_]:
    shifts = numpy.array(shifts)
    if shifts.ndim == 1:
        (x, y) = shifts
        return numpy.exp(-2.j*numpy.pi*(x*kx + y*ky))

    out = numpy.empty((*shifts.shape[:-1], *ky.shape), dtype=numpy.complex_)
    for idxs in numpy.ndindex(shifts.shape[:-1]):
        (x, y) = shifts[*idxs]
        out[*idxs] = numpy.exp(-2.j*numpy.pi*(x*kx + y*ky))

    return out

In [5]:
# TODO check sign of defocus

chi = meta.defocus * 1e10 * (theta2 / 2)
init_probe = numpy.exp(chi * -2.j*numpy.pi / wavelength)
# mask reciprocal space to aperture
aperture_mask = theta2 <= (meta.conv_angle * 1e-3)**2
init_probe *= aperture_mask
init_probe /= numpy.sqrt(numpy.sum(numpy.abs(init_probe)))

## Make scan

In [6]:
scanx = (numpy.arange(meta.scan_shape[0], dtype=numpy.float_) - meta.scan_shape[0] / 2.) * meta.scan_step[0] * 1e10
scany = (numpy.arange(meta.scan_shape[1], dtype=numpy.float_) - meta.scan_shape[1] / 2.) * meta.scan_step[1] * 1e10
scany, scanx = map(numpy.ravel, numpy.meshgrid(scany, scanx, indexing='ij'))
# shape (n, 2)
scan = numpy.stack((scanx, scany), axis=-1)

fig, ax = pyplot.subplots()
ax.set_aspect(1.)
ax.scatter(scan[:, 0], scan[:, 1])
pyplot.show()

## Make object

In [7]:
x_min, x_max = numpy.nanmin(scanx), numpy.nanmax(scanx)
y_min, y_max = numpy.nanmin(scany), numpy.nanmax(scany)

# pad for outer positions
sim_r = numpy.max((a, b)) / 2. + numpy.max((px_size_x, px_size_y)) * 2.
x_min -= sim_r
y_min -= sim_r
x_max += sim_r
y_max += sim_r

# TODO check this for off-by-one

# keep x_min and x_max, calculate number of object pixels required
n_x = numpy.ceil((x_max - x_min) / px_size_x).astype(numpy.int_) + 1
n_y = numpy.ceil((y_max - y_min) / px_size_x).astype(numpy.int_) + 1

# update x_max and y_max
x_max = x_min + n_x * px_size_x
y_max = y_min + n_y * px_size_y

fig, ax = pyplot.subplots()
ax.set_aspect(1.)
ax.add_patch(patches.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, fill=False, transform=ax.transData))
i, j = (a[0] for a in numpy.split(numpy.indices((n_y, n_x)), 2, axis=0))
ax.pcolormesh(x_min + j * px_size_x, y_min + i*px_size_y, i)
ax.scatter(scanx, scany, s=1, c='red')
pyplot.show()

In [8]:
# handle object cutouts and subpixel probe shifting

def pos_to_object_idx(pos: ArrayLike) -> NDArray[numpy.float_]:
    return (numpy.array(pos) - numpy.array([y_min, x_min])) / [px_size_y, px_size_x] - numpy.array([ky.shape[-1], ky.shape[-2]]) / 2.

def slice_around_position(x: float, y: float) -> t.Tuple[slice, slice]:
    (start_j, start_i) = numpy.round(pos_to_object_idx((x, y))).astype(numpy.int_)
    #start_i = numpy.round((y - y_min) / px_size_y - ky.shape[-2] / 2).astype(numpy.int_)
    #start_j = numpy.round((x - x_min) / px_size_x - ky.shape[-1] / 2).astype(numpy.int_)
    return (
        slice(start_i, start_i + ky.shape[-2]),
        slice(start_j, start_j + ky.shape[-1]),
    )

def get_subpx_shifts(pos: ArrayLike) -> NDArray[numpy.float_]:
    pos = pos_to_object_idx(pos)
    return pos - numpy.round(pos)

def get_view_at_pos(obj: numpy.ndarray, pos: ArrayLike) -> numpy.ndarray:
    pos = numpy.array(pos)
    if pos.ndim == 1:
        return obj[slice_around_position(*pos)]

    out = numpy.empty((*pos.shape[:-1], *ky.shape[-2:]), dtype=obj.dtype)
    for idx in numpy.ndindex(pos.shape[:-1]):
        out[*idx] = obj[slice_around_position(*pos[idx])]

    return out

def set_view_at_pos(obj: numpy.ndarray, pos: ArrayLike, view: numpy.ndarray):
    pos = numpy.array(pos)
    if pos.ndim == 1:
        obj[slice_around_position(*pos)] = view
        return

    for idx in numpy.ndindex(pos.shape[:-1]):
        obj[slice_around_position(*pos[idx])] = view[idx]


## Iteration loop

### Grouping loop

In [9]:
def calc_groupings(grouping: int = 8) -> list[NDArray[numpy.int_]]:
    idxs = numpy.arange(raw.shape[0])
    numpy.random.shuffle(idxs)
    return numpy.array_split(idxs, numpy.ceil(raw.shape[0] / grouping).astype(numpy.int_))

# random start object
obj_angle = numpy.random.normal(0., 1e-8, (n_y, n_x))
obj = (numpy.cos(obj_angle) + numpy.sin(obj_angle) * 1.j)

# probe model
probe = init_probe

alpha = 1.0        # beta in 2004 paper, alpha in 2009 paper
beta = 1.0         # beta in 2009 paper (probe update factor)
illum_reg = 0.2    # alpha in 2004 paper. used only for ePIE for some reason

for i in range(10):
    groups = calc_groupings(16)

    for (group_i, group) in enumerate(groups):
        group_scan = scan[group]
        # group probes in reciprocal space
        group_probes = numpy.fft.fftshift(numpy.fft.ifft2(fourier_shift_filter(get_subpx_shifts(group_scan)) * probe, norm='forward'), axes=(-1, -2))
        group_patterns = raw[group]

        group_objs = get_view_at_pos(obj, group_scan)

        exit_wave = group_objs * group_probes

        pattern_model = numpy.fft.fft2(exit_wave, norm='forward')
        intensity_model = numpy.abs(pattern_model)**2

        # modulus constraint
        mask = intensity_model >= 1e-10
        pattern_model[mask] *= numpy.sqrt(group_patterns[mask] / intensity_model[mask])
        wave_diff = numpy.fft.ifft2(pattern_model, norm='forward') - exit_wave

        # old PIE stuff
        #illum_factor = numpy.abs(group_probes) / numpy.max(numpy.abs(group_probes), axis=(-1, -2), keepdims=True)
        #obj_update = alpha * illum_factor * group_probes.conj() / (numpy.abs(group_probes)**2 + illum_reg) * wave_diff

        # ePIE
        obj_update = alpha * group_probes.conj() / numpy.max(numpy.abs(group_probes)**2, axis=(-1, -2), keepdims=True) * wave_diff
        probe_update = beta * group_objs.conj() / numpy.max(numpy.abs(group_objs)**2, axis=(-1, -2), keepdims=True) * wave_diff
        probe_update = numpy.fft.fft2(numpy.fft.ifftshift(probe_update, axes=(-1, -2)), norm='forward') / fourier_shift_filter(get_subpx_shifts(group_scan))

        # TODO combine object updates in a smarter way here
        set_view_at_pos(obj, group_scan, group_objs + obj_update)

        probe = probe + numpy.mean(probe_update, axis=0)  # average probe updates for each in group

    print(f"Finished iter #{i+1}")

In [10]:
from scipy.linalg import lstsq

def remove_phase_ramp(data: numpy.ndarray) -> numpy.ndarray:
    output = numpy.empty_like(data)

    (yy, xx) = (arr.flatten() for arr in numpy.indices(data.shape[-2:], dtype=float))
    pts = numpy.stack((numpy.ones_like(xx), xx, yy), axis=-1)

    for idx in numpy.ndindex(data.shape[:-2]):
        layer = data[*idx]
        p, residues, rank, singular = lstsq(pts, layer.flatten())
        output[*idx] = layer - (p @ pts.T).reshape(layer.shape)

    return output

fig, ax = pyplot.subplots()

crop = numpy.ceil(sim_r / numpy.array([px_size_y, px_size_x])).astype(numpy.int_)

ax.imshow(remove_phase_ramp(numpy.angle(obj[crop[0]:-crop[0], crop[1]:-crop[1]])))

In [11]:
fig, ax = pyplot.subplots()
ax.set_box_aspect(1.)
i, j = numpy.indices(obj.shape)
ax.pcolormesh(x_min + j * px_size_x, y_min + i*px_size_y, numpy.angle(obj))
ax.scatter(scanx, scany, s=1, c='red')
#ax.imshow(numpy.angle(obj))

In [12]:
fig, axs = pyplot.subplots(ncols=2, nrows=3)
fig.set_size_inches(4, 7)

for ax in axs.ravel():
    ax.set_axis_off()

axs[0, 1].sharex(axs[0, 0])
axs[0, 1].sharey(axs[0, 0])
for idxs in ((1, 1), (2, 0), (2, 1)):
    axs[*idxs].sharex(axs[1, 0])
    axs[*idxs].sharey(axs[1, 0])

pattern_max = max(map(numpy.nanmax, (group_patterns[0], intensity_model[0])))

axs[0, 0].set_title("Object")
if True:
    crop = numpy.ceil(sim_r / numpy.array([px_size_y, px_size_x])).astype(numpy.int_)
    crop = (slice(crop[0], -crop[0]), slice(crop[0], -crop[0]))
else:
    crop = (slice(0, None), slice(0, None))

axs[0, 0].imshow(remove_phase_ramp(numpy.angle(obj[*crop])))

#axs[0].imshow(numpy.angle(obj))
axs[1, 0].set_title("Exp. Pattern")
axs[1, 0].imshow(numpy.fft.fftshift(group_patterns[0]), vmin=0., vmax=pattern_max)
axs[1, 1].set_title("Sim. Pattern")
axs[1, 1].imshow(numpy.fft.fftshift(intensity_model[0]), vmin=0., vmax=pattern_max)
axs[0, 1].set_title("Obj update")
padded_obj_update = numpy.zeros_like(obj)
set_view_at_pos(padded_obj_update, group_scan, obj_update)
axs[0, 1].imshow(numpy.angle(padded_obj_update[*crop]))
axs[2, 0].set_title("Probe")
axs[2, 0].imshow(numpy.fft.fftshift(numpy.abs(probe)))
axs[2, 1].set_title("Probe update")
axs[2, 1].imshow(numpy.fft.fftshift(numpy.angle(probe_update[0])))