# 3D Phase Contrast AET Script
Author: David Ren (david.ren@berkeley.edu)

## Load packages
Additional packages needed: contexttimer (timing purposes), arrayfire (for GPU computation)

In [1]:
#Specify code location
import sys
sys.path.append('/home/general/TEMcode/tomography_gpu/')

In [2]:
%reload_ext autoreload
%autoreload 2
import numpy as np
import TEM_recon
import scipy.io as sio
import scipy.linalg as la
import os
import sys
np.set_printoptions(threshold=np.nan)
import contexttimer
print("hostname:", os.uname()[1])
import arrayfire as af
af.set_device(0)

hostname: f00f9321f438


In [3]:
# Specify datapath
data_path  = "%s/%s"   % (os.getcwd(), "../data/")
indir      = "/home/general/TEMdata/"
outdir      = "/home/general/TEMrec/"

In [4]:
#Experiments to be run
exps = [-1]

In [5]:
#All experiments available
exp_all = {      
    -1:    {"dset": "sim_sio2_full",
              "flag_save_output" : True,
              "fista_L": 1e-4,
              "random_order": True,
              "maxitr": 1,
              "maxitr_per_angle": 1, 
              "update_mode": "FISTA",
              "flag_reg": True,
              "reg_type": ["positivity_and_real", "tv"],         
              "reg_params": dict(reg_tv = 1.0, max_iter_tv = 15),
              "flag_add_noise": True,
              "dose_per_pixel": 50000./60./3./4.,    
              "slice_binning_factor": 10,
              },         
    #Results - DOSE
    #Figure 5(a) -- infinite dose
    5004:    {"dset": "sim_sio2_full",
              "flag_save_output" : True,
              "fista_L": 1e-4,
              "random_order": True,
              "maxitr": 40,
              "maxitr_per_angle": 1, 
              "update_mode": "FISTA",
              "flag_reg": True,
              "reg_type": ["positivity_and_real", "tv"],         
              "reg_params": dict(reg_tv = 1.0, max_iter_tv = 15),
              "slice_binning_factor": 10,
              },             
    #Figures 5(b), 6(b)， 7(a), 8(c), 9(b) -- 50 000 e/A^2
    5106:    {"dset": "sim_sio2_full",
              "flag_save_output" : True,
              "fista_L": 1e-4,
              "random_order": True,
              "maxitr": 40,
              "maxitr_per_angle": 1, 
              "update_mode": "FISTA",
              "flag_reg": True,
              "reg_type": ["positivity_and_real", "tv"],         
              "reg_params": dict(reg_tv = 1.0, max_iter_tv = 15),
              "flag_add_noise": True,
              "dose_per_pixel": 50000./60./3./4.,    
              "slice_binning_factor": 10,
              },     
    #Figure 5(c) -- 7 000 e/A^2
    5101:    {"dset": "sim_sio2_full",
              "flag_save_output" : True,
              "fista_L": 1e-4,
              "random_order": True,
              "maxitr": 40,
              "maxitr_per_angle": 1, 
              "update_mode": "FISTA",
              "flag_reg": True,
              "reg_type": ["positivity_and_real", "tv"],         
              "reg_params": dict(reg_tv = 2.5, max_iter_tv = 15),
              "flag_add_noise": True,
              "dose_per_pixel": 7000./60./3./4.,    
              "slice_binning_factor": 10,
              },         
    #TILT ANGLES vs. DEFOCUS
    #Figure 6(a) -- 20 tilt angles, 9 defocus
    5200:    {"dset": "sim_sio2_more_measure",
              "flag_save_output" : True,
              "fista_L": 3e-4,
              "random_order": True,
              "maxitr": 40,
              "maxitr_per_angle": 1, 
              "update_mode": "FISTA",
              "flag_reg": True,
              "reg_type": ["positivity_and_real", "tv"],         
              "reg_params": dict(reg_tv = 0.5, max_iter_tv = 15),
              "flag_add_noise": True,
              "dose_per_pixel": 50000./20./9./4.,    
              "slice_binning_factor": 10,
              "rotation_use": np.array(range(0,180,9))
              },       
    #Figure 6(c) -- 180 tilt angles, 1 defocus
    5202:    {"dset": "sim_sio2_more_measure",
              "flag_save_output" : True,
              "fista_L": 3e-5,
              "random_order": True,
              "maxitr": 40,
              "maxitr_per_angle": 1, 
              "update_mode": "FISTA",
              "flag_reg": True,
              "reg_type": ["positivity_and_real", "tv"],         
              "reg_params": dict(reg_tv = 3.0, max_iter_tv = 15),
              "flag_add_noise": True,
              "dose_per_pixel": 50000./180./1./4.,    
              "slice_binning_factor": 10,
              "defocus_use": np.array([2])
              },    
    #MISSING WEDGE
    #Figure 7(b) -- -75 to +75, 30 degrees of missing wedge
    5301:    {"dset": "sim_sio2_full",
              "flag_save_output" : True,
              "fista_L": 1e-4,
              "random_order": True,
              "maxitr": 40,
              "maxitr_per_angle": 1, 
              "update_mode": "FISTA",
              "flag_reg": True,
              "reg_type": ["positivity_and_real", "tv"],         
              "reg_params": dict(reg_tv = 1.0, max_iter_tv = 15),
              "flag_add_noise": True,
              "dose_per_pixel": 50000./50./3./4.,    
              "slice_binning_factor": 10,
              "rotation_use": np.arange(5,55)
              },       
    #Figure 7(c) -- -60 to +60, 60 degrees of missing wedge
    5303:    {"dset": "sim_sio2_full",
              "flag_save_output" : True,
              "fista_L": 1e-4,
              "random_order": True,
              "maxitr": 40,
              "maxitr_per_angle": 1, 
              "update_mode": "FISTA",
              "flag_reg": True,
              "reg_type": ["positivity_and_real", "tv"],         
              "reg_params": dict(reg_tv = 0.5, max_iter_tv = 15),
              "flag_add_noise": True,
              "dose_per_pixel": 50000./40./3./4.,    
              "slice_binning_factor": 10,
              "rotation_use": np.arange(10,50)
              },     
    #REGULARIZATION
    #Figure 8(a) -- positivity and real contraints only
    5110:    {"dset": "sim_sio2_full",
              "flag_save_output" : True,
              "fista_L": 1e-4,
              "random_order": True,
              "maxitr": 60,
              "maxitr_per_angle": 1, 
              "update_mode": "FISTA",
              "flag_reg": True,
              "reg_type": ["positivity_and_real"],         
              "flag_add_noise": True,
              "dose_per_pixel": 50000./60./3./4.,    
              "slice_binning_factor": 10,
              }, 
    #Figure 8(b) -- total variation regularization
    5109:    {"dset": "sim_sio2_full",
              "flag_save_output" : True,
              "fista_L": 1e-4,
              "random_order": True,
              "maxitr": 40,
              "maxitr_per_angle": 1, 
              "update_mode": "FISTA",
              "flag_reg": True,
              "reg_type": "lasso", 
              "reg_params": dict(reg_lasso = 3.0),
              "flag_add_noise": True,
              "dose_per_pixel": 50000./60./3./4.,    
              "slice_binning_factor": 10,
              },           
    #Figure 9(c) -- Lasso regularization
    5500:    {"dset": "sim_sio2_full_vacant",
              "flag_save_output" : True,
              "fista_L": 1e-4,
              "random_order": True,
              "maxitr": 40,
              "maxitr_per_angle": 1, 
              "update_mode": "FISTA",
              "flag_reg": True,
              "reg_type": ["positivity_and_real", "tv"],         
              "reg_params": dict(reg_tv = 1.0, max_iter_tv = 15),
              "flag_add_noise": True,
              "dose_per_pixel": 50000./60./3./4.,    
              "slice_binning_factor": 10,
              },   
    #APPENDIX
    #Slice binning, no binning
    5400:    {"dset": "sim_sio2_full",
              "flag_save_output" : True,
              "fista_L": 1e-4,
              "random_order": True,
              "maxitr": 80,
              "maxitr_per_angle": 1, 
              "update_mode": "FISTA",
              "flag_reg": True,
              "reg_type": ["positivity_and_real"],        
              },      
    #Slice binning, 4x   
    5401:    {"dset": "sim_sio2_full",
              "flag_save_output" : True,
              "fista_L": 1e-4,
              "random_order": True,
              "maxitr": 80,
              "maxitr_per_angle": 1, 
              "update_mode": "FISTA",
              "flag_reg": True,
              "reg_type": ["positivity_and_real"],  
              "slice_binning_factor": 4,
              },      
    #Slice binning, 8x   
    5402:    {"dset": "sim_sio2_full",
              "flag_save_output" : True,
              "fista_L": 1e-4,
              "random_order": True,
              "maxitr": 80,
              "maxitr_per_angle": 1, 
              "update_mode": "FISTA",
              "flag_reg": True,
              "reg_type": ["positivity_and_real"],    
              "slice_binning_factor": 8,
              },      
    #Slice binning, 16x   
    5403:    {"dset": "sim_sio2_full",
              "flag_save_output" : True,
              "fista_L": 1e-4,
              "random_order": True,
              "maxitr": 80,
              "maxitr_per_angle": 1, 
              "update_mode": "FISTA",
              "flag_reg": True,
              "reg_type": ["positivity_and_real"],   
              "slice_binning_factor": 16,
              },      
    #Slice binning, 32x   
    5404:    {"dset": "sim_sio2_full",
              "flag_save_output" : True,
              "fista_L": 1e-4,
              "random_order": True,
              "maxitr": 80,
              "maxitr_per_angle": 1, 
              "update_mode": "FISTA",
              "flag_reg": True,
              "reg_type": ["positivity_and_real"],    
              "slice_binning_factor": 32,
              },          
    #Slice binning, 64x   
    5405:    {"dset": "sim_sio2_full",
              "flag_save_output" : True,
              "fista_L": 1e-4,
              "random_order": True,
              "maxitr": 80,
              "maxitr_per_angle": 1, 
              "update_mode": "FISTA",
              "flag_reg": True,
              "reg_type": ["positivity_and_real"],      
              "slice_binning_factor": 64,
              },         
}

