## Notebook example of ptychograpic reconstruction on simulated data

This is a basic example, showing how to:
* simulate data
* compute the object shape from the data
* use the Ptycho operators for reconstruction
* including an incoherent background

In [None]:
# Optional: select language and/or GPU name or rank through environment variable
import os
os.environ['PYNX_PU'] = 'opencl'

%matplotlib notebook
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from pynx.ptycho import simulation, shape

# Import Ptycho, PtychoData and operators (automatically selecting OpenCL or CUDA)
from pynx.ptycho import *

## Simulate the Ptychography dataset

In [None]:
# 2D detector size (square)
nxy = 256
# Pixel size in meters
pixel_size_detector = 55e-6
# Wavelength in meters
wavelength = 1.5e-10
# Detector distance in meters
detector_distance = 1

# background
x,y = np.arange(nxy), np.arange(nxy)
dx, dy = np.meshgrid((x - x.mean()) / (x.max() - x.min()), (y - y.mean()) / (y.max() - y.min()))
bg = np.exp(-5 * (dx ** 2 + dy ** 2))
bg *= 2 * 1e9 / bg.sum()

# Object options 'siemens' simulates Siemens star (with a few holes)
# 'logo' simulates PyNX logo
# obj_info = {'type': 'logo', 'phase_stretch': 1.57, 'alpha_win': .2}
obj_info = {'type': 'siemens', 'phase_stretch': 1.57, 'alpha_win': .2}

# Probe description, either as a Gaussian, or as a focused aperture
probe_info = {'type': 'focus', 'aperture': (150e-6, 150e-6), 'focal_length': .08,
              'defocus': 350e-6, 'shape': (nxy, nxy)}
# probe_info = {'type': 'gauss', 'sigma_pix': (20, 20), 'shape': (nxy, nxy)}

# Spiral scan: 50 positions = 4 turns, 78 = 5 turns, 113 = 6 turns
scan_info = {'type': 'spiral', 'scan_step_pix': 20, 'n_scans': 200}

# Data info, with the different parameters and using Poisson noise
# nb_photons_per_frame is the average number of photons per frame
data_info = {'nb_photons_per_frame': 1e9, 'bg': bg, 'wavelength': wavelength,
             'detector_distance': detector_distance,
             'detector_pixel_size': pixel_size_detector,
             'noise': 'poisson'}

# Initialisation of the simulation
s = simulation.Simulation(obj_info=obj_info, probe_info=probe_info, scan_info=scan_info, data_info=data_info)
s.make_data()
posx, posy = s.scan.values
pixel_size_object = wavelength * detector_distance / pixel_size_detector / nxy
ampl = s.amplitude.values  # square root of the measured diffraction pattern intensity


plt.figure(figsize=(9.5,4))
plt.subplot(121)
plt.imshow(bg)
plt.colorbar()

plt.subplot(122)
plt.imshow((ampl**2).mean(axis=0))
plt.colorbar()


## Prepare the initial object and probe
This uses the `pynx.ptycho.simulation` module for an explicit simulation of object and probe.

Note that if the initial object array is not supplied to the `Ptycho`object (`obj=None`), its size will be automatically calculated, and the object initialised to an homogeneous object (array of 1)

In [None]:
# Size of the reconstructed object
nyo, nxo = shape.calc_obj_shape(posx, posy, ampl.shape[1:])

# Initial object
obj_init_info = {'type': 'random', 'range': (0.9, 1, 0, 0.5), 'shape': (nyo, nxo)}

# Initial probe
probe_init_info = {'type': 'focus', 'aperture': (150e-6, 150e-6), 'focal_length': .08,
              'defocus': 250e-6, 'shape': (nxy, nxy)}

# Basic data info, used to compute the object pixel size
data_info = {'wavelength': wavelength, 'detector_distance': detector_distance,
             'detector_pixel_size': pixel_size_detector}
# Perform the actual simulation
init = simulation.Simulation(obj_info=obj_init_info, probe_info=probe_init_info, data_info=data_info)
init.make_obj()
init.make_probe()

