### Example of using holotomo operators using simulated data
Support for holo-tomography is still in development. This notebook can be used to experiment with different configurations (single or multiple distances, etc...)

You need the `xraydb` package for the initial simulation

In [None]:
# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2019-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr
%matplotlib notebook
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from scipy.fftpack import fftshift
import imageio
from skimage.data import shepp_logan_phantom, coins, immunohistochemistry, microaneurysms, retina
import xraydb
from pynx.holotomo import *
from pynx.holotomo.utils import simulate_probe
from pynx.wavefront import Wavefront
from pynx.wavefront import operator as wop
from pynx.test.speed import SpeedTest

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
# Basic parameters
nrj_ev = 30000
wavelength=12398.4e-10/nrj_ev
material , density = xraydb.get_material('Fe')
delta,beta, atten = xraydb.xray_delta_beta(material, density, nrj_ev)
print("%15s: delta=%6e  beta=%6e  delta/beta=%6.1f"%(material, delta, beta, delta/beta))
thickness = 1e-6
pixel_size = 50e-9
ny, nx = 512, 512
nb_proj = 24
stack_size = 3
nz = 4  # Number of distances
dz = 0.03  # Base distance
vz = dz * 1.04 **(np.arange(nz)*3)
print(vz)

In [None]:
# Distances formula from Zabler 2005 article 
ps = 0.1 #pixel_size*1e6
print(ps)
print(38.4* ps**2 + 0.4 *ps-0.12)
print(((38.4 * ps**2 + 0.4 *ps-0.12)*np.array([1,2,5,6])-9.84 * ps + 4.67)*1e-3/(wavelength*1e10))

In [None]:
# Simulate probe (propagation distances)
probe0 = simulate_probe((ny,nx), np.ones(nz)*0.2, nb_line_h=10, nb_line_v=10,
                        pixel_size=pixel_size, wavelength=wavelength, amplitude=2)
probe0 = probe0.reshape((nz, 1, ny, nx))
plt.figure(figsize=(10,2))
for i in range(nz):
    plt.subplot(1,nz,i+1)
    plt.imshow(np.angle(probe0[i,0]),cmap='gray')
    plt.colorbar()
plt.tight_layout()

In [None]:
# Create object from skimage examples
k = 2*np.pi/wavelength
mu = 2 * k * beta
print(mu*thickness, k*delta*thickness)
obj0 = np.empty((nb_proj, 512, 512), dtype=np.complex64)

img_src = [shepp_logan_phantom, immunohistochemistry, retina, coins]