In [6]:
# All datasets available (Preprocessed)
datasets = {
    "sim_sio2_full"                : "TEM_simulation_480_SiO2_py.mat",
    "sim_sio2_more_measure"        : "TEM_simulation_480_SiO2_full_py.mat",
    "sim_sio2_full_vacant"         : "TEM_simulation_480_SiO2_vacancy_py.mat",    
}

In [7]:
def constructOpts(exp_i):
    """
    This function checks all options parameters for TEM Recon:
    If a parameter is passed in, use
    else, use default value
    """
    opt_args = exp_all.get(exp_i, {}) 
    opt_args["exp_i"] = exp_i
    for y in opt_args:
        print (y,':',opt_args[y])
        
    return {
        #Experiment params
        "exp_i" :                   opt_args.get("exp_i",                   0),
        "flag_save_output" :        opt_args.get("flag_save_output",        False),

        #Dataset
        "dset" :                    opt_args.get("dset",                    "sim_sio2_full"),
        "flag_add_noise" :          opt_args.get("flag_add_noise",          False),
        "dose_per_pixel" :          opt_args.get("dose_per_pixel",          200),
        "defocus_use":              opt_args.get("defocus_use",             "all defocus"),
        "rotation_use":             opt_args.get("rotation_use",            "all rotation"),
        
        #Reconstruction params
        "maxitr" :                  opt_args.get("maxitr",                  10),
        "maxitr_per_angle" :        opt_args.get("maxitr_per_angle",        1),
        "update_mode":              opt_args.get("update_mode",             "FISTA"),
        "random_order":             opt_args.get("random_order",            False),
        "gradient_batch_size":      opt_args.get("gradient_batch_size",     1),
        # 1/fista_L is the step size
        "fista_L":                  opt_args.get("fista_L",                 None),
        
        #rotation
        "flag_rotation_pad":        opt_args.get("flag_rotation_pad",       True),
        
        #multislice propagation binning
        "slice_binning_factor":     opt_args.get("slice_binning_factor",    1),
        
        #Regularization params
        "flag_reg" :                opt_args.get("flag_reg",                False),
        "reg_type" :                np.array(opt_args.get("reg_type",       [])).ravel(),
        "reg_params":               opt_args.get("reg_params",                {}),
        
    }

