# FTSH simulation of spectral dependent input

This notebook demonstrates the used reconstruction algorithms to spectrally separate and reconstruct fields with simulated diffraction patterns from wavelength dependent input.
As it works pixel based to match the camera pixels in a real experiment, the inputs (both reference and sample) are scaled according to their wavelength.

### Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from skimage import io
from skimage import transform

import regpy
from ftsh import spectroscopic_FTH
from regpy.solvers.nonlinear.newton import NewtonCG

from matplotlib.colors import LogNorm, SymLogNorm, CenteredNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.constants import speed_of_light as sol
from scipy.interpolate import interpn
from copy import deepcopy
from IPython import display

cmaps = ['Oranges', 'Greens','Blues','Purples']


### Define functions
* wavelengthscaling_uniformgrid: zooms the array depending on the scale in scales, uses regpy.vecsps.UniformGridFcts. As the operator is based on pixel wise calculation, the function is used to scale the input images according to the wavelengths.
* gauss: 2D gaussian, used to mimic illumination of sample and reference with a gaussian beam profile.
* interfholo: calculates $|\sum_{\lambda} \mathcal{F} (\mathrm{probe} + \mathrm{ref} \cdot \mathrm{e}^{- 2\pi i ft})|$, i.e. the delay dependent diffraction patterns in the far field.

In [None]:
def wavelengthscaling_uniformgrid(arrin, grid0, scales, imaxes = (-2,-1)):
    axes = [ax if ax.flags['C_CONTIGUOUS'] else ax.copy(order = 'C') for ll, ax
            in enumerate(grid0.axes)]
    arrout = np.array(
        [interpn((axes[imaxes[0]],axes[imaxes[1]]), arrin, 
        (grid0.coords[imaxes[0]][0]*scale,grid0.coords[imaxes[1]][0]*scale), 
        bounds_error=False, fill_value=0 ) for k, scale in enumerate(scales)])
    return arrout

def gauss(xx, yy, center, width, amplitude=1):
    x0, y0 = center
    dx, dy = width
    return amplitude * np.exp(-((xx - x0) ** 2 / dx ** 2 + (yy - y0) ** 2 / dy ** 2))

def interfholo(probe, ref,delays, frequencies,axes = (-2,-1)):
    phasefacts = np.exp(
                    -1j * 2 * np.pi * frequencies[np.newaxis,:,np.newaxis,np.newaxis]
                    *delays[:,np.newaxis,np.newaxis,np.newaxis])
    ffts = np.fft.fftshift(
        np.fft.ifftn(
            np.fft.ifftshift(
                probe[np.newaxis,:,:,:]+ref[np.newaxis,:,:,:]*phasefacts, 
                axes=axes),
                axes=axes,norm='ortho'), 
                axes=axes)
    output = np.abs(np.sum(np.abs(ffts)**2,axis=1))
    return output

### Parameters
In the following we set the parameters used for the simulation and reconstruction.
We need to set the delays, the wavelengths (calculated from the fundamental wavelength), and their ratio for the wavelength dependent scaling, as well as the number of pixels (detectorsize), the region of interest where the holograms will show up and an intensity to calculate the noiselevel.

We also define functions to update the shown images live during the field reconstructions.

In [None]:
delaynum = 4 ## the number of used delay steps
detectorsize = (1024, 1024)
wl = 1030e-9 ## fundamental wavelength
harms = [15,17,19,21] ##harmonic index
intensity=4*1e7 #used to calculate noise
wls = wl/np.asarray(harms)
T0 = wl/sol
f0 = sol/wl
freqs = sol*np.asarray(harms)/wl
wlscaling_factor=wls/wls[0]

roi = np.array([[400,675],[655,1024]])
#%% delay construction
delayfull = np.linspace(0,wl/sol,150)
used_inds=np.linspace(0, len(delayfull)/2,delaynum,endpoint=False).astype(int)
delay = delayfull[used_inds]
imax = (-2,-1)