obj0 = np.zeros((nb_proj, ny, nx), dtype=np.complex64)
i, ct = 0, 0
while ct < nb_proj:
    img = img_src[i]()
    if img.ndim == 2:
        img = img[..., np.newaxis]
    for j in range(img.shape[-1]):
        d = img[:,:,j]
        if d.shape[0] > ny:
            d = d[:ny]
        if d.shape[1] > nx:
            d = d[:, :nx]
        ny1, nx1 = d.shape
        obj0[ct,:ny1,:nx1] = d / d.max()
        
        if ny1 < ny:
            obj0[ct] = np.roll(obj0[ct],(ny-ny1)//2, axis=0)
        if nx1 < nx:
            obj0[ct] = np.roll(obj0[ct],(nx-nx1)//2, axis=1)
        ct += 1
        if ct == nb_proj:
            break
    i = (i+1) % len(img_src)

for i in range(0, nb_proj-1):
    obj0[i] = np.exp(-k * (-1j * delta + beta) * thickness * obj0[i])

# last projection is without an object (empty beam)
obj0[-1] = 1

nb_mode = 1
obj0 = np.reshape(obj0, (nb_proj, nb_mode, ny, nx))

# Last frames are without sample
sample_flag = np.ones(nb_proj, dtype=np.bool)
sample_flag[-1] = False

In [None]:
print(obj0.shape)
if True:
    i = 0
    plt.figure(figsize=(9,4))
    plt.subplot(121)
    plt.imshow(np.angle(obj0[i,0]), cmap='gray')
    plt.colorbar()
    plt.subplot(122)
    plt.imshow(np.abs(obj0[i,0]), cmap='gray')
    plt.colorbar()

In [None]:
iobs = np.ones((nb_proj, nz, ny, nx), dtype=np.float32)

# Display last propagated wavefront ?
# w = wop.ImshowAbs() * w

# test displacements
dx = None  #np.random.uniform(2,5,(nb_proj, nz))
#dx[:,0] = 0
dy = None  #np.random.uniform(2,5,(nb_proj, nz))
#dy[:,0] = 0

# Create HoloTomoData
data = HoloTomoData(iobs, pixel_size_detector=pixel_size, wavelength=wavelength, detector_distance=vz,
               stack_size=stack_size, dx=dx, dy=dy, sample_flag=sample_flag)


In [None]:
# Create PCI object
p = HoloTomo(data=data, obj=obj0.copy(), probe=probe0)

plt.figure(figsize=(9.5,3))
p = ShowObj(fig_num=-1, istack=0, type='phase') * p
# p = LoopStack(ShowObj(i=None), out=False) *p  # Look at objects

plt.figure(figsize=(9.5,5))
p = ShowPsi(fig_num=-1, iproj=None, iz=None, type='amplitude') * PropagateNearField1() * ObjProbe2Psi1() * SwapStack(0) * p


In [None]:
# Copy calculated intensity to obs
p = Calc2Obs(poisson_noise=True, nb_photon=1e4) * p
if True:
    p = PropagateNearField1() * ObjProbe2Psi1() * SwapStack(0) * p
    ps = np.fft.fftshift((abs(p._cu_stack.psi.get())**2).sum(axis=(2,3)), axes=(-1,-2))
    plt.figure(figsize=(10,6))
    istack = 0
    for iproj in range(p.data.stack_size):
        for iz in range(p.data.nz):
            plt.subplot(p.data.stack_size, p.data.nz, iz + p.data.nz * iproj + 1)
            plt.imshow(p.data.stack_v[istack].iobs[iproj,iz], cmap='gray')
            plt.colorbar()
    plt.tight_layout()
    
    plt.figure(figsize=(10,6))
    istack = 0
    for iproj in range(p.data.stack_size):
        for iz in range(p.data.nz):
            plt.subplot(p.data.stack_size, p.data.nz, iz + p.data.nz * iproj + 1)
            plt.imshow(ps[iproj,iz], cmap='gray')
            plt.colorbar()
    plt.tight_layout()

In [None]:
if False:
    # Run speed tests
    p = TestParallelFFT(n_iter=10, n_stack=10, n_fft=3) * p

In [None]:
if False:
    #p = PropagateNearField1() * ObjProbe2Psi1() * SwapStack(0) * p
    p = PropagateApplyAmplitude1() * ObjProbe2Psi1() * SwapStack(0) * p
    #p = ApplyAmplitude1() * PropagateNearField1() * ObjProbe2Psi1() * SwapStack(0) * p
    plt.figure(figsize=(10,5))
    p = ShowPsi(iproj=None,iz=None, fig_num=-1, type='phase') * p

In [None]:
if True:
    # Paganin reconstruction
    p = BackPropagatePaganin(delta_beta=delta/beta, alpha=0.0) * p
    plt.figure(figsize=(16,3))
    p = ShowObj(fig_num=-1, istack=0, type='phase') * p

In [None]:
# You can start with a CTF instead of a Paganin projection - CTF is better for weak phase objects, Paganin for thick ones
# The CTF also works better with multi-distance datasets
if True:
    # Forget positions ?
    #p.set_positions(np.zeros((nb_proj, nz)), np.zeros((nb_proj, nz)))
    # CTF reconstruction
    p = BackPropagateCTF(alpha=0.2,alpha_low=1e-5, delta_beta=delta/beta) * p
    plt.figure(figsize=(16,3))
    p = ShowObj(fig_num=-1, istack=0, type='phase') * p

In [None]:
# Iterative reconstruction, with different algorithms, with or without delta/beta constraint

# Forget positions ?
# p.set_positions(np.zeros((nb_proj, nz)), np.zeros((nb_proj, nz)))

if False:
    p.set_probe(np.ones_like(p._probe))
    # p.set_obj(np.ones((nb_proj, ny, nx)))
    p.set_obj(np.random.uniform(0.95, 1, (nb_proj, ny, nx)) * np.exp(1j*np.random.uniform(0., 1, (nb_proj, ny, nx))))
    p = ScaleObjProbe(verbose=True) * p

#p = DM(update_probe=True, obj_min=None, obj_max=None, calc_llk=50, reg_obj_smooth=0.5,
#       delta_beta=delta/beta)**200 * p
#p = DM(update_probe=True, obj_min=None, obj_max=None, calc_llk=50, reg_obj_smooth=0.,
#      delta_beta=delta/beta)**100 * p
p = AP(update_probe=True, obj_min=None, obj_max=None, calc_llk=50, reg_obj_smooth=0.1,
       delta_beta=None, weight_empty=10)**200 * p
# p = AP(update_probe=True, obj_min=None, obj_max=None, calc_llk=50, reg_obj_smooth=0.1, # update positions ?
#         delta_beta=None, update_pos=5, pos_history=True)**400 * p
plt.figure(figsize=(10,2))
p = ShowObj(fig_num=-1, istack=0, type='phase') * p


In [None]:
plt.figure(figsize=(9,2))
p = ShowObj(fig_num=-1, istack=0, type='phase') * p
plt.figure(figsize=(9,2))
p = ShowObj(fig_num=-1, istack=0, type='amplitude') * p


In [None]:
if False:
    # modify the shifts and try to find them back
    dx, dy = p._cu_dx.get(), p._cu_dy.get()
    for iproj in range(p.data.nproj):
        for iz in range(p.data.nz):
            dx [iproj, iz] = iz
            dy [iproj, iz] = (iz + iproj) * (iz > 0)
    p._cu_dx = cua.to_gpu(dx)
    p._cu_dy = cua.to_gpu(dy)
    with np.printoptions(precision=2, suppress=True, threshold=1000, floatmode='fixed', linewidth=200):
        print(p._cu_dx.get()[:-1].transpose())
        print(p._cu_dy.get()[:-1].transpose())

    p = (Psi2ProbeMerge() *LoopStack(Psi2PosReg1(upsampling=8, save_position_history=True) *
                                     Psi2ObjProbe1(delta_beta=-1) * PropagateApplyAmplitude1() * ObjProbe2Psi1()))**20 * p

    with np.printoptions(precision=2, suppress=True, threshold=1000, floatmode='fixed', linewidth=200):
        print()
        print(p._cu_dx.get()[:-1].transpose())
        print(p._cu_dy.get()[:-1].transpose())


In [None]:
if False:
    # Plot the history of positions
    p._from_pu()
    vproj = range(0,24,4)
    fig = plt.figure(figsize=(16,2.5*len(vproj)))
    #ax = plt.axes((0.04,0.3,0.22,0.3))
    plt.subplot(441)
    #for i in range(default_processing_unit.get_stack_size()):
    for i in range(len(vproj)):
        iproj = vproj[i]
        plt.subplot(len(vproj), p.data.nz+1, (nz+1) * i +1)
        plt.title("Proj #%d"%iproj)
        plt.imshow(np.angle(p.data.stack_v[iproj // p.data.stack_size].obj[iproj % p.data.stack_size, 0]), cmap='gray')
        for iz in range(nz):
            plt.subplot(len(vproj), p.data.nz+1, (nz+1) * i +1 + iz +1)
            x = [v[1][iz] for v in p.position_history[iproj]]
            y = [v[2][iz] for v in p.position_history[iproj]]
            # plt.plot(x,y)
            plt.plot(x,'b.')
            plt.plot(y,'r.')

    plt.tight_layout()