In [8]:
def selectMeasurements(data, defocus_use = "all defocus", rotation_use = "all rotation"):
    """
    Selects measurements, data_map is a vector choosing angles/defocus
    """
    if rotation_use != "all rotation":
        data["intensity_measure"] = np.squeeze(data["intensity_measure"][:, :, :, rotation_use])
        data["tilt_angles"] = data["tilt_angles"].ravel()[rotation_use]
    if defocus_use != "all defocus":
        data["intensity_measure"] = np.squeeze(data["intensity_measure"][:, :, defocus_use, :])
        data["defocus_stack"] = data["defocus_stack"].ravel()[defocus_use]
    print("Tilts used (Degrees):", data["tilt_angles"])
    print("Defocus used (Angstrom):", data["defocus_stack"])
    print("Data shape:", data["intensity_measure"].shape)
    return data
        
def loadTEMData(indir, opts):
    """
    Given dataset name and input directory,
    this function loads the corresponding datasets
    """
    fn = "%s%s" % (indir, datasets.get(opts["dset"]))
    print(fn)
    try:
        data = sio.loadmat(fn)
    except:
        print("Data not found!")
        return None
    
    return data

def saveResults(outdir, TEM_rec, opts):
    """
    Given TEM object, and the opts dictionary,
    the reconstruction result is saved to the ouput directory
    """
    if opts["fista_L"] == None:
        opts["fista_L"] = "line_search"

    rec = {
        "obj_final": TEM_rec.current_rec,
        "obj_init": TEM_rec.obj_init,
        "cost": TEM_rec.cost,
        "opts": opts,
    }
    if hasattr(TEM_rec, 'rec_field'):
        rec["rec_field"] = TEM_rec.rec_field
    
    fn = "%sexp%d_%s_%s_%s%d" % (outdir, opts["exp_i"], opts["dset"], opts["update_mode"], \
                                "batch", opts["gradient_batch_size"])
    
    # Filename
    if opts["flag_reg"]:
        for reg_type in opts["reg_type"]:
            fn = "%s_%s" % (fn, reg_type)
        
    if opts["random_order"]:
        fn = "%s_%s" % (fn, "random_order")
        
    if opts["flag_add_noise"]:
        fn = "%s_%s%d" % (fn, "noise", opts["dose_per_pixel"])
        
    data = sio.savemat(fn, rec)