In [None]:
## functions to dynamically update the plots while the high-res. reconstruction is running.
def updatelineplot(hdisplay,fig,ax,lines,ydatas):
    for (line, ydata) in zip(lines, ydatas):
        line.set_ydata(ydata)
        line.set_xdata(np.arange(len(ydata)))
    ax.relim()
    ax.autoscale_view(True,True,True)
    fig.canvas.draw()
    hdisplay.update(fig)
#     fig.canvas.flush_events()
def updateimages(disp, figi, axi, imis, dati, cbars, step):
    for (ax, recocom, wavel, harm, imi, cbi) in zip(axi.flatten(),dati, wls[:], harms [:], imis, cbars):
        ax.set_title('$\\lambda_{{{}}}$ = {:1.1f}nm, step {}'.format(harm,1e9*wavel,step))
        imi.set_data(np.abs(recocom))
        imi.autoscale()
        imi.draw(figi._get_renderer())
        cbi.update_normal(imi)
    disp.update(figi)
def updatedataimages(disp, figi, axs, ims, dati, cbars, step,refdat, useim=3):
    ims[0].set_data(np.abs(dati[useim]))
    ims[0].autoscale()
    ims[0].draw(figi._get_renderer())
    axs[1].set_title('reco. data [{}] step {}'.format(useim,step))
    axs[2].set_title('Difference step {}'.format(step))
    ims[1].set_array(np.abs(refdat[useim]-dati[useim]))
    ims[1].autoscale()
    ims[1].draw(fig2._get_renderer())
    cbars[0].update_normal(ims[0])
    cbars[1].update_normal(ims[1])
    disp.update(figi)

### Define domain

regpy operators used as the forward model for the reconstructions require a domain and a codomain. The domain used here is a uniform grid (regpy.vecsps.UniformGridFcts) that has the shape (frequencies, pixel_x, pixel_y). The codomain is implicitly constructed when the operator is defined.
As the spectral separation uses only the region of interest defined in the parameters, we define multiple uniform grids (roi, and full detectorsize).

In [None]:
## define the domain
yy2,xx2 = np.meshgrid(np.linspace(-0.5,0.5,detectorsize[0]), np.linspace(-0.5,0.5,detectorsize[1]))
nf = len(freqs)
y_ax_sm = np.arange(roi[0,0], roi[0,1])
x_ax_sm = np.arange(roi[1,0], roi[1,1])
y_ax = np.linspace(-detectorsize[0]/2, detectorsize[0]/2, detectorsize[0])
x_ax = np.linspace(-detectorsize[1]/2, detectorsize[1]/2, detectorsize[1])
f_indices = np.arange(nf)
detgrid = regpy.vecsps.UniformGridFcts(f_indices,y_ax, x_ax, dtype = complex)
detgrid_small = regpy.vecsps.UniformGridFcts(f_indices, y_ax_sm, x_ax_sm, dtype = complex)
detgrid3 = regpy.vecsps.UniformGridFcts(freqs,y_ax, x_ax, dtype = complex)

### Load images for reference and sample
The simulation uses a large reference hole and the CRC1465 logo from which different squares have been removed to serve as wavelength dependent input, which are illuminated with gaussian shaped beam profiles (flat phase).
The samples are moved by 193 pixels from the center and scaled according to their wavelength. The reference is scaled as well.

In [None]:
semmask1024 = np.load(r'imagefiles/semmask1024_11.npy')
refmask = deepcopy(semmask1024[1:5])
refmask[:,:,600:] = 0
refmask[:,:480,:] = 0

new_sample = io.imread(r"imagefiles/SFB_Logo_centered.png") # use appropriate example image here
new_sample1 = io.imread(r"imagefiles/SFB_Logo_centeredmissing1.png") # use appropriate example image here
new_sample2 = io.imread(r"imagefiles/SFB_Logo_centeredmissing2.png") # use appropriate example image here
new_sample3 = io.imread(r"imagefiles/SFB_Logo_centeredmissing3.png") # use appropriate example image here
fullsample0 = transform.resize(new_sample/255,detectorsize,order=0) 
fullsample1 = transform.resize(new_sample1/255,detectorsize,order=0) 
fullsample2 = transform.resize(new_sample2/255,detectorsize,order=0) 
fullsample3 = transform.resize(new_sample3/255,detectorsize,order=0) 

