## Example: ptychographic reconstruction with position optimisations
This example uses simulated data, and tracks the positions changes vs the optimisation cycles

In [1]:
#import os
#os.environ['PYNX_PU'] = 'opencl'  # Select language and/or GPU name or rank through environment variable (optional)

%matplotlib notebook
import matplotlib.pyplot as plt
from pynx.ptycho import simulation, shape

# Import Ptycho, PtychoData and operators (automatically selecting OpenCL or CUDA)
from pynx.ptycho import *

## Simulate the Ptychographic data

In [2]:
n = 256
nb_frame = 120
pixel_size_detector = 55e-6
wavelength = 1.5e-10
detector_distance = 1
obj_info = {'type': 'phase_ampl', 'phase_stretch': np.pi, 'ampl_range': (0.8,1.2), 'alpha_win': .2}
probe_info = {'type': 'focus', 'aperture': (120e-6, 120e-6), 'focal_length': .08, 'defocus': 800e-6, 'shape': (n, n)}

# 50 scan positions correspond to 4 turns, 78 to 5 turns, 113 to 6 turns
scan_info = {'type': 'spiral', 'scan_step_pix': 30, 'n_scans': nb_frame, 'integer_values': False}
data_info = {'num_phot_max': 1e9, 'bg': 0, 'wavelength': wavelength, 'detector_distance': detector_distance,
             'detector_pixel_size': pixel_size_detector,
             'noise': 'poisson'}

# Initialisation of the simulation with specified parameters, specific <object>, <probe> or <scan>positions can be passed as:
# s = ptycho.Simulation(obj=<object>, probe=<probe>, scan = <scan>)
# omitting obj_info, probe_info or scan_info (or passing it as empty dictionary "{}")
s = simulation.Simulation(obj_info=obj_info, probe_info=probe_info, scan_info=scan_info, data_info=data_info)

# Data simulation: probe.show(), obj.show(), scan.show() and s.show_illumination_sum() will visualise the integrated total coverage of the beam
s.make_data()

posx, posy = s.scan.values

pixel_size_object = wavelength * detector_distance / pixel_size_detector / n

ampl = s.amplitude.values  # square root of the measured diffraction pattern intensity

Simulating object: phase_ampl
Simulating probe: focus




Simulating scan: spiral
Simulating ptychographic data [120 frames].


Parameters of the simulation:
Data info: {'pix_size_direct_nm': 10, 'num_phot_max': 1000000000.0, 'nb_photons_per_frame': 100000000.0, 'bg': 0, 'beam_stop_transparency': 0, 'noise': 'poisson', 'wavelength': 1.5e-10, 'detector_distance': 1, 'detector_pixel_size': 5.5e-05}
Scan info: {'type': 'spiral', 'scan_step_pix': 30, 'n_scans': 120, 'integer_values': False}
Object info: {'type': 'Custom', 'phase_stretch': 3.141592653589793, 'ampl_range': (0.8, 1.2), 'alpha_win': 0.2}
Probe info: {'type': 'focus', 'shape': (256, 256), 'sigma_pix': (50, 50), 'rotation': 0, 'aperture': (0.00012, 0.00012), 'focal_length': 0.08, 'defocus': 0.0008}


## Create the initial reconstructed object & probe
The initial object is random (amplitude between 0.5 and 1, phase between 0 and 0.5 radians), and the probe is different from the one used to simulate the diffraction patterns.

In [3]:
nyo, nxo = shape.calc_obj_shape(posx, posy, ampl.shape[1:])

# Initial object
# obj_init_info = {'type':'flat','shape':(nx,ny)}
obj_init_info = {'type': 'random', 'range': (0.5, 1, 0, 0.5), 'shape': (nyo, nxo)}
# Initial probe
probe_init_info = {'type': 'focus', 'aperture': (100e-6, 100e-6), 'focal_length': .08,
                   'defocus': 700e-6, 'shape': (n, n)}
data_info = {'wavelength': wavelength, 'detector_distance': detector_distance,
             'detector_pixel_size': pixel_size_detector}
init = simulation.Simulation(obj_info=obj_init_info, probe_info=probe_init_info, data_info=data_info)

init.make_obj()
init.make_probe()

Simulating object: random
Simulating probe: focus




## Alter the positions & create the Ptycho object
We just alter two positions here, but more can be added.

The `p._interpolation` parameter can be used to trigger the use of interpolation - i.e. when a scan position does not correspond to an integer number of pixels, the object is interpolated with a bilinear approximation. Note that it is not necessary to use this interpolation to detetermine the 

In [4]:
posx1, posy1 = posx.copy(), posy.copy()
posx1[10] += 5
posy1[10] += 10
posx1[20] -= 5
posy1[20] -= 5

if False:
    posx1, posy1 = np.round(posx1), np.round(posy1)

data = PtychoData(iobs=ampl ** 2, positions=(posx1 * pixel_size_object, posy1 * pixel_size_object), 
                  detector_distance=1, mask=None, pixel_size_detector=55e-6, wavelength=1.5e-10)

p = Ptycho(probe=s.probe.values, obj=init.obj.values, data=data, background=None) # Random object start


# Use interpolation ?
p._interpolation = False

# Initial scaling of object and probe
p = ScaleObjProbe(verbose=True) * p

ScaleObjProbe: 4342.861 295676.3 22107.34267425547 1226.8418651011557 18.019717004844853


## Initial object and probe optimisation

In [5]:
plt.figure()
p = DM(update_object=True, update_probe=True, calc_llk=20, show_obj_probe=20)**40 * p
p = AP(update_object=True, update_probe=False, calc_llk=20, show_obj_probe=20)**40 * p
p = ML(update_object=True, update_probe=True, calc_llk=20, show_obj_probe=20)**40 * p

