In [1]:
%matplotlib widget
from smpr3d.core import Sparse4DData, Metadata4D
from numpy.fft import fftshift, fft2
import matplotlib
import matplotlib.pyplot as plt 
import numpy as np
from matplotlib_scalebar.scalebar import ScaleBar
from pathlib import Path
from smpr3d.util import *
from mpl_toolkits.axes_grid1 import make_axes_locatable
from tifffile import imsave, imwrite

scan_number = 147

base_path = Path('/home/philipp/nvme/insync_berkeley/2020-09-22/results/ptycho/')
filename4d = Path('147.h5')

data = Sparse4DData.from_h5(base_path / filename4d, 'data').fftshift_()
meta = Metadata4D.from_h5(base_path / filename4d, 'meta')
meta

Metadata4D(scan_step=array([0.31626087, 0.31626087], dtype=float32), pixel_step=array([0.39762501, 0.39762501]), k_max=array([0.62863453, 0.62863453]), alpha_rad=0.025, rotation_deg=0.0, E_ev=80000.0, wavelength=0.041757171951176904, aberrations=array([-62.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.]))

In [5]:
f,ax = plt.subplots(figsize=(3,3))
imax = ax.imshow(data.sum_diffraction())
ax.set_title(f'Scan {scan_number} sum after cropping')
plt.colorbar(imax)
plt.tight_layout()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [6]:
radius = 12
abf = data.virtual_annular_image(radius/2, radius, data.frame_dimensions/2)
bf = data.virtual_annular_image(0, radius/2, data.frame_dimensions/2)
eabf = abf - bf

bf[bf==0] = bf.mean()
abf[abf==0] = abf.mean()

In [7]:

w,h = matplotlib.figure.figaspect(bf)
fig, ax = plt.subplots(1,2,dpi=150,figsize=(w*2, h))
im = ax[0].imshow(abf, cmap= plt.cm.get_cmap('bone'))
ax[0].set_title(f'Scan {scan_number} ABF')
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[0].add_artist(ScaleBar(meta.scan_step[0]/10,'nm'))
im = ax[1].imshow(bf, cmap= plt.cm.get_cmap('bone'))
ax[1].set_title(f'Scan {scan_number} BF')
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[1].add_artist(ScaleBar(meta.scan_step[0]/10,'nm'))


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib_scalebar.scalebar.ScaleBar at 0x7f4d2e6290a0>

In [10]:
import time
from smpr3d.util import get_qx_qy_1D, ZernikeProbeSingle, fourier_coordinates_2D, imsave
from ipywidgets import AppLayout, FloatSlider, GridspecLayout, VBox, Tab, Box, HBox, IntSlider
import ipywidgets as widgets
from matplotlib_scalebar.scalebar import ScaleBar
from numpy.fft import fftshift
from math import sin, cos
import time
import cupy as cp
import matplotlib.pyplot as plt
from cupy.cuda import Device
import cupyx.scipy.fft as fft
from IPython.display import display
from skimage.filters import gaussian
from matplotlib.patches import Rectangle
plt.ioff()

out0 = widgets.Output(layout={'border': '1px solid black'})