In [None]:
probe_beam_in = gauss(xx2, yy2, (0, 0), 
                       (int(detectorsize[0]/32), int(detectorsize[1]/64)))
ref_beam_in = gauss(xx2, yy2, (0, 0), 
                     (0.025, 0.025))


probe_beam_full0 = (fullsample0) * probe_beam_in
probe_beam_full1 = (fullsample1) * probe_beam_in
probe_beam_full2 = (fullsample2) * probe_beam_in
probe_beam_full3 = (fullsample3) * probe_beam_in

probe_beam_full0= np.roll(probe_beam_full0, (0,193), axis =(0,1))
probe_beam_full1= np.roll(probe_beam_full1, (0,193), axis =(0,1))
probe_beam_full2= np.roll(probe_beam_full2, (0,193), axis =(0,1))
probe_beam_full3= np.roll(probe_beam_full3, (0,193), axis =(0,1))

probe_beams_full0 = wavelengthscaling_uniformgrid(probe_beam_full0, detgrid3, wlscaling_factor)
probe_beams_full1 = wavelengthscaling_uniformgrid(probe_beam_full1, detgrid3, wlscaling_factor)
probe_beams_full2 = wavelengthscaling_uniformgrid(probe_beam_full2, detgrid3, wlscaling_factor)
probe_beams_full3 = wavelengthscaling_uniformgrid(probe_beam_full3, detgrid3, wlscaling_factor)

probe_beams_full = np.array((1*probe_beams_full0[0],
                               0.5*probe_beams_full1[1],
                               0.8*probe_beams_full2[2],
                               1.2*probe_beams_full3[3]))

probe_beams_full = 2*probe_beams_full/probe_beams_full.max()

ref_beams_in = wavelengthscaling_uniformgrid(ref_beam_in, detgrid, wlscaling_factor) 
ref_beams_full = 0.5 * (refmask * ref_beams_in)/(refmask * ref_beams_in).max()


### Calculate diffraction patterns for delays in delay

Diffraction patterns are calculated for the 4 delay steps. Afterwards possonian noise is applied according to the intensity given in the parameters.

In [None]:
simdata = interfholo(probe_beams_full, ref_beams_full, delay, freqs)
data = np.sum(simdata)/intensity*np.random.poisson(intensity * simdata/np.sum(simdata)) ## with noise
# data = simdata ## without noise

### Create mask and operator for spectral separation

A mask spanning the full region of interest is calculated and the operator for spectral separation is defined. The mask operator is not changing anything in this case. This is how the masking would be applied in a $N_{\\tau}< N_{\\lamba}$ case. 

In [None]:
## create mask for embedding operator (no masking, just roi)
maskprobe = np.ones(probe_beams_full.shape, bool)
semprobe_small=np.ones((4,roi[0,1]-roi[0,0],roi[1,1]-roi[1,0]))
masking_operator= regpy.operators.CoordinateMask(detgrid_small, semprobe_small)
fts_forward1 = spectroscopic_FTH(detgrid_small, delays=-delay, frequencies=freqs, masking_operator=masking_operator)

The following calculates multi-wavelength holograms from simulated data, sets initial guess and defines the domain and codomain norms. Also the solver and stoprule (here just based on the number of iterations) are set.

In [None]:
datsfull = np.fft.fftshift(np.fft.fftn(np.fft.ifftshift(data, axes=imax), axes=imax, norm='ortho'), axes=imax)

dats = datsfull[:,roi[0,0]:roi[0,1],roi[1,0]:roi[1,1]]
initguess = fts_forward1.domain.ones()*1e-16
h_codomain = regpy.hilbert.L2(fts_forward1.codomain)
h_domain = regpy.hilbert.L2(fts_forward1.domain)
setting = regpy.solvers.RegularizationSetting(op=fts_forward1, penalty=h_domain, data_fid=h_codomain)
datsi = fts_forward1.codomain.join(*tuple(dats.astype(np.complex128)))

