In [None]:
########################################################################
#
# Example of 2D Bragg ptycho reconstruction from simulated data
# (c) ESRF 2019-present
# Authors: Vincent Favre-Nicolin <favre@esrf.fr>
#
########################################################################
%matplotlib ipympl
import os
import timeit
import numpy as np
from matplotlib.colors import LogNorm

# Use only OpenCL - this must be done before PyNX imports
os.environ['PYNX_PU']='opencl'  # Force using OpenCL only (even for Wavefront)

from pynx.ptycho.bragg2d import *
from pynx.ptycho.bragg.cpu_operator import show_3d
from pynx.ptycho.simulation import spiral_archimedes
from pynx.wavefront import Wavefront, ImshowRGBA, ImshowAbs, PropagateFarField, PropagateNearField
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider

# Full width
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
# Experiment parameters
wavelength = 1.5e-10
delta = np.deg2rad(45)
eta = delta/2
nu = np.deg2rad(0)
pixel_size_detector = 55e-6
ny, nx = (160, 160)
detector_distance = 1

# Spiralscan positions
nb = 64
default_processing_unit.cl_stack_size = nb  #16
xs, ys = spiral_archimedes(100e-9, nb)
zs = np.zeros_like(xs)

# Rotate positions according to eta (piezo motors in the sample frame are rotated by eta)
ce, se = np.cos(eta), np.sin(eta)
ys, zs = ce * ys + se * zs, ce * zs - se * ys

# Project the positions along z onto the average sample position, to avoid non-centered Psi
zs += ys / np.tan(eta)

# detector parameters
detector = {'rotation_axes': (('x', delta), ('y', nu)), 'pixel_size': pixel_size_detector,
            'distance': detector_distance}

In [None]:
# Create empty data
data = Bragg2DPtychoData(iobs=np.empty((nb, ny, nx), dtype=np.float32), positions=(xs, ys, zs), mask=None,
                       wavelength=wavelength, detector=detector)
plt.figure()
plt.plot(data.posx, data.posy)

In [None]:
if False:
    # Import existing probe from 2D ptycho
    d = np.load("/Users/favre/Analyse/201606id01-FZP/ResultsScan0013/latest.npz")
    #d = np.load("/Users/vincent/Analyse/201606id01-FZP/ResultsScan0000/latest.npz")
    pr = Wavefront(d=np.fft.fftshift(d['probe'],axes=(-2,-1)), z=0, pixel_size=d['pixelsize'], wavelength=wavelength)
else:
    # Simulate probe from a focused aperture, with some defocusing
    pixel_size_focus = wavelength * detector_distance / (nx * pixel_size_detector)
    focal_length = 0.09
    defocus = 0e-6
    pixel_size_aperture = wavelength * focal_length / (nx * pixel_size_focus)
    pr = Wavefront(d=np.ones((ny, nx)), wavelength=wavelength, pixel_size=pixel_size_aperture)
    x, y = pr.get_x_y()
    r = np.sqrt(x**2+y**2)
    print(x.min(),x.max(), y.min(), y.max(), wavelength, pixel_size_aperture)
    #widthy = 200e-6
    #widthx = 200e-6
    #pr.set((abs(y) < (widthy / 2)) * (abs(x) < (widthx / 2)))
    pr.set(r < 50e-6)
    pr = PropagateNearField(dz=defocus) * PropagateFarField(focal_length, forward=False) * pr

print('Probe pixel size: %6.2fnm'%(pr.pixel_size*1e9))
pr = ImshowRGBA()*pr

In [None]:
# Create main Bragg Ptycho object
p = Bragg2DPtycho(probe=pr, data=data, support=None)
pxyz = p.voxel_size_object()
print(wavelength * detector_distance / (pixel_size_detector*nx)*1e9)
print("Object voxel size: %6.2fnm x %6.2fnm x %6.2fnm" % (pxyz[0] * 1e9, pxyz[1] * 1e9, pxyz[2] * 1e9))
print(p.m)

In [None]:
# Base parallelepiped object
x0, x1, y0, y1, z0, z1 = -1e-6, 1e-6, -200e-9, 200e-9, -400e-9, 400e-9
# Create a support. Larger than the object, or not...
rs = 1.0
# Equation for GPU init of support using Monte-Carlo integration
eq = "(x >= %g) * (x <= %g) * (y >= %g) * (y <= %g) * (z >= %g) * (z <= %g)" % (rs * x0, rs * x1, rs * y0, rs * y1, rs * z0, rs * z1)
print(eq)
p = InitSupport(eq, rotation_axes=[('x', -eta)], shrink_object_around_support=True) * p