class InteractiveSSB:
    def __init__(self, sparse_data, slice_image, radius, meta, window_size=128):
        self.out = widgets.Output(layout={'border': '1px solid black'})
        self.meta = meta
        self.data = sparse_data
        self.ssb_size = 24
        self.slic = np.s_[0:sparse_data.scan_dimensions[0],0:sparse_data.scan_dimensions[1]]
        self.rmax = sparse_data.frame_dimensions[-1] // 2
        self.alpha_max = self.rmax / radius * meta.alpha_rad
        self.slice_image = slice_image
        self.window_center = sparse_data.scan_dimensions//2
        self.window_size = sparse_data.scan_dimensions[0] 

        r_min = meta.wavelength / (2 * self.alpha_max)
        r_min = np.array([r_min, r_min])
        self.k_max = [self.alpha_max / meta.wavelength, self.alpha_max / meta.wavelength]
        self.r_min = np.array(r_min)
        self.dxy = np.array(meta.scan_step)
        
        self.margin = sparse_data.frame_dimensions // 2
        self.rotation_deg = 0.0
        
        self.out.append_stdout(f"alpha_max       = {self.alpha_max * 1e3:2.2f} mrad\n")
        self.out.append_stdout(f"E               = {meta.E_ev/1e3} keV\n")
        self.out.append_stdout(f"λ               = {meta.wavelength * 1e2:2.2}   pm\n")
        self.out.append_stdout(f"dR              = {self.dxy[0]:2.2f}             Å\n")
        self.out.append_stdout(f"dK              = {self.k_max[0]:2.2f}          Å^-1\n")
        self.out.append_stdout(f"scan       size = {sparse_data.scan_dimensions}\n")
        self.out.append_stdout(f"detector   size = {sparse_data.frame_dimensions}\n")
        
        self.rotation_slider = widgets.FloatSlider(
            value=0,
            min=-180,
            max=180,
            step=0.1,
            description='STEM rotation angle:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
        )
        self.rotation_slider.observe(self.rotation_changed, 'value')
        
        self.window_size_slider = widgets.IntSlider(
            value=sparse_data.scan_dimensions[0],
            min=2,
            max=np.max(sparse_data.scan_dimensions),
            step=2,
            description='Window size',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True
        )
        self.window_size_slider.observe(self.window_size_slider_changes, 'value')
        
        self.aberration_text = widgets.HTML(
                value="1",
                placeholder='',
                description='',
            )
        
        
        self.gs = GridspecLayout(4,9)
        self.Cslider_box = VBox(width=50)
        self.scale_slider_box = VBox()
        children= []
        sliders =  []

        self.probe_output = widgets.Output()
        self.overlaps_output = widgets.Output()
        
        qn = fourier_coordinates_2D([128,128],r_min/1.6)
        q = th.as_tensor(qn).cuda()
        alpha = q * meta.wavelength
        self.probe_gen = ZernikeProbeSingle(q, meta.wavelength, fft_shifted=False)
        self.A = np.linalg.norm(qn,axis=0) * meta.wavelength < meta.alpha_rad
        self.A = gaussian(self.A,1)
                
        self.C = cp.zeros((12,))
        self.C_names = ['C1','C12a' ,'C12b','C21a','C21b','C23a','C23b','C3','C32a','C32b','C34a','C34b']
        self.C_min = [-120,-20,-20,-50,-50,-50,-50,-20,-20,-20,-20,-20]
        self.C_max = [120,20,20,50,50,50,50,20,20,20,20,20]
        self.C_multiplier = [1e1,1e1,1e1,1e1,1e1,1e1,1e4,1e4,1e4,1e4,1e4]
        
        Psi = self.probe_gen(th.tensor(self.C.get()).cuda(),th.tensor(self.A).cuda())
        self.phases = th.angle(th.fft.fftshift(self.probe_gen(th.tensor(self.C.get()).cuda(),th.tensor(np.ones_like(self.A)).cuda())))
        self.Psi_shifted = th.fft.fftshift(Psi)
        self.psi = th.fft.fftshift(th.fft.ifft2(Psi))
        
        with self.probe_output:
            self.probe_figure = plt.figure(constrained_layout=True,figsize=(9,3))
            gs1 = self.probe_figure.add_gridspec(1, 3, wspace=0.05,hspace=0.05)
            self.f1ax0 = self.probe_figure.add_subplot(gs1[0])
            self.f1ax1 = self.probe_figure.add_subplot(gs1[1])
            self.f1ax2 = self.probe_figure.add_subplot(gs1[2])
            
            self.f1ax0.set_title('Probe  (real space)')
            self.f1ax1.set_title('Probe  (Fourier space)')
            self.f1ax2.set_title('Phase profile (Fourier space)')
            
            self.probe_realspace_imax = self.f1ax0.imshow(imsave(self.psi.cpu().numpy()))
            self.probe_fourier_imax = self.f1ax1.imshow(imsave(self.Psi_shifted.cpu().numpy()))
            self.probe_phases_imax = self.f1ax2.imshow(self.phases.cpu().numpy())
            #,cmap=plt.cm.get_cmap('Greys')

            self.f1ax0.set_xticks([])
            self.f1ax0.set_yticks([])
            self.f1ax1.set_xticks([])
            self.f1ax1.set_yticks([])
            self.f1ax2.set_xticks([])
            self.f1ax2.set_yticks([])
            
        self.overlap_figure_axes = []
        with self.overlaps_output:
            self.overlap_figure = plt.figure(constrained_layout=True,figsize=(9,9))
            gs1 = self.overlap_figure.add_gridspec(3, 3, wspace=0.05,hspace=0.05)
            for ggs in gs1:
                f3_ax1 = self.overlap_figure.add_subplot(ggs)
                imax2 = f3_ax1.imshow(np.zeros((40,40)))
                f3_ax1.set_xticks([])
                f3_ax1.set_yticks([])
                self.overlap_figure_axes.append(imax2)
                
        self.plot_box = VBox(children =[self.probe_figure.canvas,self.overlap_figure.canvas])    
        self.recon_fig, self.recon_axes = plt.subplots(figsize=(9,9))
        self.recon_img = self.recon_axes.imshow(np.zeros(sparse_data.scan_dimensions), cmap=plt.get_cmap('bone'))
        self.recon_axes.set_xticks([])
        self.recon_axes.set_yticks([])
        scalebar = ScaleBar(meta.scan_step[0]/10,'nm') # 1 pixel = 0.2 meter
        self.recon_axes.add_artist(scalebar)

        for i, (name, mins, maxs, multiplier) in enumerate(zip(self.C_names, self.C_min, self.C_max, self.C_multiplier)):
            s = FloatSlider(description=name,
                           min=mins, max = maxs)
            s.observe(self.create_function(f'slider_changed_{i}', i, multiplier), names='value')
            sliders.append(s)
            children.append(s)

        self.Cslider_box.children = children + [self.aberration_text]

        self.gs[:2,0] = self.Cslider_box
        self.gs[2:,0] = self.scale_slider_box
        self.gs[:,1:5] = self.plot_box
        self.gs[:,5:9] = self.recon_fig.canvas
        
        self.first_time_calc = True
    def __del__(self):
        del self.recon_fig
        del self.overlap_figure
        del self.probe_figure
        
    def rotation_changed(self, change):
        self.rotation_deg = change['new']
        
    @out0.capture()
    def update_slice(self):
        y0 = self.window_center[0] - self.window_size[0]//2
        y1 = self.window_center[0] + self.window_size[0]//2
        x0 = self.window_center[1] - self.window_size[1]//2
        x1 = self.window_center[1] + self.window_size[1]//2
        
        y0 = int(y0 if y0 > 0 else 0)
        y1 = int(y1 if y1 <= self.data.scan_dimensions[0] else self.data.scan_dimensions[0])
        x0 = int(x0 if x0 > 0 else 0)
        x1 = int(x1 if x1 <= self.data.scan_dimensions[1] else self.data.scan_dimensions[1])
        
        if y0 == 0:
            y1 = (y1//2)*2
            self.window_size[0] = y1-y0
            
        if x0 == 0:
            x1 = (x1//2)*2
            self.window_size[1] = x1-x0
            
        if y1 == self.data.scan_dimensions[0] and y0 % 2 > 0:
            y0 -= 1
            self.window_size[0] = self.window_size[0] + 1
            
        if x1 == self.data.scan_dimensions[0] and x0 % 2 > 0:
            x0 -= 1
            self.window_size[1] = self.window_size[1] + 1
        
#         y1 = y1 if y0 > 0 else (y1 // 2) * 2
#         x1 = x1 if x0 > 0 else (x1 // 2) * 2
        
        self.slic = np.s_[y0:y1,x0:x1]
        self.out.append_stdout(f"slice {y0},{y1},{x0},{x1}\n")
        
        self.slice_rect.set_xy((x0,y0))
        self.slice_rect.set_height(y1-y0)
        self.slice_rect.set_width(x1-x0)
        
        self.G = self._get_G(self.ssb_size)
        self.Gabs = cp.sum(cp.abs(self.G), (2, 3))
        
        sh = np.array(self.Gabs.shape)
        mask = ~np.array(fftshift(sector_mask(sh, sh // 2, sh[0]/30, (0, 360))))
        mask[:,-1] = 0
        mask[:,0] = 0
        mask[:,1] = 0
        
        gg = np.log10(self.Gabs.get()+1)
        gg[~mask] = gg.mean()
        gg = fftshift(gg)
        gg1 = gg#[self.margin[0]:-self.margin[0],self.margin[1]:-self.margin[1]]
        self.imax_power_spectrum.set_data(gg1)
        self.imax_power_spectrum.set_clim(gg1.min(),gg1.max())
        
        self.fig_select_slice.canvas.draw()
        self.fig_select_slice.canvas.flush_events()
        self.fig_power_spectrum.canvas.draw()
        self.fig_power_spectrum.canvas.flush_events()
    
    @out0.capture()
    def window_size_slider_changes(self, change):
        self.window_size = [change['new'],change['new']]
        self.update_slice()
            
    @out0.capture()
    def fig_select_slice_onclick(self, event):
        ix, iy = event.xdata, event.ydata
#         self.out.append_stdout(f"clicked ix:{ix},iy:{iy}\n")
        self.window_center = [iy, ix]
        self.update_slice()
        
    @out0.capture()
    def _get_G(self, size):
        bin_factor = int(np.min(np.floor(self.data.frame_dimensions/size)))
        start = time.perf_counter()
        dc = self.dc[self.slic]
        
        M = cp.array(dc, dtype=cp.complex64)
        ny, nx, nky, nkx = M.shape
        start = time.perf_counter()
        G = fft.fft2(M, axes=(0, 1), overwrite_x=True)
        G /= cp.sqrt(np.prod(G.shape[:2]))
        
        return G
    
    @out0.capture()
    def _get_G_full(self, size):
        bin_factor = int(np.min(np.floor(self.data.frame_dimensions/size)))
        start = time.perf_counter()
        data = self.data.slice(self.slic)
        self.dc = data.to_dense(bin_factor)
        self.out.append_stdout(f"Bin by {bin_factor} for ssb took {time.perf_counter() - start:2.2g}s\n")
        self.out.append_stdout(f"Data shape is {self.dc.shape}\n")
        
        self.Qy1d, self.Qx1d = get_qx_qy_1D(data.scan_dimensions, self.dxy, cp.float32, fft_shifted=False)
        self.Ky, self.Kx = get_qx_qy_1D(self.dc.shape[-2:], self.r_min, cp.float32, fft_shifted=True)
        
        self.Psi_Qp = cp.zeros(data.scan_dimensions, dtype=np.complex64)
        self.Psi_Qp_left_sb = cp.zeros(data.scan_dimensions, dtype=np.complex64)
        self.Psi_Qp_right_sb = cp.zeros(data.scan_dimensions, dtype=np.complex64)
        self.Psi_Rp = cp.zeros(data.scan_dimensions, dtype=np.complex64)
        self.Psi_Rp_left_sb = cp.zeros(data.scan_dimensions, dtype=np.complex64)
        self.Psi_Rp_right_sb = cp.zeros(data.scan_dimensions, dtype=np.complex64)
        
        M = cp.array(self.dc, dtype=cp.complex64)
        ny, nx, nky, nkx = M.shape
        start = time.perf_counter()
        G = fft.fft2(M, axes=(0, 1), overwrite_x=True)
        G /= cp.sqrt(np.prod(G.shape[:2]))
        
        self.out.append_stdout(f"FFT along scan coordinate took {time.perf_counter() - start:2.2g}s\n")
        return G
    
    @out0.capture()
    def update_variables(self):
        self.Psi_Qp[:] = 0
        self.Psi_Qp_left_sb[:] = 0
        self.Psi_Qp_right_sb[:] = 0
            
        eps=1e-3
        start = time.perf_counter()
        single_sideband_reconstruction(
                self.G,
                self.Qx1d,
                self.Qy1d,
                self.Kx,
                self.Ky,
                self.C,
                np.deg2rad(self.rotation_deg),
                self.meta.alpha_rad,
                self.Psi_Qp,
                self.Psi_Qp_left_sb,
                self.Psi_Qp_right_sb,
                eps,
                self.meta.wavelength,
            )
        m = 2

        self.Psi_Rp[:] = fft.ifft2(self.Psi_Qp, norm="ortho")
        self.Psi_Rp_left_sb[:] = fft.ifft2(self.Psi_Qp_left_sb, norm="ortho")
        self.Psi_Rp_right_sb[:] = fft.ifft2(self.Psi_Qp_right_sb, norm="ortho")
        
        self.Gamma = disk_overlap_function(self.Qx_max1d, self.Qy_max1d, self.Kx, self.Ky, self.C, np.deg2rad(self.rotation_deg), self.meta.alpha_rad, self.meta.wavelength)
        
        Psi = self.probe_gen(th.tensor(self.C.get()).cuda(),th.tensor(self.A).cuda())
        self.phases = th.angle(th.fft.fftshift(self.probe_gen(th.tensor(self.C.get()).cuda(),th.tensor(self.A).cuda())))
        self.Psi_shifted = th.fft.fftshift(Psi)
        self.psi = th.fft.fftshift(th.fft.ifft2(Psi))
            
    @out0.capture()
    def update_gui(self):
        gg = self.Gamma * self.G_max

        m=10
        img = np.angle(self.Psi_Rp_left_sb.get()[m:-m,m:-m])
        self.recon_img.set_data(img)
        self.recon_img.set_clim(img.min(),img.max())

        for ax, g in zip(self.overlap_figure_axes,gg):
            ax.set_data(imsave(g.get()))
        
        self.probe_realspace_imax.set_data(imsave(self.psi.cpu().numpy()))
        self.probe_fourier_imax.set_data(imsave(self.Psi_shifted.cpu().numpy()))
        self.probe_phases_imax.set_data(self.phases.cpu().numpy())
        self.probe_phases_imax.set_clim(self.phases.min(),self.phases.max())
        
        self.recon_fig.canvas.draw()
        self.overlap_figure.canvas.draw()
        self.probe_figure.canvas.draw()
        
        self.recon_fig.canvas.flush_events()
        self.overlap_figure.canvas.flush_events()
        self.probe_figure.canvas.flush_events()
    
    @out0.capture()
    def create_function(self, name, i, multiplier):
        def func1(change):
            self.C[i] = change['new']*multiplier
            w = change['new']*multiplier
#             self.out.append_stdout(f"C[{i}] = {self.C[i]}\n")
            
            self.update_variables()
            self.update_gui()
            
        func1.__name__ = name
        return func1
    
    
    @out0.capture()
    def selected_index_changed(self, change):
        w = change['new']
        if w == 1:
#             if self.first_time_calc:
            n_fit = 9
            self.G = self._get_G_full(self.ssb_size)
            self.Gabs = cp.sum(cp.abs(self.G), (2, 3))
            self.out.append_stdout(f"self.G.shape {self.G.shape}\n")

            sh = np.array(self.Gabs.shape)
            mask = ~np.array(fftshift(sector_mask(sh, sh // 2, 15, (0, 360))))
            mask[:,-1] = 0
            mask[:,0] = 0
            mask[:,1] = 0

            gg = self.Gabs.get()
            gg[~mask] = gg.mean()

            inds = np.argsort((gg).ravel())
            strongest_object_frequencies = np.unravel_index(inds[-1 - n_fit:-1], self.G.shape[:2])

            self.out.append_stdout(f"strongest_object_frequencies: {strongest_object_frequencies}\n")

            self.G_max = self.G[strongest_object_frequencies]
            self.Qy_max1d = self.Qy1d[strongest_object_frequencies[0]]
            self.Qx_max1d = self.Qx1d[strongest_object_frequencies[1]]
                
            self.update_variables()
            self.update_gui()
            
        elif w==2:
            data = dssb 
            bright_field_radius = radius

            meta = metadata

            defocus_list_nm = np.linspace(-300,300,60,endpoint=True)  + 0
            print(f'defocus_list_nm: {defocus_list_nm}')
            defocal = []
            defocal_right = []
            defocal_left = []

            for df in defocus_list_nm:
                print(df)
                C = xp.zeros((12,))
                C[0] = df
                Psi_Qp = cp.zeros((ny, nx), dtype=G.dtype)
                Psi_Qp_left_sb = cp.zeros((ny, nx), dtype=np.complex64)
                Psi_Qp_right_sb = cp.zeros((ny, nx), dtype=np.complex64)
                start = time.perf_counter()
                eps = 1e-3
                single_sideband_reconstruction(
                        self.G,
                        self.Qx1d,
                        self.Qy1d,
                        self.Kx,
                        self.Ky,
                        self.C,
                        np.deg2rad(self.rotation_deg),
                        self.meta.alpha_rad,
                        self.Psi_Qp,
                        self.Psi_Qp_left_sb,
                        self.Psi_Qp_right_sb,
                        eps,
                        self.meta.wavelength,
                    )
                self.Psi_Rp[:] = fft.ifft2(self.Psi_Qp, norm="ortho")
                self.Psi_Rp_left_sb[:] = fft.ifft2(self.Psi_Qp_left_sb, norm="ortho")
                self.Psi_Rp_right_sb[:] = fft.ifft2(self.Psi_Qp_right_sb, norm="ortho")

                defocal.append(Psi_Rp.get())
                defocal_right.append(Psi_Rp_right_sb.get())
                defocal_left.append(Psi_Rp_left_sb.get())

            ssb_defocal = np.array(defocal)
            ssb_defocal_right = np.array(defocal_right)
            ssb_defocal_left = np.array(defocal_left)
            
#             fig, ax = plt.subplots(dpi=150)

#             var1 = np.var(ssb_defocal[...,m:-m,m:-m],(1,2))
#             var2 = np.var(ssb_defocal_right[...,m:-m,m:-m],(1,2))

#             for i, v in enumerate([var1,var2,var3]):
#                 ax.scatter(np.arange(len(v)), v, label=f'{i}', s=3)
#                 ax.plot(v, label=f'{i}')
#             plt.legend()
            
#             mxs = np.argmax(var1)
#             best_df = defocus_list_nm[mxs]
#             mxs2 = np.argmax(var2)
#             best_df2 = defocus_list_nm[mxs2]
            
            self.out.append_stdout(f'best_df ssb_defocal: {mxs}:{best_df}\n')
            self.out.append_stdout(f'best_df ssb_defocal_right: {mxs2}:{best_df2}\n')
    
    def show(self):
        n_fit = 49
        ranges=[360, 30]
        partitions = [144, 120]
        ranges = [360]
        partitions=[360]
        manual_frequencies=None
        manual_frequencies = None
        
        bin_factor = int(np.min(np.floor(self.data.frame_dimensions/12)))
        data = self.data.slice(self.slic)
        self.dc = data.to_dense(bin_factor)
        
        M = cp.array(self.dc, dtype=cp.complex64)
        ny, nx, nky, nkx = M.shape
        self.G = fft.fft2(M, axes=(0, 1), overwrite_x=True)
        self.G /= cp.sqrt(np.prod(self.G.shape[:2]))
        self.Gabs = cp.sum(cp.abs(self.G), (2, 3))
        
        sh = np.array(self.Gabs.shape)
        mask = ~np.array(fftshift(sector_mask(sh, sh // 2, 15, (0, 360))))
        mask[:,-1] = 0
        mask[:,0] = 0
        mask[:,1] = 0
        
        gg = np.log10(self.Gabs.get()+1)
        gg[~mask] = gg.mean()
        gg = fftshift(gg)
        gg1 = gg#[self.margin[0]:-self.margin[0],self.margin[1]:-self.margin[1]]
        
        self.fig_select_slice, self.ax_select_slice = plt.subplots(1,1,figsize=(8,8))
        self.fig_select_slice.canvas.mpl_connect('button_press_event', self.fig_select_slice_onclick)
        self.imax_select_slice = self.ax_select_slice.imshow(self.slice_image, cmap=plt.cm.get_cmap('magma'))
        self.slice_rect = Rectangle((self.slic[0].start,self.slic[1].start), self.window_size, self.window_size, fill=False, lw=3,ls='--')
        self.ax_select_slice.add_patch(self.slice_rect)        
        
        self.fig_power_spectrum, self.ax_power_spectrum = plt.subplots(1,1,figsize=(8,8))
        self.imax_power_spectrum = self.ax_power_spectrum.imshow(gg1, cmap=plt.cm.get_cmap('magma'))
        

        best, thetas, intensities = find_rotation_angle_with_double_disk_overlap(self.G, self.meta.wavelength, self.r_min, self.dxy, self.meta.alpha_rad, mask=cp.array(mask), ranges=ranges,
                                                                  partitions=partitions, n_fit=n_fit, verbose=False,
                                                                  manual_frequencies=manual_frequencies, aberrations=self.C)

        self.out.append_stdout(f"Best rotation angle: {np.rad2deg(thetas[best])}\n")
        
        fig, ax = plt.subplots()
        ax.scatter(np.rad2deg(thetas),intensities)
        ax.set_xlabel('STEM rotation [degrees]')
        ax.set_ylabel('Integrated G amplitude over double overlap')
                      
        rot_box = VBox([fig.canvas,self.rotation_slider])
        
        window_size_box = VBox([self.window_size_slider, self.fig_select_slice.canvas])
        
        canvas_box = HBox([window_size_box, self.fig_power_spectrum.canvas,  rot_box])
        gsl1 = VBox([canvas_box])
        gsl2 = GridspecLayout(1,1)
        
        tab_contents = ['P0', 'P1','P2']
        children = [gsl1,self.gs]
        tab = widgets.Tab()
        tab.children = children
        tab.titles = ['fft','00','00']
        
        tab.observe(self.selected_index_changed, 'selected_index')
        
        gsl00 = GridspecLayout(20,20)
        gsl00[3:,:3] = self.out
        gsl00[3:,3:] = tab
        gsl00[:3,:] = out0
        return gsl00

t = InteractiveSSB(data, bf,radius, meta)
t.show()

radius_data_int : 12 
radius_max_int  : 12 
Dense frame size: 12x 12
batch 0: dtypes: torch.float32 torch.int16 torch.float32 int64
batch 1: dtypes: torch.float32 torch.int16 torch.float32 int64
batch 2: dtypes: torch.float32 torch.int16 torch.float32 int64
batch 3: dtypes: torch.float32 torch.int16 torch.float32 int64


GridspecLayout(children=(Output(layout=Layout(border='1px solid black', grid_area='widget001'), outputs=({'out…