solver = NewtonCG(setting, data=datsi, init=initguess)
errreco = []
stoprule = (
    regpy.stoprules.CountIterations(max_iterations=2)
    #  +
    # rules.RelativeChangeData(
    #     setting.h_codomain.norm,
    #     dats,
    #     1e5
    )

The next cell is where the solver to spectrally separate the multi-wavelength holograms is used. Afterwards the reconstruction (spectrally separated holograms) are padded back to the detectorsize and Fourier transformed to get the input for the reconstruction of high-quality images.

In [None]:
norm = np.linalg.norm(datsi)
previous_reco = 0
for reco, reco_data in solver.until(stoprule):
    newton_step = solver.iteration_step_nr
    erreco = []
    erreco.append([
        np.linalg.norm(reco-previous_reco)*delaynum/nf/norm,
        np.linalg.norm(reco_data-datsi)/norm])
    previous_reco = reco.copy()
    reco_data_s = np.asarray(fts_forward1.codomain.split(reco_data))
    print("Newton Step = {}".format(newton_step))
    print('abs. reconstruction errors step {}: {:1.4f}'.format(
                newton_step,
                np.linalg.norm(reco_data_s-dats)/np.linalg.norm(dats)))
    
    fig= plt.figure(figsize=(4.7,6),layout='constrained')
    subfigs = fig.subfigures(2,1, height_ratios=[1,2])
    axs1 = subfigs[0].subplots(1,1, sharex=True, sharey=True)
    subfigs[0].suptitle('Hologram')
    subfigs[0].supxlabel('$p_{i,x}$')
    subfigs[0].supylabel('$p_{i,y}$')
    axs1.imshow(np.sum(np.abs(dats),axis=0)/np.sum(np.abs(dats),axis=0).max())
    axs = subfigs[1].subplots(2,2, sharex=True, sharey=True)
    subfigs[1].suptitle('Spectral components (abs)')

    for (ax, recocom, wavel) in zip(axs.flatten(),reco, wls):
        ax.set_title('$\\lambda$ = {:1.1f}nm'.format(1e9*wavel))
        im = ax.imshow(np.abs(recocom)/np.abs(recocom).max())

In [None]:
reco_pad = np.pad(reco,((0,0),(roi[0][0], detectorsize[0]-roi[0][1]),(roi[1][0], detectorsize[1]-roi[1][1])) )
dats2 = np.fft.fftshift(np.fft.fftn(np.fft.ifftshift(reco_pad, axes=imax), axes=imax, norm='ortho'), axes=imax)
referftshifted = np.fft.fftshift(
                np.fft.fftn(
                    np.fft.ifftshift(
                        ref_beams_full,
                        axes=imax
                        ),
                    axes=imax,
                    norm='ortho'
                    ),
                axes=imax)
proj = regpy.operators.CoordinateMask(detgrid, maskprobe)
embedding = proj.adjoint

### Define forward operator used  to reconstruct high-quality images
In the follwing the high resolution reconstruction is prepared. First, the forward model is defined and the operator is constructed. Then the domain and codomain of the operator are set to have the L2 Norm.

In [None]:
def deconvrecoreffixed(domain, reference:np.ndarray, freqaxis = 0, imageaxes=(-2,-1), 
                            withifft = False,
                            paddingop = None#, mask2=None,
                            ):
    
    assert domain.is_complex
    assert not np.any(np.isin(imageaxes, freqaxis))
    
    padding = paddingop

    prop = regpy.operators.FourierTransform(padding.codomain, centered=True, axes=imageaxes)
    combine_ref_probe = regpy.operators.PtwMultiplication(prop.codomain,reference.conj())
    if withifft is True:
        ifftop = regpy.operators.FourierTransform(combine_ref_probe.codomain, centered=True, axes=(-2,-1)).adjoint
        
        return ifftop * combine_ref_probe * prop * padding
    else:
        return combine_ref_probe * prop * padding

