In [2]:
%matplotlib widget
from smpr3d.core import Sparse4DData, Metadata4D
from numpy.fft import fftshift, fft2
import matplotlib.pyplot as plt 
import numpy as np
from matplotlib_scalebar.scalebar import ScaleBar
from pathlib import Path
from smpr3d.util import *
from mpl_toolkits.axes_grid1 import make_axes_locatable
from tifffile import imsave, imwrite

scan_number = 147

base_path = Path('/home/philipp/nvme/insync_berkeley/2020-09-22/')
adfpath = base_path / 'adf'

sparse_path = base_path / '4d/'

results_path = base_path / 'results/'
results_path_scan = results_path / f'scan{scan_number}'
results_path_vadf = results_path / f'vadf'
results_path_adf = results_path / f'adf'
results_path_abf = results_path / f'abf'
results_path_bf = results_path / f'bf'
results_path_dpc = results_path / f'dpc'
results_path_ptychography = results_path / f'dpc'

if not results_path.exists(): results_path.mkdir()
if not results_path_scan.exists(): results_path_scan.mkdir()
if not results_path_vadf.exists(): results_path_vadf.mkdir()
if not results_path_adf.exists(): results_path_adf.mkdir()
if not results_path_abf.exists(): results_path_abf.mkdir()
if not results_path_bf.exists(): results_path_bf.mkdir()
if not results_path_dpc.exists(): results_path_dpc.mkdir()
if not results_path_ptychography.exists(): results_path_ptychography.mkdir()
    
filename4d = sparse_path / f'data_scan{scan_number}_th4.0_electrons.h5'
filenameadf = adfpath / f'scan{scan_number}.dm4'

In [3]:
da = []
mda = []

results_path_scan = results_path / f'scan{scan_number}'
if not results_path_scan.exists():
    results_path_scan.mkdir()
filename4d = sparse_path / f'data_scan{scan_number}_th4.0_electrons.h5'
# filename4d = sparse_path / f'data_scan{scan_number}_th4.5_electrons.h5'
# filename4d = sparse_path / f'data_scan{scan_number}_th4_electrons.h5'
filenameadf = adfpath / f'scan{scan_number}.dm4'
# filenameadf = adfpath / f'scan{scan_number}_ADF.dm4'
d = Sparse4DData.from_4Dcamera_file(filename4d)
print('1')
metadata = Metadata4D.from_dm4_file(filenameadf)
metadata.alpha_rad = 25e-3
metadata.rotation_deg = 0
metadata.wavelength =  wavelength(metadata.E_ev)  
mda.append(metadata)

center, radius = d.determine_center_and_radius(manual=False, size=200) 
# center = [292.0,169.0]
print(center,radius)
# center[0] -= 2
# center[1] -= 3
print('2')
d.crop_symmetric_center_(center)
print('3')
s = d.sum_diffraction()
print('4')

da.append(d)

f,ax = plt.subplots(figsize=(4,4))
imax = ax.imshow(s)
ax.set_title(f'Scan {scan_number} sum after cropping')
plt.colorbar(imax)
plt.tight_layout()
f.savefig(results_path_scan /f'sum_crop.png')

1
[292.20951103 160.73790153] 123.21053157858441
2
old frames shape: (512, 512, 2572)
new frames shape: (512, 512, 2572)
old frames frame_dimensions: [576 576]
new frames frame_dimensions: [320 320]
3
4


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [4]:
abf = d.virtual_annular_image(radius/2, radius, d.frame_dimensions/2)
bf = d.virtual_annular_image(0, radius/2, d.frame_dimensions/2)
eabf = abf - bf
adf = d.virtual_annular_image(radius, d.frame_dimensions[0]/2, d.frame_dimensions/2)

bf[bf==0] = bf.mean()
abf[abf==0] = abf.mean()

fig, ax = plt.subplots(dpi=150)
im = ax.imshow(abf, cmap= plt.cm.get_cmap('bone'))
ax.set_title(f'Scan {scan_number} ABF')
ax.set_xticks([])
ax.set_yticks([])
ax.add_artist(ScaleBar(metadata.dr[0]/10,'nm'))
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
plt.tight_layout()
fig.savefig(results_path_abf /f'scan{scan_number}_abf.png')
fig.savefig(results_path_scan /f'scan{scan_number}_abf.png')
fig.savefig(results_path_adf /f'scan{scan_number}_abf.pdf')
fig, ax = plt.subplots(dpi=150)
im = ax.imshow(bf, cmap= plt.cm.get_cmap('bone'))
ax.set_title(f'Scan {scan_number} BF')
ax.set_xticks([])
ax.set_yticks([])
ax.add_artist(ScaleBar(metadata.dr[0]/10,'nm'))
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
plt.tight_layout()
fig.savefig(results_path_bf /f'scan{scan_number}_bf.png')
fig.savefig(results_path_scan /f'scan{scan_number}_bf.png')
fig.savefig(results_path_adf /f'scan{scan_number}_bf.pdf')
fig, ax = plt.subplots(dpi=150)
im = ax.imshow(adf, cmap= plt.cm.get_cmap('bone'))
ax.set_title(f'Scan {scan_number} ADF')
ax.set_xticks([])
ax.set_yticks([])
ax.add_artist(ScaleBar(metadata.dr[0]/10,'nm'))
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
plt.tight_layout()