if True:
    plt.figure(figsize=(9,4))
    show_3d(p.support,ortho_m=p.m, rotation=('x', -eta))
    plt.figure(figsize=(9,4))
    show_3d(p.support,ortho_m=p.m, rotation=None)


In [None]:
obj0 = p.support / 100
obj1 = obj0.copy()
if False:
    # Add some strain
    obj1 = obj0 * np.exp(1j * 8 * np.exp(-(x ** 2 + z ** 2) / 200e-9 ** 2))
if False:
    # a few random twin domains
    nb_domain = 20
    cx = np.random.uniform(x0, x1, nb_domain)
    cz = np.random.uniform(z0, z1, nb_domain)
    c = (np.random.uniform(0, 1, nb_domain) > 0.5).astype(np.float32)
    # distance of eqch domain
    dist2 = np.ones_like(obj0, dtype=np.float32)
    ph = np.zeros_like(obj0, dtype=np.float32)
    for i in range(nb_domain):
        d2 = (x - cx[i]) ** 2 + (z - cz[i]) ** 2
        ph = ph * (d2 >= dist2) + c[i] * (d2 < dist2)
        dist2 = dist2 * (d2 >= dist2) + d2 * (d2 < dist2)
    obj1 = obj0 * (2 * ph - 1)  # +/-1
    # obj1 = obj0 * np.exp(1j * np.pi / 2 * ph)  # 0 or pi/2