In [None]:
recoop = deconvrecoreffixed(detgrid3, referftshifted,
                paddingop= embedding,
                withifft = False
                )

In [None]:
if type(recoop.codomain) is not regpy.vecsps.UniformGridFcts:

    if len(delay)>1:
        h_codomain2 = regpy.hilbert.L2(recoop.codomain[0])
    else: 
        h_codomain2 = regpy.hilbert.L2(recoop.codomain)
    for ind, grid in enumerate(recoop.codomain):
        if ind > 0:
            h_codomain2 = h_codomain2 + regpy.hilbert.L2(grid)
else:
    h_codomain2 = regpy.hilbert.L2(recoop.codomain)

if type(recoop.domain) is not regpy.vecsps.UniformGridFcts:
    if recoop.domain.ndim > 1:
        h_domain2 = regpy.hilbert.L2(recoop.domain[0])
        for ind, grid in enumerate(recoop.domain):
            if ind > 0:
                h_domain2 = h_domain2 + regpy.hilbert.L2(grid)
    else:
        h_domain2 = regpy.hilbert.L2(recoop.domain)
else:
        h_domain2 = regpy.hilbert.L2(recoop.domain)


In [None]:
initguess2 = recoop.domain.zeros()
initguess_split = embedding(initguess2)
setting2 = regpy.solvers.RegularizationSetting(op=recoop, penalty=h_domain2, data_fid=h_codomain2)

The solver for the high-resolution reconstruction and a stoprule need to be defined. The stoprule in this example is again only determined by the number of iterations.
Then the solver is run and the current reconstruction and the difference between exemplary input data and reconstructed data is shown.

In [None]:
#%% solver params
NewtonCG_cgmaxit = 100 #max interations per newton step
NewtonCG_rho = 0.85 
max_its = 20 #max number of newton steps


solver2 = NewtonCG(
            setting2, dats2, init=initguess2,
            cgmaxit=NewtonCG_cgmaxit, rho=NewtonCG_rho,
            )
errreco2 = []
stoprule2 = (
    regpy.stoprules.CountIterations(max_iterations=max_its)
)

In [None]:
#%%
#%matplotlib widget

plt.ioff()

errreco2.append([1, #np.linalg.norm(dats2-initguess_split)/np.linalg.norm(dats2),
                 np.linalg.norm(probe_beams_full-initguess_split)/np.linalg.norm(probe_beams_full)]
        )
## error plot
figu, ax1 = plt.subplots(1,1, figsize=(4,3),layout='constrained')
figuid = display.display("",display_id=True)
ax1.set_ylabel('relative error')
ax1.set_xlabel('Interation step')  
ax1.set_yscale('log')
errline1, = ax1.plot(np.asarray(errreco2)[:,0], label = 'Reconstruction, real part')
errline2, = ax1.plot(np.asarray(errreco2)[:,1], label = 'Diffraction pattern')
ax1.legend()
## curr. reconstruction plot
fix, axs = plt.subplots(2,2, sharex=True, sharey=True, figsize=(6,4),layout='constrained',gridspec_kw={'wspace':0.1})
fixid = display.display("",display_id=True)
fix.get_layout_engine().set(rect=(0, 0, 0.95, 1))
fix.suptitle('Reconstruction (abs)')
fix.supxlabel('$p_{i,x}$')
fix.supylabel('$p_{i,y}$')
axs[0][0].set_title('Wavelength = {:1.1f}nm'.format(1e9*wls[0]))
axs[0][0].set_xlim(roi[1][0],roi[1][1]-150)
axs[0][0].set_ylim(roi[0][1]-100,roi[0][0]+50)
ims = []
cbs = []
for (ax, recocom, wavel, harm,cmi) in zip(axs.flatten(),initguess_split, wls[:], harms [:],cmaps):
    ax.set_title('$\\lambda_{{{}}}$ = {:1.1f}nm'.format(harm,1e9*wavel))
    imi = ax.imshow(np.abs(recocom), cmap=cmi)
    ims.append(imi)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%",pad=0.1)
    cbar = fix.colorbar(imi, cax=cax)
    cbs.append(cbar)

