# TEM Generate Measurements needed for reconstruction
Author: David Ren (david.ren@berkeley.edu)

7/16/2018

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio
import os
import contexttimer
import TEM_misc
import sys
import h5py
sys.path.append("%s/%s"   % (os.getcwd(), "tomography_gpu/"))
from opticaltomography.opticsutil import compare3DStack, show3DStack
from opticaltomography.opticsalg import PhaseObject3D, TomographySolver, AlgorithmConfigs

### Specify Parameters

#### Datapath

In [None]:
# Specify datapath
data_path  = "%s/%s"   % (os.getcwd(), "../data/")
outdir      = "%s/%s"   % (data_path, "measurement/")

# Specify filename
fn = "TEM_simulation_480_SiO2" + "_py"
flag_save = False

# Load potentials
fn_load = "TEM_simulation_480_SiO2"

#### Dataset parameters
Everything is measured in Angstroms

In [None]:
#Same setting as the paper
slice_binning_factor = 1
flag_plot = True
flag_rotation_pad = True
flag_angle_defocus_override = False

flag_crop_fov = True
pad_size = (60, 60)
if flag_crop_fov:
    pass
else:
    pad_size = (0, 0)

### Load Data

In [None]:
print("Loading data...")
try:
    print("Using sio")
    data = sio.loadmat(outdir+fn_load+".mat")
except:
    data_file = h5py.File(outdir+fn_load+".mat")
    data = {}
    for k, v in data_file.items():
        data[k] = np.array(v)
    data["pot_stack"] = data["pot_stack"].transpose(2,1,0)
dz = np.asscalar(data["dz"])
na = np.asscalar(data["na"])
# na = 1
sigma = np.asscalar(data["sigma"])
wavelength = np.asscalar(data["wavelength"])
pixel_size = np.asscalar(data["pixel_size"])
tilt_angles = data["tilt_angles"].ravel()
defocus_stack = data["defocus_stack"].ravel()
print("Done loading data...")
print(data.keys())

### Override tile angles and defocus stack

In [None]:
if flag_angle_defocus_override:
    tilt_angles = np.arange(-90,90)
    defocus_stack = np.arange(200,1100,100)
    print(tilt_angles)
    print(defocus_stack)    

### Generate potentials
This part generates the coordinates of the atoms if necessary
Only works for a single atom

In [None]:
pot_stack = data["pot_stack"]

if flag_plot:
    plt.figure(figsize=(5,4))
    plt.imshow(np.sum(pot_stack,axis = 2))
    plt.colorbar()

### Generate measurements

#### Generate Low pass filtered version (Ground truth)

In [None]:
if na < 1:
    pot_stack_gt = TEM_misc.apply3DPupil(pot_stack, pixel_size, na, wavelength)
    if flag_plot:
        plt.figure(figsize=(5,4))
        plt.imshow(np.sum(pot_stack_gt,axis = 2))
        plt.colorbar()

#### Apply forward model

In [None]:
number_angles = tilt_angles.size
number_defocus = defocus_stack.size

voxel_size = (pixel_size, pixel_size, dz)
phase_obj_3d = PhaseObject3D(shape = pot_stack.shape, voxel_size = voxel_size)

fx_illu_list = [0]
fy_illu_list = [0]
solver_params = dict(wavelength = wavelength, na = na, \
                     propagation_distance_list = defocus_stack, rotation_angle_list = tilt_angles, \
                     RI_measure = 1.0, sigma = sigma * pixel_size, \
                     fx_illu_list = fx_illu_list, fy_illu_list = fy_illu_list, \
                     pad = flag_crop_fov, pad_size = pad_size,\
                     slice_binning_factor = slice_binning_factor,\
                     rotation_pad = flag_rotation_pad)

solver_obj    = TomographySolver(phase_obj_3d, **solver_params)
solver_obj.setScatteringMethod(model = "MultiPhaseContrast")
solver_obj._x = pot_stack.astype("complex64")

In [None]:
with contexttimer.Timer() as timer:
    amplitude_measure = np.squeeze(solver_obj.forwardPredict())
    print(timer.elapsed)

#### Check final measurement shapes

In [None]:
np.squeeze(np.abs(amplitude_measure[:,:,:,:])).shape

#### Visualize measurements

In [None]:
compare3DStack(np.squeeze(np.abs(amplitude_measure[:,:,0,:])), np.squeeze(np.abs(amplitude_measure[:,:,2,:])), clim=(0.5,1.5))

#### Save datasets

In [None]:
if flag_save:
    data = {}
    # Ground truth
    data["pot_stack"] = pot_stack
    if na < 1:
        data["pot_stack_gt"] = pot_stack_gt
    
    # Measurement 
    data["dz"] = pixel_size
    data["pixel_size"] = pixel_size
    data["wavelength"] = wavelength
    data["sigma"] = sigma
    data["amplitude_measure"] = amplitude_measure
    data["tilt_angles"] = tilt_angles
    data["defocus_stack"] = defocus_stack
    data["na"] = na
    data["obj_shape"] = pot_stack.shape
    sio.savemat(outdir+fn+".mat", data)