<IPython.core.display.Javascript object>

DM/o/p     #  0 LLK= 297781.53(p) 256194918.40(g) 469775.50(e), nb photons=2.371257e+13, dt/cycle=0.641s
DM/o/p     # 20 LLK= 15184.06(p) 2646404.27(g) 24559.07(e), nb photons=2.476473e+13, dt/cycle=0.047s
DM/o/p     # 39 LLK= 20275.21(p) 3006576.53(g) 31904.90(e), nb photons=2.457154e+13, dt/cycle=0.025s
AP/o       # 40 LLK= 20275.21(p) 3006573.87(g) 31904.91(e), nb photons=2.457154e+13, dt/cycle=0.263s
AP/o       # 60 LLK=  6575.07(p) 4313313.87(g) 12982.36(e), nb photons=2.492320e+13, dt/cycle=0.022s
AP/o       # 79 LLK=  6174.58(p) 4210672.00(g) 12279.62(e), nb photons=2.492448e+13, dt/cycle=0.018s
ML/o/p     # 81 LLK=  5824.31(p) 4167488.00(g) 11584.07(e), nb photons=2.492529e+13, dt/cycle=0.294s
ML/o/p     #101 LLK=  5480.71(p) 4878354.67(g) 11352.55(e), nb photons=2.494789e+13, dt/cycle=0.029s
ML/o/p     #120 LLK=  5390.74(p) 5029602.67(g) 11232.44(e), nb photons=2.494745e+13, dt/cycle=0.027s


## Optimise positions
This works best using AP or ML algorithms. DM tends to be more unstable.

We use the `pos_history` option so that we can plot the position history vs the cycle number later. It slows down the optimisation as data needs to be retreived from the GPU for each cycle.

In [6]:
plt.figure()  # Use a new figure
#p = ShowObjProbe() *DM(update_object=True, update_probe=True, update_pos=True,  pos_threshold=0.1,
#                       pos_min_shift=0.0, pos_max_shift=2, pos_history=True, calc_llk=20,
#                       show_obj_probe=20)**100 * p
p = ShowObjProbe() *AP(update_object=True, update_probe=True, update_pos=5, pos_mult=5, pos_threshold=0.2,
                       pos_min_shift=0.0, pos_max_shift=2, pos_history=True, calc_llk=50,
                       show_obj_probe=50)**500 * p
p = ShowObjProbe() * ML(update_object=True, update_probe=True, update_pos=5, 
                        pos_history=True, calc_llk=20, show_obj_probe=20)**100 * p


<IPython.core.display.Javascript object>

AP/o/p/t   #121 LLK=  5388.35(p) 5051173.87(g) 11238.50(e), nb photons=2.494854e+13, dt/cycle=0.305s
AP/o/p/t   #171 LLK=  2109.15(p) 1201524.80(g)  4425.31(e), nb photons=2.494329e+13, dt/cycle=0.014s
AP/o/p/t   #221 LLK=  1972.25(p) 1164917.20(g)  4166.59(e), nb photons=2.494085e+13, dt/cycle=0.014s
AP/o/p/t   #271 LLK=  1849.93(p) 1114876.40(g)  3933.56(e), nb photons=2.494059e+13, dt/cycle=0.014s
AP/o/p/t   #321 LLK=  1922.20(p) 1072479.93(g)  4090.06(e), nb photons=2.494237e+13, dt/cycle=0.014s
AP/o/p/t   #371 LLK=  1906.82(p) 1084910.67(g)  4062.19(e), nb photons=2.494021e+13, dt/cycle=0.014s
AP/o/p/t   #421 LLK=  2105.22(p) 1325889.33(g)  4459.50(e), nb photons=2.494053e+13, dt/cycle=0.014s
AP/o/p/t   #471 LLK=  2126.75(p) 1091558.53(g)  4504.19(e), nb photons=2.493369e+13, dt/cycle=0.014s
AP/o/p/t   #521 LLK=  1961.12(p) 1060565.80(g)  4176.46(e), nb photons=2.493595e+13, dt/cycle=0.014s
AP/o/p/t   #571 LLK=  2262.64(p) 1078594.33(g)  4777.55(e), nb photons=2.493523e+13, dt/cyc

## Plot the position shifts
The recorded position shifts can be manually plotted

In [9]:
ipos = [10,20]  # 10 or 20
fig = plt.figure(figsize=(9.5,4))
ax = plt.subplot(121)
#for i in range(default_processing_unit.get_stack_size()):
for i in range(50):
    x = [v[1] for v in p.position_history[i]]
    y = [v[2] for v in p.position_history[i]]
    plt.scatter(x,y, 1)
    plt.text(x[0], y[0], '%d' % i)
    #print("%3d  dr = %5.3f" % (i, np.sqrt((x[0]-x[-1])**2 + (y[0]-y[-1])**2)))
ax.set_aspect(1)

for i in range(len(ipos)):
    plt.subplot(2,2,2 + 2 *i)
    ix, x, y = [v[0] for v in p.position_history[ipos[i]]], [v[1] for v in p.position_history[ipos[i]]], \
               [v[2] for v in p.position_history[ipos[i]]]
    plt.plot(ix,x,'b.', label='x[%d]'%ipos[i])
    plt.xlabel('cycle')
    plt.twinx()
    plt.plot(ix,y,'r.', label='y[%d]'%ipos[i])
    fig.legend(loc="center right")
plt.tight_layout()

<IPython.core.display.Javascript object>