## data vs. curr. reconstructed data
fig2, axs2 = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(8, 4),layout='constrained')
fig2id = display.display("",display_id=True)
im = axs2[0].imshow(np.abs(dats2[3]), interpolation='none',norm=LogNorm())
fig2.colorbar(im, ax=axs2[0], location = 'bottom')
imdats = axs2[1].imshow(np.zeros(dats2[3].shape)+1e-16, interpolation='none',norm=LogNorm())
axs2[1].set_title('reco. data [3] step {}'.format(0))
diffim = axs2[2].imshow(np.abs(dats2[3])-np.zeros(dats2[3].shape), interpolation='none',norm=LogNorm())
axs2[2].set_title('Difference step {}'.format(0))
cbar6 = fig2.colorbar(imdats, ax=axs2[1], location = 'bottom')
cbar5 = fig2.colorbar(diffim, ax=axs2[2], location = 'bottom')
axs2[0].set_title('Data [3]')

print('solver start')
for reco2, reco_data2 in solver2.until(stoprule2):
    newton_step = solver2.iteration_step_nr
    ereco = embedding(reco2)
    errreco2.append([np.linalg.norm(dats2-reco_data2)/np.linalg.norm(dats2), 
                    np.linalg.norm(probe_beams_full-ereco)/np.linalg.norm(probe_beams_full)]
        )
    print('abs. reconstruction errors step {}: {:1.4f}'.format(
        newton_step,
        np.linalg.norm(reco_data2-dats2)/np.linalg.norm(dats2)))
    # Plot results
    if newton_step % 1 == 0 or stoprule2.triggered:
        updatelineplot(figuid,figu, ax1,(errline1,errline2), (np.asarray(errreco2)[:,0],np.asarray(errreco2)[:,1]))
        updateimages(fixid,fix, axs, ims, ereco, cbs, newton_step)
        reco_data_comp = reco_data2
        updatedataimages(fig2id,fig2, axs2, (imdats,diffim), reco_data_comp, (cbar6, cbar5), newton_step,dats2)

In [None]:
#%%
plt.ioff()
fig = plt.figure(layout='constrained', 
                 figsize = (1.5*4.3,1.1*4.3)
                 )
sfig,sfig2 = fig.subfigures(1,2, width_ratios=(1,0.414),
                 wspace=0.0)
axsl = sfig.subplots(4,2, sharex='col', sharey=True, 
                          #figsize=(1*optica_colwidth,optica_colwidth*0.98),
                           width_ratios=[1,0.5],#layout='constrained', 
                           gridspec_kw={'wspace':0.0,
                                        'hspace':0.0
                          }
                          )
sfig.text(0.68,0.96,'Reconstruction', fontsize ='large')
sfig.suptitle('Input', fontsize ='large',ha='right')
fig.supxlabel('$\\mathrm{x\'}_\\mathrm{x}$ [pixel]',fontsize='medium')
fig.supylabel('$\\mathrm{x\'}_\\mathrm{y}$ [pixel]',fontsize='medium')
axsl[0][0].set_xlim(475,int(roi[1][1]-350/2))
axsl[0][1].set_xlim(int(roi[1][0]),int(roi[1][1]-350/2))
axsl[0][1].set_ylim(int(roi[0][1]-100),int(roi[0][0]+50))

for (axi, recocom, wavel, harm, cm) in zip(axsl[:,0].flatten(),(ref_beams_full+probe_beams_full), wls, harms,cmaps):
    axi.set_title('$\\lambda_{{{}}}$ = {:1.1f}nm'.format(harm,1e9*wavel), 
                  fontsize = 'medium')

    im = axi.imshow(np.abs(recocom), cmap=cm, interpolation='None')
    cax = axi.inset_axes([1.015,0,0.05,1])
    cbar = sfig.colorbar(im, cax=cax)
    
    cbar.ax.tick_params(axis='y', direction='out',which='both')
    cbar.ax.yaxis.set_offset_position('left')
