Another experimental notebook for holo-tomography, mostly meant for development purposes.

This one uses coherent probe modes (linear combination of probe modes different for each projection)

TODO:
1. Test fixing the probe mode coefficients - e.g. a linear ramp going from 0 to 1 and 1 to 0 for two modes.
  1. The probe modes would not be normalised.
  1. A scale factor could be refined per frame so that instead of having fixed probe mode coefficients, only the ratio would be fixed
  1. This would need a new kernel, not reducing the probe mode coefficients but the norm of the computed probe
1. Test on the holotomo-1dist data
1. Add an empty beam at the beginning and the end of the projections ?
1. Try again the free mode probe coefficients:
  1. with multiple distances (gives more redundancy)
  1. by using the new object for the probe update, rather than the old object ?

In [None]:
# -*- coding: utf-8 -*-

# PyNX - Python tools for Nano-structures Crystallography
#   (c) 2019-present : ESRF-European Synchrotron Radiation Facility
#       authors:
#         Vincent Favre-Nicolin, favre@esrf.fr
%matplotlib notebook
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from scipy.fftpack import fftshift
import imageio
from skimage.data import shepp_logan_phantom, coins, immunohistochemistry, microaneurysms, retina
import xraydb
from pynx.holotomo import *
from pynx.holotomo.utils import simulate_probe
from pynx.wavefront import Wavefront
from pynx.wavefront import operator as wop
from pynx.test.speed import SpeedTest

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
# Basic parameters
nrj_ev = 30000
wavelength=12398.4e-10/nrj_ev
material , density = xraydb.get_material('Fe')
delta,beta, atten = xraydb.xray_delta_beta(material, density, nrj_ev)
print("%15s: delta=%6e  beta=%6e  delta/beta=%6.1f"%(material, delta, beta, delta/beta))
thickness = 1e-6
pixel_size = 50e-9
ny, nx = 512, 512
nb_proj = 60
stack_size = 3
nz = 3
nb_probe = 1 # two probe modes
dz = 0.02  # Base distance
vz = dz * 1.04 **(np.arange(nz)*3)
print(vz)

In [None]:
# Distances formula from Zabler 2005 article 
ps = 0.1 #pixel_size*1e6
print(ps)
print(38.4* ps**2 + 0.4 *ps-0.12)
print(((38.4 * ps**2 + 0.4 *ps-0.12)*np.array([1,2,5,6])-9.84 * ps + 4.67)*1e-3/(wavelength*1e10))

In [None]:
# Simulate probe (propagation distances)
probe0 = simulate_probe((ny,nx), np.ones(nz * nb_probe)*0.2, nb_line_h=10, nb_line_v=10,nb_spot=20,
                        pixel_size=pixel_size, wavelength=wavelength, amplitude=2)

# We simulated the probe modes and z as a single dimension, so reshape
probe0 = probe0.reshape((nz, nb_probe, ny, nx))

if nb_probe > 1:
    # Change linearly the coefficients to see if we can recover this
    coherent_probe_modes = np.zeros((nb_proj, nz, nb_probe))
    for iz in range(nz):
        coherent_probe_modes[:,iz, 0] = np.linspace(1, 0, nb_proj)
        coherent_probe_modes[:,iz, 1] = np.linspace(0, 1, nb_proj)
    # Make the secondary probe mode a simple shift of the first so Paganin works
    probe0[:,1] = np.roll(probe0[:,0],(2,1), axis=(-1,-2))
else:
    coherent_probe_modes = True

plt.figure(figsize=(18,4))
for i in range(nz):
    for j in range(nb_probe):
        plt.subplot(nb_probe,nz*2,2*i+1 + 2*nz * j)
        plt.imshow(np.angle(probe0[i,j]),cmap='gray')
        plt.colorbar()
        plt.subplot(nb_probe,nz*2,2*i+2 + 2*nz * j)
        plt.imshow(np.abs(probe0[i,j]),cmap='gray')
        plt.colorbar()
plt.tight_layout()


In [None]:
# Create object from skimage examples
k = 2*np.pi/wavelength
mu = 2 * k * beta
print(mu*thickness, k*delta*thickness)
obj0 = np.empty((nb_proj, 512, 512), dtype=np.complex64)