In [9]:
# Main
# Run each experiment specified in the first cell
for exp_i in exps:

    #Create options stucture
    opts = constructOpts(exp_i)

    #Load preprocessed data
    data = loadTEMData(indir, opts)
    if data == None:
        continue
        
    if "field_measure" not in data:
        intensity_measure = data["amplitude_measure"] ** 2
    else:
        intensity_measure = np.abs(data["field_measure"]) ** 2
    
    #Add noise if specified
    if opts["flag_add_noise"]:
        intensity_measure = np.random.poisson(intensity_measure * opts["dose_per_pixel"]).astype(float)
        intensity_measure /= opts["dose_per_pixel"]
    data["intensity_measure"] = intensity_measure
    data = selectMeasurements(data, opts["defocus_use"], opts["rotation_use"])
    
    TEM_obj = TEM_recon.TEM_recon_gpu(data, opts)
    
    TEM_obj.run()

    #Save reconstruction
    if opts["flag_save_output"]:
        saveResults(outdir, TEM_obj, opts)

dset : sim_sio2_full
flag_save_output : True
fista_L : 0.0001
random_order : True
maxitr : 1
maxitr_per_angle : 1
update_mode : FISTA
flag_reg : True
reg_type : ['positivity_and_real', 'tv']
reg_params : {'reg_tv': 1.0, 'max_iter_tv': 15}
flag_add_noise : True
dose_per_pixel : 69.44444444444444
slice_binning_factor : 10
exp_i : -1
/home/general/TEMdata/TEM_simulation_480_SiO2_py.mat
Tilts used (Degrees): [[-90 -87 -84 -81 -78 -75 -72 -69 -66 -63 -60 -57 -54 -51 -48 -45 -42 -39
  -36 -33 -30 -27 -24 -21 -18 -15 -12  -9  -6  -3   0   3   6   9  12  15
   18  21  24  27  30  33  36  39  42  45  48  51  54  57  60  63  66  69
   72  75  78  81  84  87]]
Defocus used (Angstrom): [[ 200  450 1000]]
Data shape: (360, 360, 3, 60)
pad is: True
Regularizer - Pure real
Regularizer - Total Variation+positivity_real
---- Start of the MultiPhaseContrast algorithm ----


KeyboardInterrupt: 