for (axo, recocom, wavel, harm, cm) in zip(axsl[:,1].flatten(),ereco, wls, harms,cmaps):
    im = axo.imshow(np.abs(recocom), cmap=cm, interpolation='None')
    cax = axo.inset_axes([1.015,0,0.05,1])
    cbar = sfig.colorbar(im, cax=cax)
    
    cbar.ax.tick_params(axis='y', direction='out',which='both')
    cbar.ax.yaxis.set_offset_position('left')

axsr = sfig2.subplots(4,1, sharex=True, sharey=True, 
                          gridspec_kw={#'wspace':0.14, 
                                         'hspace':0
                          }
                          )
sfig2.suptitle('Sim. Hologram', fontsize ='large'
        , ha='center', 
        # va='center_baseline'
         )
axsr[0].set_xlim(int(roi[1][0]),int(roi[1][1]-350/2))
axsr[0].set_ylim(int(roi[0][1]-100),int(roi[0][0]+50))

for kk,(ax, recocom,datsio, dela) in enumerate(zip(axsr.flatten(),datsfull, dats, delay)):
    ax.set_title('$\\tau_{{{}}}$ = {:1.1f}fs'.format(kk,1e15*dela),
                 fontsize='medium')
    vma=np.abs(datsio).max()
    vmi = np.abs(datsio).min()
    im = ax.imshow(np.abs(recocom), cmap='Greys',vmax=vma,vmin =vmi, interpolation='None')
    cax = ax.inset_axes([1.015,0,0.05,1])
    
    cbar = sfig2.colorbar(im, cax=cax)
    
    cbar.ax.tick_params(axis='y', direction='out',which='both')

plt.show(block=False)


In [None]:
fig, axs = plt.subplots(3,3, figsize = (7,4.5),layout = 'constrained', sharey= True, sharex = True)
axs[0][0].set_xlim(int(roi[1][0]),int(roi[1][1]-350/2))
axs[0][0].set_ylim(int(roi[0][1]-100),int(roi[0][0]+50))
fig.supxlabel('$\\mathrm{x\'}_\\mathrm{x}$ [pixel]',fontsize='medium')
fig.supylabel('$\\mathrm{x\'}_\\mathrm{y}$ [pixel]',fontsize='medium')
for (axi, recocom, wavel, harm, cm) in zip(axs[:,0].flatten(),(probe_beams_full), wls, harms,cmaps):
    axi.set_title('$\\lambda_{{{}}}$ = {:1.1f}nm'.format(harm,1e9*wavel), 
                  fontsize = 'medium')

    im = axi.imshow(np.abs(recocom), cmap=cm, interpolation='None')
    cax = axi.inset_axes([1.015,0,0.05,1])
    cbar = fig.colorbar(im, cax=cax)
    
    cbar.ax.tick_params(axis='y', direction='out',which='both')
    cbar.ax.yaxis.set_offset_position('left')
for (axo, recocom, wavel, harm, cm) in zip(axs[:,1].flatten(),ereco, wls, harms,cmaps):
    im = axo.imshow(np.abs(recocom), cmap=cm, interpolation='None')
    cax = axo.inset_axes([1.015,0,0.05,1])
    cbar = fig.colorbar(im, cax=cax)
    
    cbar.ax.tick_params(axis='y', direction='out',which='both')
    cbar.ax.yaxis.set_offset_position('left')
for (axo, recocom, wavel, harm, cm) in zip(axs[:,2].flatten(),probe_beams_full-ereco, wls, harms,cmaps):
    im = axo.imshow(np.abs(recocom), cmap=cm, interpolation='None')
    cax = axo.inset_axes([1.015,0,0.05,1])
    cbar = fig.colorbar(im, cax=cax)
    
    cbar.ax.tick_params(axis='y', direction='out',which='both')
    cbar.ax.yaxis.set_offset_position('left')
plt.show()