img_src = [shepp_logan_phantom, immunohistochemistry, retina, coins]

obj0 = np.zeros((nb_proj, ny, nx), dtype=np.complex64)
i, ct = 0, 0
while ct < nb_proj:
    img = img_src[i]()
    if img.ndim == 2:
        img = img[..., np.newaxis]
    for j in range(img.shape[-1]):
        d = img[:,:,j]
        if d.shape[0] > ny:
            d = d[:ny]
        if d.shape[1] > nx:
            d = d[:, :nx]
        ny1, nx1 = d.shape
        obj0[ct,:ny1,:nx1] = d / d.max()
        
        if ny1 < ny:
            obj0[ct] = np.roll(obj0[ct],(ny-ny1)//2, axis=0)
        if nx1 < nx:
            obj0[ct] = np.roll(obj0[ct],(nx-nx1)//2, axis=1)
        
        # add random shift for more independent images
        if ct >= len(img_src):
            obj0[ct] = np.roll(obj0[ct], np.random.randint(0,nx, 2), axis=(-1,-2))
        
        ct += 1
        if ct == nb_proj:
            break
    i = (i+1) % len(img_src)

for i in range(0, nb_proj-1):
    obj0[i] = np.exp(-k * (-1j * delta + beta) * thickness * obj0[i])

# last projection is without an object (empty beam)
obj0[-1] = 1

nb_mode = 1
obj0 = np.reshape(obj0, (nb_proj, nb_mode, ny, nx))

# Last frames are without sample
sample_flag = np.ones(nb_proj, dtype=np.bool)
sample_flag[-1] = False

In [None]:
iobs = np.ones((nb_proj, nz, ny, nx), dtype=np.float32)

# Display last propagated wavefront ?
# w = wop.ImshowAbs() * w

# test displacements
dx = None  #np.random.uniform(2,5,(nb_proj, nz))
#dx[:,0] = 0
dy = None  #np.random.uniform(2,5,(nb_proj, nz))
#dy[:,0] = 0

# Create HoloTomoData
data = HoloTomoData(iobs, pixel_size_detector=pixel_size, wavelength=wavelength, detector_distance=vz,
               stack_size=stack_size, dx=dx, dy=dy, sample_flag=sample_flag)


In [None]:
# Create PCI object
p = HoloTomo(data=data, obj=obj0.copy(), probe=probe0, coherent_probe_modes=False)
plt.figure(figsize=(9.5,3))
p = ShowObj(fig_num=-1, istack=0, type='phase') * p
# p = LoopStack(ShowObj(i=None), out=False) *p  # Look at objects

if True:
    plt.figure(figsize=(9.5,5))
    p = ShowPsi(fig_num=-1, iproj=0, iz=None, type='amplitude') * PropagateNearField1() * ObjProbe2Psi1() * SwapStack(0) * p


In [None]:
# Copy calculated intensity to obs
p = Calc2Obs(poisson_noise=True, nb_photon=1e5) * p

In [None]:
#p = LoopStack(ObjProbe2Psi1(), copy_psi=True) * p
#for s in p.data.stack_v:
#    print(s.istack,s.psi.sum())
#print(p._cu_probe_mode_coeff)
#print(p._cu_probe.get().mean(axis=(-1,-2)), p._probe.mean(axis=(-1,-2)), p.nb_probe)
#p.probe_mode_coeff.dtype

In [None]:
if True:
    # Paganin reconstruction
    p = BackPropagatePaganin(delta_beta=delta/beta, alpha=0.0) * p
    #p = BackPropagateCTF(alpha=0.2,alpha_low=1e-5, delta_beta=delta/beta) * p
    #plt.figure(figsize=(16,3))
    #p = ShowObj(fig_num=-1, istack=0, type='phase') * p
    plt.figure(figsize=(9,2))
    p = ShowObj(fig_num=-1, istack=0, type='phase') * p
    #plt.figure(figsize=(9,2))
    #p = ShowObj(fig_num=-1, istack=0, type='amplitude') * p
    
    plt.figure(figsize=(9,2))
    plt.subplot(121)
    plt.imshow(abs(p.get_probe()[0,0]))
    #plt.imshow(abs(probe0[0,0]))
    if p.nb_probe > 1:
        plt.subplot(122)
        plt.imshow(abs(p.get_probe()[0,1]))
        #plt.imshow(abs(probe0[0,1]))

    #plt.figure(figsize=(9,2))
    #p = ShowObj(fig_num=-1, istack=0, type='phase') * p
#for s in p.data.stack_v:
#    print(abs(s.obj).mean(axis=(1,2,3)), s.iobs.sum(axis=(1,2,3)))
#print(abs(p.get_probe()).mean(axis=(1,2,3)))


In [None]:
p = AP(delta_beta=delta/beta, reg_obj_smooth=0, update_object=True, update_probe=True, calc_llk=20,weight_empty=1, probe_inertia=0)**100 * p
#p = AP(delta_beta=None, reg_obj_smooth=0, update_object=True, update_probe=True)**50 * p
plt.figure(figsize=(12,3))
p = ShowObj(fig_num=-1, istack=0, type='phase') * p
#print(p._cu_probe_mode_coeff)

In [None]:
if False:
    #plt.figure(figsize=(15,4))
    p = ShowPsi(iproj=None,iz=None, fig_num=-1, type='phase') * PropagateApplyAmplitude1() * ObjProbe2Psi1()* SwapStack(0) * p

    #p = Psi2ObjProbeCoherent1(delta_beta=delta/beta) * PropagateApplyAmplitude1() * ObjProbe2Psi1() * p
    #p = OrthoProbe() * (LoopStack(Psi2ObjProbeCoherent1(delta_beta=delta/beta) * PropagateApplyAmplitude1() * ObjProbe2Psi1()))**10 * p
    #p = OrthoProbe() * (LoopStack(Psi2ObjProbeCoherent1(delta_beta=delta/beta) * PropagateApplyAmplitude1() * ObjProbe2Psi1()))**10 * p
    #p = (OrthoProbe() * Psi2ProbeMergeCoh() * LoopStack(Psi2ObjProbeCoherent1() * PropagateApplyAmplitude1() * ObjProbe2Psi1()))**1 * p

    p = Psi2ProbeMerge() * LoopStack(Psi2ObjProbeCoherent1() * PropagateApplyAmplitude1() * ObjProbe2Psi1()) * p
    #p = Psi2ObjProbeCoherent1() * p
    plt.figure(figsize=(9,2))
    p = ShowObj(fig_num=-1, istack=0, type='phase') * p
if True:
    plt.figure(figsize=(9,2))
    plt.subplot(121)
    pr = p._cu_probe_new.get() # p.get_probe()
    plt.imshow(abs(pr[0,0]))
    if nb_probe > 1:
        plt.subplot(122)
        plt.imshow(abs(pr[0,1]))
print(p._cu_stack.obj.shape)

In [None]:
if True:
    plt.figure(figsize=(9.5,5))
    p = ShowPsi(fig_num=-1, iproj=None, iz=None, type='amplitude') * PropagateApplyAmplitude1() * ObjProbe2Psi1() * SwapStack(0) * p


In [None]:
#p._cu_probe_mode_coeff = p._cu_probe_mode_coeff_new
#p = ProbeNorm(option='probe') * p

In [None]:
((abs(p._cu_probe.get())**2).sum(axis=(-1,-2)))

In [None]:
p = AP(delta_beta=delta/beta, reg_obj_smooth=0.5, update_object=True, update_probe=True, calc_llk=20,weight_empty=1, probe_inertia=0)**100 * p
p = AP(delta_beta=delta/beta, reg_obj_smooth=0, update_object=True, update_probe=True, calc_llk=20,weight_empty=1, probe_inertia=0)**200 * p
#p = AP(delta_beta=None, reg_obj_smooth=0, update_object=True, update_probe=True)**50 * p
plt.figure(figsize=(12,3))
p = ShowObj(fig_num=-1, istack=0, type='phase') * p
#print(p._cu_probe_mode_coeff)