# imwrite(results_path_scan /f'scan{scan_number}_abf.tif', abf.astype('float32'), imagej=True, resolution=(1./(metadata.dr[0]/10), 1./(metadata.dr[1]/10)),
#             metadata={'spacing': 1 / 10, 'unit': 'nm', 'axes': 'ZYX'})
# imwrite(results_path_scan /f'scan{scan_number}_bf.tif', bf.astype('float32'), imagej=True, resolution=(1./(metadata.dr[0]/10), 1./(metadata.dr[1]/10)),
#             metadata={'spacing': 1 / 10, 'unit': 'nm', 'axes': 'ZYX'})
# imwrite(results_path_scan /f'scan{scan_number}_eabf.tif', eabf.astype('float32'), imagej=True, resolution=(1./(metadata.dr[0]/10), 1./(metadata.dr[1]/10)),
#             metadata={'spacing': 1 / 10, 'unit': 'nm', 'axes': 'ZYX'})
# imwrite(results_path_scan /f'scan{scan_number}_adf', adf.astype('float32'), imagej=True, resolution=(1./(metadata.dr[0]/10), 1./(metadata.dr[1]/10)),
#             metadata={'spacing': 1 / 10, 'unit': 'nm', 'axes': 'ZYX'})

fig.savefig(results_path_adf /f'scan{scan_number}_adf.png')
fig.savefig(results_path_scan /f'scan{scan_number}_adf.png')
fig.savefig(results_path_adf /f'scan{scan_number}_adf.pdf')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [5]:
dwell_time = 1/87e3
detector_to_real_fluence_80kv = 1#1/0.56

fluence = d.fluence(metadata.dr[0]) * detector_to_real_fluence_80kv
flux = d.flux(metadata.dr[0], dwell_time) * detector_to_real_fluence_80kv


area = np.pi * radius **2

# print(f"E               = {metadata.E_ev/1e3}             keV")
# print(f"λ               = {metadata.wavelength * 1e2:2.2}   pm")
print(f'radius: {radius}')
print(f'[min,max] electrons to avoid coincidence losses: [{area//100},{area//40}]')
print(f"dR              = {metadata.dr} Å")
print(f"scan       size = {d.scan_dimensions}")
print(f"detector   size = {d.frame_dimensions}")
print(f"scan       FOV  = {d.scan_dimensions*metadata.dr/10} nm")
print(f"fluence         ~ {fluence} e/Å^2")
print(f"flux            ~ {flux} e/Å^2/s")

radius: 123.21053157858441
[min,max] electrons to avoid coincidence losses: [476.0,1192.0]
dR              = [0.31626087 0.31626087] Å
scan       size = [512 512]
detector   size = [320 320]
scan       FOV  = [16.19255676 16.19255676] nm
fluence         ~ 22120.367672679142 e/Å^2
flux            ~ 7341.278028576223 e/Å^2/s


In [6]:
alpha_max_factor = 1.2
alpha_max_factor = 1.05
dssb = d.crop_symmetric_center(d.frame_dimensions/2, radius*alpha_max_factor)
metadata.k_max = metadata.alpha_rad * alpha_max_factor / metadata.wavelength

s = dssb.sum_diffraction()

f,ax = plt.subplots(figsize=(4,4))
imax = ax.imshow(s)
plt.colorbar(imax)

old frames shape: (512, 512, 2572)
new frames shape: (512, 512, 2572)
old frames frame_dimensions: [320 320]
new frames frame_dimensions: [258 258]


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.colorbar.Colorbar at 0x7f9e9c3c09d0>

In [7]:
metadata.dr, dssb.scan_dimensions

(array([0.31626087, 0.31626087], dtype=float32), array([512, 512]))

In [11]:
import time
from smpr3d.util import get_qx_qy_1D, ZernikeProbeSingle, fourier_coordinates_2D, imsave
from ipywidgets import AppLayout, FloatSlider, GridspecLayout, VBox, Tab, Box, HBox
import ipywidgets as widgets
from matplotlib_scalebar.scalebar import ScaleBar
from numpy.fft import fftshift
from math import sin, cos
import time
import cupy as cp
import matplotlib.pyplot as plt
from cupy.cuda import Device
import cupyx.scipy.fft as fft
from IPython.display import display
from skimage.filters import gaussian
plt.ioff()

out0 = widgets.Output(layout={'border': '1px solid black'})