if True:
    # Put hole to get an idea about the shape & orientation of pixels
    nzo,nyo,nxo = p.support.shape
    dn = 4
    obj1[nzo//2-dn:nzo//2+dn,nyo//2-dn:nyo//2+dn,nxo//2-dn:nxo//2+dn]=0
    #obj1[nzo//2-20:nzo//2+20,nyo//2-2:nyo//2+2,nxo//2-2:nxo//2+2]=0

p.set_obj(obj1)
if False:
    plt.figure(figsize=(9,4))
    plt.pcolormesh(z[:,:,nxo//2]*1e6,y[:,:,nxo//2]*1e6,abs(p._obj[0,:,:,nxo//2]))
    plt.xlabel('z')
    plt.ylabel('y')
    plt.xlim(-.5,.5)
    plt.ylim(-.5,.5)
    plt.gca().set_aspect('equal')

    #show_3d(p._probe[0],ortho_m=p.m)
    show_3d(p.support,ortho_m=p.m, rotation=('x', eta))
plt.figure(figsize=(9,4))
p = ShowObj(rotation=None, title='Object (eta=%6.2f°, delta=%6.2f°)' % (np.rad2deg(eta), np.rad2deg(delta))) * p
#p = ShowObj(rotation=('x', eta), title='Object rotated back to eta=0') * p
print(x0*1e6,x1*1e6,y0*1e6,y1*1e6,z0*1e6,z1*1e6)

In [None]:
# Compute the pixel coordinate of the center of the object*probe for each frame
p = CalcCenterObjProbe() * p
print(p._cl_obs_v[0].cl_cixo, p._cl_obs_v[0].cl_ciyo)

In [None]:
#print(p._cl_obs_v[0].cl_ix, p._cl_obs_v[0].cl_iy)
plt.figure()
plt.scatter(p._cl_obs_v[0].cl_cixo.get(),p._cl_obs_v[0].cl_ciyo.get())
#plt.plot(p.data.posx,p.data.posy)
#plt.plot(xs,ys)

In [None]:
p = ObjProbe2Psi() * p
# Display calculated Psi before propagation to the detector

if True:
    plt.figure(figsize=(9,4))
    psi = np.fft.fftshift(p._cl_psi.get(), axes=(-1,-2))[0,0]
    # psi = p._cl_psi.get()[0,0]
    def iplot(i):
        plt.clf()
        plt.imshow(np.abs(psi[i]),origin='lower')
        plt.title('Psi #%d'%i)
        plt.colorbar()

    interact(iplot, i=IntSlider(min=0,max=len(p.data.iobs)-1,step=1,value=0))

In [None]:
# Calculate the observed intensity and copy it to the observed ones
p = Calc2Obs(poisson_noise=True, nb_photons_per_frame=1e7) * FT() * ObjProbe2Psi() * p

if True:
    # Display simulated intensities
    plt.figure(figsize=(9,4))
    iobs = np.fft.fftshift(p.data.iobs, axes=(-1,-2))
    def plot_simul(i):
        plt.clf()
        plt.imshow(np.abs(iobs[i]),origin='lower', norm=LogNorm(vmin=0.01))
        plt.title('Simulated frame #%d'%i)
        plt.colorbar()

    interact(plot_simul, i=IntSlider(min=0,max=len(p.data.iobs)-1,step=1,value=0))

In [None]:
# Set object starting point
p.set_obj(obj0 * np.random.uniform(.2,1.0,obj0.shape))
p = CalcCenterObjProbe() * p
plt.figure(figsize=(9,4))
p = ShowObj(rotation=('x',eta)) * p

if True:
    # Set starting probe by cutting low amplitudes
    pr = np.abs(p.get_probe())
    apr = abs(pr)
    p.set_probe(apr * (apr > (apr.max()/10)))

# Scale object and probe with observed intensity before any optimisation
p = ScaleObjProbe() * p

In [None]:
# Solve this
plt.figure(figsize=(9,4))

p = DM(update_object=True, update_probe=True, calc_llk=5, show_obj_probe=0, reg_fac_obj_a=1e5, reg_fac_obj_c=100) ** 20 * p
p = AP(update_object=True, update_probe=True, calc_llk=5, show_obj_probe=0, reg_fac_obj_a=1e5, reg_fac_obj_c=100) ** 40 * p

#p = ML(calc_llk=5, show_obj_probe=0) ** 40 * p
p = ShowObj() * p

In [None]:
# Compare obs and calc frames
plt.figure(figsize=(9,4))
p = FT(scale=False) * ObjProbe2Psi() * p
calc = np.fft.fftshift(p._cl_psi.get()[0,0], axes=(-1,-2))
obs0 = np.fft.fftshift(p.data.iobs, axes=(-1,-2))
obs = obs0 * (obs0>=0)  # remove <0 (masked) pixels
print(obs[0].sum() / (np.abs(calc[0])**2).sum())
def plot_obs_calc(i):
    plt.clf()
    plt.subplot(121)
    vmax = np.percentile(np.abs(obs[i]),99)
    plt.imshow(np.abs(calc[i])**2,origin='lower',norm=LogNorm(vmin=0.01, vmax=vmax))
    plt.title('calc')
    plt.colorbar()
    plt.subplot(122)
    plt.imshow(np.abs(obs[i]),origin='lower',norm=LogNorm(vmin=0.01, vmax=vmax))
    #plt.imshow(obs0[i],origin='lower', vmin=0, vmax=20)
    plt.title('obs')
    plt.colorbar()

interact(plot_obs_calc, i=IntSlider(min=0,max=len(p.data.iobs)-1,step=1,value=0))

In [None]:
p = FourierApplyAmplitude() * ObjProbe2Psi() * p
psi0 = np.fft.fftshift(p._cl_psi.get()[0,0], axes=[-1,-2])
p = ObjProbe2Psi() * Psi2ObjProbe() * p
psi1 = np.fft.fftshift(p._cl_psi.get()[0,0], axes=[-1,-2])
plt.figure(figsize=(9,4))
def iplot(i):
    plt.clf()
    plt.subplot(121)
    vmax = np.percentile(np.abs(psi0[i]),99)
    plt.imshow(np.abs(psi0[i])**2,origin='lower',norm=LogNorm(vmin=vmax/1e4, vmax=vmax))
    plt.title('psi0')
    plt.colorbar()
    plt.subplot(122)
    plt.imshow(np.abs(psi1[i]),origin='lower',norm=LogNorm(vmin=vmax/1e4, vmax=vmax))
    #plt.imshow(obs0[i],origin='lower', vmin=0, vmax=20)
    plt.title('psi1')
    plt.colorbar()
    print(i, (np.abs(psi1[i] * (np.abs(psi0[i])>0))**2).sum() / (np.abs(psi0[i])**2).sum())

interact(iplot, i=IntSlider(min=0,max=len(p.data.iobs)-1,step=1,value=0))

In [None]:
# This uses a lot of memory if npsi=None (all frames) or too large
p = ObjProbe2PsiDebug(npsi=16) * p

plt.figure(figsize=(15,4))
def iplot(i, i1):
    plt.clf()
    plt.subplot(131)
    vmax = np.percentile(np.abs(psi0[i]),99)
    plt.imshow(np.abs(p._psi3d[0,0,i,i1]),origin='lower') #,norm=LogNorm(vmin=vmax/1e4, vmax=vmax))
    plt.title('Psi3d')
    plt.colorbar()
    plt.subplot(132)
    plt.imshow(np.abs(p._obj3d[0,0,i,i1]),origin='lower') #,norm=LogNorm(vmin=vmax/1e4, vmax=vmax))
    plt.title('Obj3d')
    plt.colorbar()
    plt.subplot(133)
    plt.imshow(np.abs(p._probe3d[0,0,i,i1]),origin='lower') #,norm=LogNorm(vmin=vmax/1e4, vmax=vmax))
    plt.title('Probe3d')
    plt.colorbar()
    plt.tight_layout()
    plt.suptitle('XY')

interact(iplot, i=IntSlider(min=0,max=len(p.data.iobs)-1,step=1,value=0),
         i1=IntSlider(min=0,max=p._psi3d.shape[-3]-1,step=1,value=p._psi3d.shape[-3]//2))

In [None]:
# This uses a lot of memory if npsi=None (all frames) or too large
p = ObjProbe2PsiDebug() * p

plt.figure(figsize=(15,4))
def iplot(i, i1):
    plt.clf()
    plt.subplot(131)
    vmax = np.percentile(np.abs(psi0[i]),99)
    plt.imshow(np.abs(p._psi3d[0,0,i,:, i1]),origin='lower') #,norm=LogNorm(vmin=vmax/1e4, vmax=vmax))
    plt.title('Psi3d')
    plt.colorbar()
    plt.subplot(132)
    plt.imshow(np.abs(p._obj3d[0,0,i,:, i1]),origin='lower') #,norm=LogNorm(vmin=vmax/1e4, vmax=vmax))
    plt.title('Obj3d')
    plt.colorbar()
    plt.subplot(133)
    plt.imshow(np.abs(p._probe3d[0,0,i,:,i1]),origin='lower') #,norm=LogNorm(vmin=vmax/1e4, vmax=vmax))
    plt.title('Probe3d')
    plt.colorbar()
    plt.tight_layout()
    plt.suptitle('XZ')

interact(iplot, i=IntSlider(min=0,max=len(p.data.iobs)-1,step=1,value=0),
         i1=IntSlider(min=0,max=p._psi3d.shape[-2]-1,step=1,value=p._psi3d.shape[-2]//2))

In [None]:
# This uses a lot of memory if npsi=None (all frames) or too large
p = ObjProbe2PsiDebug() * p

plt.figure(figsize=(18,3))
def iplot(i, ix):
    plt.clf()
    plt.subplot(131)
    vmax = np.percentile(np.abs(psi0[i]),99)
    plt.imshow(np.abs(p._psi3d[0,0,i,:, :, ix]),origin='lower') #,norm=LogNorm(vmin=vmax/1e4, vmax=vmax))
    plt.title('Psi3d')
    plt.colorbar()
    plt.subplot(132)
    plt.imshow(np.abs(p._obj3d[0,0,i,:, :, ix]),origin='lower') #,norm=LogNorm(vmin=vmax/1e4, vmax=vmax))
    plt.title('Obj3d')
    plt.colorbar()
    plt.subplot(133)
    plt.imshow(np.abs(p._probe3d[0,0,i,:,:, ix]),origin='lower') #,norm=LogNorm(vmin=vmax/1e4, vmax=vmax))
    plt.title('Probe3d')
    plt.colorbar()
    plt.tight_layout()
    plt.suptitle('YZ')

interact(iplot, i=IntSlider(min=0,max=len(p.data.iobs)-1,step=1,value=0),
         ix=IntSlider(min=0,max=p._psi3d.shape[-1]-1,step=1,value=p._psi3d.shape[-1]//2))

In [None]:
plt.figure(figsize=(13,4))
plt.subplot(141)
plt.imshow(np.log10(abs(p._cl_probe.get()[0])))
plt.colorbar()
plt.subplot(142)
plt.imshow(np.log10(abs(p._cl_probe_grad.get()[0])))
plt.colorbar()
plt.subplot(143)
nzo, nyo, nxo = p.support.shape
plt.imshow(abs(p._cl_obj.get()[0,:,:,nxo//2]), origin='lower')
plt.colorbar()
plt.subplot(144)
plt.imshow(abs(p._cl_obj_grad.get()[0,:,:,nxo//2]), origin='lower')
plt.colorbar()