## Create the `PtychoData` and `Ptycho` objects

In [None]:
data = PtychoData(iobs=ampl ** 2, positions=(posx * pixel_size_object, posy * pixel_size_object), 
                  detector_distance=1, mask=None, pixel_size_detector=55e-6, wavelength=1.5e-10)

# Random object start + almost no initial background
bg0 = np.random.uniform(0.5,1.5,bg.shape)
p = Ptycho(probe=init.probe.values, obj=init.obj.values, data=data, background=bg0)

# Initial scaling of object and probe
p = ScaleObjProbe(verbose=True) * p

## Optimise the Ptycho object 
This can use different algorithms:
* Difference Map
* Alternating Projections
* Maximum Likelihood conjugate gradient

For each algorithm it is possible to update object, probe, positions, and to display the result.

Each algorithm operator is elevated to the the number of cycles, e.g. `DM()**40` will perform 40 cycles

In [None]:
plt.figure()
p = DM(update_object=True, update_probe=True, update_background=1, calc_llk=10, show_obj_probe=10,
      center_probe_n=0)**200 * p
p = AP(update_object=True, update_probe=True, calc_llk=10, update_background=1)**40 * p
#p = ML(update_object=True, update_probe=True, calc_llk=20, show_obj_probe=20, update_background=0)**100 * p

In [None]:
# Look at the reconstructed background

plt.figure()
plt.imshow(p.get_background())  #, norm=LogNorm(vmin=0.5))
#plt.imshow(p.get_background(), norm=LogNorm())
#plt.imshow(abs(p._probe[0]), norm=LogNorm())
plt.colorbar()

## Add probe modes and continue optimising
The `DM/o/3p` indicates:
* the algorithm (DM or AP or ML)
* the parts which are optimised (o for object, p for probe, t for translations)
* the number of modes (when >1)

In [None]:
pr = p.get_probe()
nb_probe, ny, nx = pr.shape
# New number of probe modes
nb_probe = 5  # 35??
pr1 = np.empty((nb_probe, ny, nx), dtype=np.complex64)
pr1[0] = pr[0]
for i in range(1, nb_probe):
    n = abs(pr).mean() / 10
    pr1[i] = np.random.uniform(0, n, (ny, nx)) * np.exp(1j * np.random.uniform(0,2*np.pi, (ny,nx)))

p.set_probe(pr1)

plt.figure()
p = DM(update_object=True, update_probe=True, update_background=1, calc_llk=10, show_obj_probe=10)**40 * p
p = AP(update_object=True, update_probe=True, update_background=1, calc_llk=10, show_obj_probe=10)**40 * p
p = ML(update_object=True, update_probe=True, update_background=1, calc_llk=20, show_obj_probe=20)**100 * p


In [None]:
# Manual decompositon of algorithms
#p = Psi2Obj() * PropagateApplyAmplitude()* ObjProbe2Psi() * SelectStack(0) * p
#p = Psi2ObjMerge() * LoopStack(Psi2Obj() * PropagateApplyAmplitude() * ObjProbe2Psi()) * p

## Export data and/or result object & probe to CXI (hdf5) files

In [None]:
if False:
    # 
    p.save_obj_probe_cxi('obj_probe.cxi')
    save_ptycho_data_cxi('data.cxi', ampl ** 2, pixel_size_detector, wavelength, detector_distance,
                         posx * pixel_size_object, posy * pixel_size_object, z=None, monitor=None,
                         mask=None, instrument='simulation', overwrite=True)


## View background
... and test filtering it

In [None]:
plt.figure()
plt.imshow(p.get_background())  #, norm=LogNorm(vmin=0.5))


In [None]:
b0 = p.get_background()

In [None]:
p.set_background(b0)
p = BackgroundFilter(200) * p
b1 = p.get_background()
plt.figure(figsize=(9.5,4))
plt.subplot(121)
plt.imshow(b0)
plt.colorbar()
plt.subplot(122)
plt.imshow(b1)
plt.colorbar()