class InteractiveSSB:
    def __init__(self, sparse_data, radius, meta):
        self.out = widgets.Output(layout={'border': '1px solid black'})
        self.meta = meta
        self.data = sparse_data
        self.ssb_size = 32
        self.slic = np.s_[:,:]
        self.rmax = sparse_data.frame_dimensions[-1] // 2
        self.alpha_max = self.rmax / radius * meta.alpha_rad

        r_min = meta.wavelength / (2 * self.alpha_max)
        r_min = np.array([r_min, r_min])
        self.k_max = [self.alpha_max / meta.wavelength, self.alpha_max / meta.wavelength]
        self.r_min = np.array(r_min)
        self.dxy = np.array(meta.dr)
        
        self.margin = sparse_data.frame_dimensions // 2
        
        self.Qx1d, self.Qy1d = get_qx_qy_1D(sparse_data.scan_dimensions, self.dxy, cp.float32, fft_shifted=False)
        
        self.out.append_stdout(f"alpha_max       = {self.alpha_max * 1e3:2.2f} mrad\n")
        self.out.append_stdout(f"E               = {meta.E_ev/1e3} keV\n")
        self.out.append_stdout(f"λ               = {meta.wavelength * 1e2:2.2}   pm\n")
        self.out.append_stdout(f"dR              = {self.dxy[0]:2.2f}             Å\n")
        self.out.append_stdout(f"dK              = {self.k_max[0]:2.2f}          Å^-1\n")
        self.out.append_stdout(f"scan       size = {sparse_data.scan_dimensions}\n")
        self.out.append_stdout(f"detector   size = {sparse_data.frame_dimensions}\n")
        
        self.rotation_slider = widgets.FloatSlider(
            value=0,
            min=-180,
            max=180,
            step=0.1,
            description='STEM rotation angle:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
        )
        self.rotation_slider.observe(self.rotation_changed, 'value')
        self.rotation_deg = 0.0
        
        self.aberration_text = widgets.HTML(
                value="1",
                placeholder='',
                description='',
            )
        
        
        self.gs = GridspecLayout(4,9)
        self.Cslider_box = VBox(width=50)
        self.scale_slider_box = VBox()
        children= []
        sliders =  []

        self.probe_output = widgets.Output()
        self.overlaps_output = widgets.Output()
        
        qn = fourier_coordinates_2D([128,128],r_min)
        q = th.as_tensor(qn).cuda()
        alpha = q * meta.wavelength
        self.probe_gen = ZernikeProbeSingle(q, meta.wavelength, fft_shifted=False)
        self.A = np.linalg.norm(qn,axis=0) * meta.wavelength < meta.alpha_rad
        self.A = gaussian(self.A,1)
                
        self.C = cp.zeros((12,))
        self.C_names = ['C1','C12a' ,'C12b','C21a','C21b','C23a','C23b','C3','C32a','C32b','C34a','C34b']
        self.C_min = [-100,-20,-20,-50,-50,-50,-50,-20,-20,-20,-20,-20]
        self.C_max = [100,20,20,50,50,50,50,20,20,20,20,20]
        self.C_multiplier = [1e1,1e1,1e1,1e1,1e1,1e1,1e4,1e4,1e4,1e4,1e4]
        
        self.Psi_Qp = cp.zeros(sparse_data.scan_dimensions, dtype=np.complex64)
        self.Psi_Qp_left_sb = cp.zeros(sparse_data.scan_dimensions, dtype=np.complex64)
        self.Psi_Qp_right_sb = cp.zeros(sparse_data.scan_dimensions, dtype=np.complex64)
        self.Psi_Rp = cp.zeros(sparse_data.scan_dimensions, dtype=np.complex64)
        self.Psi_Rp_left_sb = cp.zeros(sparse_data.scan_dimensions, dtype=np.complex64)
        self.Psi_Rp_right_sb = cp.zeros(sparse_data.scan_dimensions, dtype=np.complex64)
        
        Psi = self.probe_gen(th.tensor(self.C.get()).cuda(),th.tensor(self.A).cuda())
        self.phases = th.angle(th.fft.fftshift(self.probe_gen(th.tensor(self.C.get()).cuda(),th.tensor(np.ones_like(self.A)).cuda())))
        self.Psi_shifted = th.fft.fftshift(Psi)
        self.psi = th.fft.fftshift(th.fft.ifft2(Psi))
        
        with self.probe_output:
            self.probe_figure = plt.figure(constrained_layout=True,figsize=(9,3))
            gs1 = self.probe_figure.add_gridspec(1, 3, wspace=0.05,hspace=0.05)
            self.f1ax0 = self.probe_figure.add_subplot(gs1[0])
            self.f1ax1 = self.probe_figure.add_subplot(gs1[1])
            self.f1ax2 = self.probe_figure.add_subplot(gs1[2])
            
            self.f1ax0.set_title('Probe  (real space)')
            self.f1ax1.set_title('Probe  (Fourier space)')
            self.f1ax2.set_title('Phase profile (Fourier space)')
            
            self.probe_realspace_imax = self.f1ax0.imshow(imsave(self.psi.cpu().numpy()))
            self.probe_fourier_imax = self.f1ax1.imshow(imsave(self.Psi_shifted.cpu().numpy()))
            self.probe_phases_imax = self.f1ax2.imshow(self.phases.cpu().numpy())
            #,cmap=plt.cm.get_cmap('Greys')

            self.f1ax0.set_xticks([])
            self.f1ax0.set_yticks([])
            self.f1ax1.set_xticks([])
            self.f1ax1.set_yticks([])
            self.f1ax2.set_xticks([])
            self.f1ax2.set_yticks([])
            
        self.overlap_figure_axes = []
        with self.overlaps_output:
            self.overlap_figure = plt.figure(constrained_layout=True,figsize=(9,9))
            gs1 = self.overlap_figure.add_gridspec(3, 3, wspace=0.05,hspace=0.05)
            for ggs in gs1:
                f3_ax1 = self.overlap_figure.add_subplot(ggs)
                imax2 = f3_ax1.imshow(np.zeros((40,40)))
                f3_ax1.set_xticks([])
                f3_ax1.set_yticks([])
                self.overlap_figure_axes.append(imax2)
                
        self.plot_box = VBox(children =[self.probe_figure.canvas,self.overlap_figure.canvas])    
        self.recon_fig, self.recon_axes = plt.subplots(figsize=(9,9))
        self.recon_img = self.recon_axes.imshow(np.zeros(sparse_data.scan_dimensions), cmap=plt.get_cmap('bone'))
        self.recon_axes.set_xticks([])
        self.recon_axes.set_yticks([])
        scalebar = ScaleBar(meta.dr[0]/10,'nm') # 1 pixel = 0.2 meter
        self.recon_axes.add_artist(scalebar)

        for i, (name, mins, maxs, multiplier) in enumerate(zip(self.C_names, self.C_min, self.C_max, self.C_multiplier)):
            s = FloatSlider(description=name,
                           min=mins, max = maxs)
            s.observe(self.create_function(f'slider_changed_{i}', i, multiplier), names='value')
            sliders.append(s)
            children.append(s)

        self.Cslider_box.children = children + [self.aberration_text]

        self.gs[:2,0] = self.Cslider_box
        self.gs[2:,0] = self.scale_slider_box
        self.gs[:,1:5] = self.plot_box
        self.gs[:,5:9] = self.recon_fig.canvas
        
        self.first_time_calc = True
    def __del__(self):
        del self.recon_fig
        del self.overlap_figure
        del self.probe_figure
        
    def rotation_changed(change):
        self.rotation_deg = change['new']
            
    @out0.capture()
    def _get_G(self, size):
        bin_factor = int(np.min(np.floor(self.data.frame_dimensions/size)))
        start = time.perf_counter()
        data = self.data.slice(self.slic)
        dc = data.to_dense(bin_factor)
        self.out.append_stdout(f"Bin by {bin_factor} for ssb took {time.perf_counter() - start:2.2g}s\n")
        
        M = cp.array(dc).astype(cp.float32)
        ny, nx, nky, nkx = M.shape
        start = time.perf_counter()
        G = fft.fft2(M, axes=(0, 1), overwrite_x=True)
        G /= cp.sqrt(np.prod(G.shape[:2]))
        self.out.append_stdout(f"FFT along scan coordinate took {time.perf_counter() - start:2.2g}s\n")
        return G
    
    @out0.capture()
    def update_variables(self):
        self.Psi_Qp[:] = 0
        self.Psi_Qp_left_sb[:] = 0
        self.Psi_Qp_right_sb[:] = 0
            
        eps=1e-3
        start = time.perf_counter()
        single_sideband_reconstruction(
                self.G,
                self.Qx1d,
                self.Qy1d,
                self.Kx,
                self.Ky,
                self.C,
                np.deg2rad(self.rotation_deg),
                self.meta.alpha_rad,
                self.Psi_Qp,
                self.Psi_Qp_left_sb,
                self.Psi_Qp_right_sb,
                eps,
                self.meta.wavelength,
            )
        m = 2

        self.Psi_Rp[:] = fft.ifft2(self.Psi_Qp, norm="ortho")
        self.Psi_Rp_left_sb[:] = fft.ifft2(self.Psi_Qp_left_sb, norm="ortho")
        self.Psi_Rp_right_sb[:] = fft.ifft2(self.Psi_Qp_right_sb, norm="ortho")
        
        self.Gamma = disk_overlap_function(self.Qx_max1d, self.Qy_max1d, self.Kx, self.Ky, self.C, np.deg2rad(self.rotation_deg), self.meta.alpha_rad, self.meta.wavelength)
        
        Psi = self.probe_gen(th.tensor(self.C.get()).cuda(),th.tensor(self.A).cuda())
        self.phases = th.angle(th.fft.fftshift(self.probe_gen(th.tensor(self.C.get()).cuda(),th.tensor(self.A).cuda())))
        self.Psi_shifted = th.fft.fftshift(Psi)
        self.psi = th.fft.fftshift(th.fft.ifft2(Psi))
            
    @out0.capture()
    def update_gui(self):
        gg = self.Gamma * self.G_max

        m=10
        img = np.angle(self.Psi_Rp_left_sb.get()[m:-m,m:-m])
        self.recon_img.set_data(img)
        self.recon_img.set_clim(img.min(),img.max())

        for ax, g in zip(self.overlap_figure_axes,gg):
            ax.set_data(imsave(g.get()))
        
        self.probe_realspace_imax.set_data(imsave(self.psi.cpu().numpy()))
        self.probe_fourier_imax.set_data(imsave(self.Psi_shifted.cpu().numpy()))
        self.probe_phases_imax.set_data(self.phases.cpu().numpy())
        self.probe_phases_imax.set_clim(self.phases.min(),self.phases.max())
        
        self.recon_fig.canvas.draw()
        self.overlap_figure.canvas.draw()
        self.probe_figure.canvas.draw()
        
        self.recon_fig.canvas.flush_events()
        self.overlap_figure.canvas.flush_events()
        self.probe_figure.canvas.flush_events()
    
    @out0.capture()
    def create_function(self, name, i, multiplier):
        def func1(change):
            self.C[i] = change['new']*multiplier
            w = change['new']*multiplier
            self.out.append_stdout(f"C[{i}] = {self.C[i]}\n")
            
            self.update_variables()
            self.update_gui()
            
        func1.__name__ = name
        return func1
    
    
    @out0.capture()
    def selected_index_changed(self, change):
        w = change['new']
        if w == 1:
            if self.first_time_calc:
                n_fit = 9
                self.G = self._get_G(self.ssb_size)
                self.Kx, self.Ky = get_qx_qy_1D(self.G.shape[-2:], self.r_min, cp.float32, fft_shifted=True)

                self.Gabs = cp.sum(cp.abs(self.G), (2, 3))
                sh = np.array(self.Gabs.shape)
                mask = ~np.array(fftshift(sector_mask(sh, sh // 2, 15, (0, 360))))
                mask[:,-1] = 0
                mask[:,0] = 0
                mask[:,1] = 0
                
                gg = self.Gabs.get()
                gg[~mask] = gg.mean()
                
                inds = np.argsort((gg).ravel())
                strongest_object_frequencies = np.unravel_index(inds[-1 - n_fit:-1], self.G.shape[:2])
                
                self.out.append_stdout(f"strongest_object_frequencies: {strongest_object_frequencies}\n")
                
                self.G_max = self.G[strongest_object_frequencies]
                self.Qy_max1d = self.Qy1d[strongest_object_frequencies[0]]
                self.Qx_max1d = self.Qx1d[strongest_object_frequencies[1]]
                
                self.first_time_calc = False
            self.update_variables()
            self.update_gui()
            
        elif w==2:
            data = dssb 
            bright_field_radius = radius

            meta = metadata

            defocus_list_nm = np.linspace(-300,300,60,endpoint=True)  + 0
            print(f'defocus_list_nm: {defocus_list_nm}')
            defocal = []
            defocal_right = []
            defocal_left = []

            for df in defocus_list_nm:
                print(df)
                C = xp.zeros((12,))
                C[0] = df
                Psi_Qp = cp.zeros((ny, nx), dtype=G.dtype)
                Psi_Qp_left_sb = cp.zeros((ny, nx), dtype=np.complex64)
                Psi_Qp_right_sb = cp.zeros((ny, nx), dtype=np.complex64)
                start = time.perf_counter()
                eps = 1e-3
                single_sideband_reconstruction(
                        self.G,
                        self.Qx1d,
                        self.Qy1d,
                        self.Kx,
                        self.Ky,
                        self.C,
                        np.deg2rad(self.rotation_deg),
                        self.meta.alpha_rad,
                        self.Psi_Qp,
                        self.Psi_Qp_left_sb,
                        self.Psi_Qp_right_sb,
                        eps,
                        self.meta.wavelength,
                    )
                self.Psi_Rp[:] = fft.ifft2(self.Psi_Qp, norm="ortho")
                self.Psi_Rp_left_sb[:] = fft.ifft2(self.Psi_Qp_left_sb, norm="ortho")
                self.Psi_Rp_right_sb[:] = fft.ifft2(self.Psi_Qp_right_sb, norm="ortho")

                defocal.append(Psi_Rp.get())
                defocal_right.append(Psi_Rp_right_sb.get())
                defocal_left.append(Psi_Rp_left_sb.get())

            ssb_defocal = np.array(defocal)
            ssb_defocal_right = np.array(defocal_right)
            ssb_defocal_left = np.array(defocal_left)
            
            fig, ax = plt.subplots(dpi=150)

            var1 = np.var(ssb_defocal[...,m:-m,m:-m],(1,2))
            var2 = np.var(ssb_defocal_right[...,m:-m,m:-m],(1,2))

            for i, v in enumerate([var1,var2,var3]):
                ax.scatter(np.arange(len(v)), v, label=f'{i}', s=3)
                ax.plot(v, label=f'{i}')
            plt.legend()
            
            mxs = np.argmax(var1)
            best_df = defocus_list_nm[mxs]
            mxs2 = np.argmax(var2)
            best_df2 = defocus_list_nm[mxs2]
            
            self.out.append_stdout((f'best_df ssb_defocal: {mxs}:{best_df}\n')
            self.out.append_stdout((f'best_df ssb_defocal_right: {mxs2}:{best_df2}\n')
    
    def show(self):
        n_fit = 49
        ranges=[360, 30]
        partitions = [144, 120]
        ranges = [360]
        partitions=[360]
        manual_frequencies=None
        manual_frequencies = None
        
        self.G = self._get_G(12)

        self.Gabs = cp.sum(cp.abs(self.G), (2, 3))
        sh = np.array(self.Gabs.shape)
        mask = ~np.array(fftshift(sector_mask(sh, sh // 2, 15, (0, 360))))
        mask[:,-1] = 0
        mask[:,0] = 0
        mask[:,1] = 0
        
        gg = np.log10(self.Gabs.get()+1)
        gg[~mask] = gg.mean()
        gg = fftshift(gg)
        gg1 = gg[self.margin[0]:-self.margin[0],self.margin[1]:-self.margin[1]]
        
        self.fig_power_spectrum, self.ax_power_spectrum = plt.subplots(1,1,figsize=(8,8))
        self.imax_power_spectrum = self.ax_power_spectrum.imshow(gg1, cmap=plt.cm.get_cmap('magma'))
        

        best, thetas, intensities = find_rotation_angle_with_double_disk_overlap(self.G, self.meta.wavelength, self.r_min, self.dxy, self.meta.alpha_rad, mask=cp.array(mask), ranges=ranges,
                                                                  partitions=partitions, n_fit=n_fit, verbose=False,
                                                                  manual_frequencies=manual_frequencies, aberrations=self.C)

        self.out.append_stdout(f"Best rotation angle: {np.rad2deg(thetas[best])}\n")
        fig, ax = plt.subplots()
        ax.scatter(np.rad2deg(thetas),intensities)
        ax.set_xlabel('STEM rotation [degrees]')
        ax.set_ylabel('Integrated G amplitude over double overlap')
                      
        rot_box = VBox([fig.canvas,self.rotation_slider])
        
        canvas_box = HBox([self.fig_power_spectrum.canvas, rot_box])
        gsl1 = VBox([canvas_box])
        gsl2 = GridspecLayout(1,1)
        
        tab_contents = ['P0', 'P1','P2']
        children = [gsl1,self.gs]
        tab = widgets.Tab()
        tab.children = children
        tab.titles = ['fft','00','00']
        
        tab.observe(self.selected_index_changed, 'selected_index')
        
        gsl00 = GridspecLayout(20,20)
        gsl00[:,:3] = self.out
        gsl00[:,3:] = tab
#         gsl00[:3,:] = out0
        return gsl00

t = InteractiveSSB(dssb,radius, alpha_max_factor,metadata)
t.show()

GridspecLayout(children=(Output(layout=Layout(border='1px solid black', grid_area='widget001'), outputs=({'out…

In [88]:
probe_radius_multiplier = 2

s= t.psi.cpu().numpy()

abss = np.abs(s)
w = abss > abss.max() * 5e-2
t.r_min.max()
probe_radius = np.sqrt(w.sum()*1.0/np.pi) * t.r_min.max()
print(probe_radius,t.r_min.max())

alpha_max = da[0].frame_dimensions[-1] // 2 / radius * metadata.alpha_rad
alpha_max_det = alpha_max + metadata.alpha_rad
k_max_probe = metadata.alpha_rad / metadata.wavelength
print('k_max_probe     ',k_max_probe)
print('alpha_max_det   ',alpha_max_det)
k_max_det = alpha_max_det / metadata.wavelength
k_max_sampling = 1/2/metadata.dr.min()

print('k_max_det       ',k_max_det)
print('k_max_sampling  ',k_max_sampling)
print()
k_max = np.max([k_max_sampling,k_max_det])
print('k_max           ',k_max)

dx = 1/2/k_max
print('dx              ',dx)

FOV = probe_radius * 2 * probe_radius_multiplier
dk = 1/FOV

smatrix_bf_radius_big = int(np.ceil(k_max_probe/dk))

bin_factor_big = int(np.min(np.ceil(radius/smatrix_bf_radius_big)))

print('dk              ',dk)
print('pixels_bf       ',pixels_bf)
print('FOV [A]         ',FOV)

radius_data_int_big = int(np.ceil(radius / bin_factor_big) * bin_factor_big)
radius_max_int_big = int(np.ceil(radius_factor_big * radius / bin_factor_big) * bin_factor_big)
frame_size_big = 2 * radius_max_int_big // bin_factor_big

alpha_max_big = radius_factor_big * metadata.alpha_rad

dx1_big = metadata.wavelength / 2 / alpha_max_big
dxy_big = [dx1_big,dx1_big]


k_max1_big = alpha_max_big / metadata.wavelength
k_max_big = [k_max1_big, k_max1_big]


pixel_step_y = metadata.dr[0]/dxy_big[0]
pixel_step_x = metadata.dr[1]/dxy_big[1]
pixel_step = np.array([pixel_step_y,pixel_step_x])

print(f'radius: {radius}')
print()
print('=== For iterative reconstruction: ')
print(f'target BF radius    : {smatrix_bf_radius_big}')
print(f'bin_factor          : {bin_factor_big}')
print(f'radius_max_int_big after bin: {radius_max_int_big/bin_factor_big}')
print(f'frame_size after bin: {frame_size_big}')
print(f'rotation (degrees)  : {metadata.rotation_deg}')
print(f'dxy                 : {dx1_big}')
print(f'k_max               : {k_max_big}')
print(f'pixel_step          : {pixel_step}')
print()

fig, ax = plt.subplots()
ax.imshow(w)
AppLayout(center=fig.canvas)

3.307741768388865 0.8375456621688379
k_max_probe      0.5986995486483224
alpha_max_det    0.057464757263455006
k_max_det        1.3761649694726368
k_max_sampling   1.5809733061413571

k_max            1.5809733061413571
dx               0.3162608742713928
dk               0.07558026517945807
pixels_bf        12
FOV [A]          13.23096707355546
radius: 123.21053157858441

=== For iterative reconstruction: 
target BF radius    : 8
bin_factor          : 16
radius_max_int_big after bin: 21.0
frame_size after bin: 42
rotation (degrees)  : 0
dxy                 : 0.31467348870517636
k_max               : [1.5889486021126478, 1.5889486021126478]
pixel_step          : [1.00504455 1.00504455]



AppLayout(children=(Canvas(layout=Layout(grid_area='center'), toolbar=Toolbar(toolitems=[('Home', 'Reset origi…

In [89]:
dbig = da[0].crop_symmetric_center(d.frame_dimensions/2,  radius*radius_factor_big)

old frames shape: (512, 512, 2572)
new frames shape: (512, 512, 2572)
old frames frame_dimensions: [320 320]
new frames frame_dimensions: [654 654]


In [94]:
ddense = dbig.to_dense(bin_factor_big)

slic_edge = np.s_[:,-50:,:,:]
sds = np.sum(ddense[slic_edge],(0,1)) / np.prod(dssb.scan_dimensions)
vacuum_probe = sds * (sds > sds.max() * 20e-2) 

fig,ax = plt.subplots()
ax.imshow(vacuum_probe)
AppLayout(center=fig.canvas)

radius_data_int : 336 
radius_max_int  : 336 
Dense frame size: 42x 42
sparse_to_dense_datacube_crop_gain_mask dtypes: torch.float32 torch.int32 torch.float32 int64


AppLayout(children=(Canvas(layout=Layout(grid_area='center'), toolbar=Toolbar(toolitems=[('Home', 'Reset origi…

In [91]:
from smpr3d.util import advanced_raster_scan
from smpr3d.util import h5write, h5append

sh = ddense.shape
ddense2 = ddense.reshape((sh[0]*sh[1],sh[2],sh[3]))
ddense2 = np.fft.fftshift(ddense2,(1,2))

r = advanced_raster_scan(dssb.scan_dimensions[0], dssb.scan_dimensions[1], fast_axis=1, mirror=[1, 1], theta=0,
                                 dy=pixel_step[0], dx=pixel_step[1])



d = {}
d['r'] = r
d['data'] = ddense2
d['vacuum_probe'] = vacuum_probe

ImportError: cannot import name 'h5write' from 'smpr3d.util' (/home/philipp/projects/smpr3d/smpr3d/util.py)

In [93]:
from smpr3d.util import ZernikeProbeSingle
from smpr3d.util import *

qnp = fourier_coordinates_2D(ddense.shape[2:], dxy_big, centered=False)
q = th.as_tensor(qnp)
Psi_gen = ZernikeProbeSingle(q, metadata.wavelength, fft_shifted=True)
Ap0 = th.as_tensor(fftshift(vacuum_probe))
C1 = th.as_tensor(t.C.get())
Psi_model = Psi_gen(C1, Ap0).detach()
Psi = Psi_model
psi = th.fft.ifft2(Psi_model,norm='ortho')
fig, ax = plt.subplots(1,2,figsize=(10,5))
ax[0].imshow(np.angle(psi))
ax[1].imshow(np.abs(psi))

# d['probe_real'] = psi.real
# d['probe_imag'] = psi.imag

AppLayout(center=fig.canvas)

AppLayout(children=(Canvas(layout=Layout(grid_area='center'), toolbar=Toolbar(toolitems=[('Home', 'Reset origi…

In [None]:
from smpr3d.core import ReconstructionOptions, reconstruct

metadata2 = Metadata4D(E_ev = E_ev, 
                       alpha_rad = metadata.alpha_rad, 
                       dr=metadata.dr,
                       k_max = k_max,
                       rotation_deg = t.rotation_deg)

options = ReconstructionOptions()

out = sm.reconstruct(data, metadata, options)

S = out.smatrix
r = out.r
Psi = out.Psi
R_factor = out.R_factors

In [96]:
k_max, rotation_deg

NameError: name 'rotation_deg' is not defined

In [None]:
import torch as th
import numpy as np
from torch.autograd import Function
from numba import cuda
from numpy.fft import fftfreq

@cuda.jit
def sparse_amplitude_loss_kernel(a_model, indices_target, counts_target, loss, grad, frame_dimensions,
                                 no_count_indicator):
    k = cuda.grid(1)
    K, _ = indices_target.shape
    MY, MX = frame_dimensions
    if k < K:
        for i in range(indices_target[k].shape[0]):
            idx1d = indices_target[k, i]
            my = idx1d // MX
            mx = idx1d - my * MX
            if idx1d != no_count_indicator:
                grad[k, my, mx] = 1 - (counts_target[k, i] / a_model[k, my, mx])
                cuda.atomic.add(loss, (0), (a_model[k, my, mx] - counts_target[k, i]) ** 2)


def sparse_amplitude_loss(a_model, indices_target, counts_target, frame_dimensions):
    """

    :param a_model:             K x M1 x M2
    :param indices_target:      K x num_max_counts
    :param counts_target:       K x num_max_counts
    :param frame_dimensions:    2
    :return: loss (1,), grad (K x M1 x M2)
    """
    threadsperblock = (256,)
    blockspergrid = tuple(np.ceil(np.array(indices_target.shape[0]) / threadsperblock).astype(np.int))

    loss = th.zeros((1,), device=a_model.device, dtype=th.float32)
    grad = th.ones_like(a_model)
    no_count_indicator = th.iinfo(indices_target.dtype).max
    sparse_amplitude_loss_kernel[blockspergrid, threadsperblock](a_model.detach(), indices_target.detach(),
                                                                 counts_target.detach(), loss.detach(), grad.detach(),
                                                                 frame_dimensions, no_count_indicator)
    return loss, grad

class SparseAmplitudeLoss(Function):

    @staticmethod
    def forward(ctx, a_model, indices_target, counts_target):
        frame_dimensions = th.as_tensor(a_model.shape[1:], device=a_model.device)
        loss, grad = sparse_amplitude_loss(a_model, indices_target, counts_target, frame_dimensions)
        loss.requires_grad = True
        ctx.save_for_backward(grad)
        return loss

    @staticmethod
    def backward(ctx, *grad_outputs):
        grad_input, = ctx.saved_tensors
        grad_indices = None
        grad_counts = None
        return grad_input, grad_indices, grad_counts

class PtychoSubpix(th.autograd.Function):
    @staticmethod
    def forward(ctx, S: th.Tensor, M: th.Tensor, Psi, pos: th.Tensor) -> th.Tensor:
        """
        :param S: N1 x N2 x 2     tensor
        :param psi: M1 x M2 x 2     tensor in fourier space
        :param pos: K x 2 real                  tensor

        :return: D x K x M1 x M2                measured amplitudes in the far field
        """
        qy, qx = np.meshgrid(fftfreq(M[0].item()), fftfreq(M[1].item()), indexing='ij')
        # M1 x M2 x 2
        q = th.stack([th.as_tensor(qy), th.as_tensor(qx)]).float().cuda()
        rs = pos - pos.int()
        ramp = complex_expi(
            -2 * np.pi * (q[0][None, ...] * rs[:, 0][:, None, None] + q[1][None, ...] * rs[:, 1][:, None, None]))
        # K x B x M x M x 2
        swap = th.LongTensor([1, 0])
        pos2 = pos[:, swap]
        # K x M1 x M2 x 2
        frames_exit = gather_patches(S, axes=[0, 1], positions=pos2.long(), patch_size=M, out=None)
        # frames_exit = th.stack([S[:, r[1]:r[1] + M[1].item(), r[0]:r[0] + M[0].item()] for r in pos], 0)
        # K x x M x M x 2
        # psi = th.fft(psi, 2)
        Psi = Psi.unsqueeze(0).repeat(ramp.shape[0], 1, 1, 1)
        psi = th.ifft(complex_mul(Psi, ramp), 2, True)
        # psi = th.ifft(Psi, 2, True)
        # plotcxmosaic(complex_numpy(psi.detach().cpu()))
        psi_exit = complex_mul(frames_exit, psi)
        # K x M x M x 2
        Psi = th.fft(psi_exit, 2, True)
        # K x M x M x 2
        a_model = cabs(Psi)
        # plot(a_model[100].cpu(),'a_mdel')

        ctx.save_for_backward(frames_exit, Psi, psi, pos2, ramp, th.Tensor([S.shape])[0].int())

        return a_model

    def backward(ctx, grad_output):
        frames_exit, Psi_exit, psi, pos2, ramp, ss = ctx.saved_tensors

        M = [grad_output.shape[1], grad_output.shape[2]]

        # K x M1 x M2 x 2
        grad_Psi = complex_mul_real(Psi_exit, grad_output)
        # K x M1 x M2 x 2
        grad_Psi = th.ifft(grad_Psi, 2, True)
        # M1 x M2 x 2

        tmp = complex_mul_conj(grad_Psi, psi)
        grad_S = th.zeros(tuple(ss.numpy()), device=Psi_exit.device, dtype=th.float32)
        grad_S = scatter_add_patches(tmp, grad_S, [0, 1], pos2.long(), M, reduce_dim=None)
        # same = th.sum((grad_S[re] == grad_S[im]) * (grad_S[re] != 0)* (grad_S[im] != 0))
        # print(f'same: {same}')
        # zplot([frames_exit_shifted[10, :, :, 0].cpu(), frames_exit_shifted[10, :, :, 1].cpu()], cmap=['inferno', 'inferno'], figsize=(9, 5))
        # zplot([frames_exit_shifted[20, :, :, 0].cpu(), frames_exit_shifted[20, :, :, 1].cpu()], cmap=['inferno', 'inferno'], figsize=(9, 5))
        grad_psi = complex_mul_conj(grad_Psi, frames_exit)
        grad_psi = th.fft(grad_psi, 2, True)
        ramp[im] *= -1
        # K x M1 x M2 x 2
        grad_psi = th.ifft(complex_mul(grad_psi, ramp), 2, True)
        grad_psi = th.sum(grad_psi, 0)
        grad_Psi = th.fft(grad_psi, 2, True)
        # print('i got here 2')


        grad_M = None
        grad_pos = None

        return grad_S, grad_M, grad_Psi, grad_pos

In [None]:
Ap = th.as_tensor(mean_dp * (mean_dp > mean_dp.max() * 5e-2)).cuda()

Psi_gen = ZernikeProbeSingle(q, metadata.wavelength, fft_shifted=True)

C_target = th.zeros((12)).cuda()
C_target[0] = 0

C_model = t.C
C_model[0] = 0

psi_target = Psi_gen(C_target, Ap)
psi_model = Psi_gen(C_target, Ap)
psi_model = psi_model.cuda()
psi_model.requires_grad_(False)

plotcx(complex_numpy(th.ifft(psi_model, 2, True).cpu()),'Psi_model')
# %%
size = [128,128]
start = [80,300]
d2 = d1.slice(np.s_[start[0]:start[0] + size[0], start[1]:start[1] + size[1]])

plot(d2.sum_diffraction())
# %%
dx = metadata.dr / r_min
theta = -66.1
r = advanced_raster_scan(d2.scan_dimensions[0], d2.scan_dimensions[1], fast_axis=1, mirror=[1, 1], theta=theta,
                         dy=dx[0], dx=dx[1])
margin = 10
M = th.tensor(d2.frame_dimensions).int()
N = th.tensor(np.ceil(r.max(axis=0)).astype(np.int32)) + margin + M
K = r.shape[0]

print('N:', N)
print('M:', M)

# A = Ptycho.apply
A = PtychoSubpix.apply
sparse_amplitude_loss_funtion = SparseAmplitudeLoss.apply

T_model = th.ones(tuple(N) + (2,), requires_grad=True, device=th.device('cuda:0'))

pos = th.as_tensor(r + margin/2, device=T_model.device)

indices_target = th.as_tensor(d2.indices, device=T_model.device)
counts_target = th.as_tensor(d2.counts, device=T_model.device)
ish = indices_target.shape

indices_target = indices_target.view((K, ish[-1]))
counts_target = th.sqrt(counts_target.view((K, ish[-1])).type(th.float32))

C_model.requires_grad = False
pos.requires_grad = False
optimizer = optim.Adam([T_model, C_model, pos], lr=1e-2)
# loss_function = F.MSELoss(reduction='sum')
# loss_function = F.PoissonNLLLoss(log_input=False, reduction='sum')
it = 30

n_batches = 4
divpoints = array_split_divpoints_ntotal(K, n_batches)
from tqdm import trange
i = 0

for i in trange(it):
    sum_loss = 0
    random_order = th.randperm(K)
    for b in range(n_batches):
    # b = 1
        take_ind = random_order[divpoints[b]:divpoints[b+1]]
        if i > 10:
            C_model.requires_grad = True

        # print(pos[100])
        # if i == 100:
        #     t = T_model.clone().detach().cpu().numpy()
        #     zplot([t[re], t[im]], cmap=['inferno', 'inferno'], figsize=(9, 5))
        optimizer.zero_grad()
        if i == 0:
            Ap1 = Ap + th.randn_like(Ap) *1e-4
        else:
            Ap1 = Ap
        Psi_model = Psi_gen(C_model, Ap1)
        a_model = A(T_model, M[:2], Psi_model, pos[take_ind])
        # a_model = A(T_model, a, pos)
        loss = sparse_amplitude_loss_funtion(a_model, indices_target[take_ind], counts_target[take_ind])
        sum_loss += loss.item()
        loss.backward()
        # if i > 10:
            # C_model.grad[1:] = 0
        # # print(pos.grad)
        # # print(psi_model.grad.sum())
        # t = complex_numpy(T_model.grad.clone().cpu())
        # t = rotate(t.real, -theta) + 1j * rotate(t.real, -theta)
        # plotcx(t[m:-m,m:-m])
        # zplot([t.real[m:-m, m:-m], t.imag[m:-m, m:-m]], title=['T_model.grad Re', 'T_model.grad Im'], cmap=['inferno', 'inferno'], figsize=(9, 5))
        optimizer.step()
    print(f'i {i} loss {sum_loss}, C_model = {C_model[0]} , C_target = {C_target[0]}')
#%%
from skimage.transform import rotate
m = 80
t = complex_numpy(T_model.clone().detach().cpu())
t = rotate(t.real,-theta) + 1j * rotate(t.imag,-theta)
# plotcx(t[m:-m,m:-m])
zplot([np.abs(t)[m:-m,m:-m], np.angle(t)[m:-m,m:-m]], title=['Abs', 'Angle'], cmap=['inferno', 'inferno'], figsize=(9, 5))