In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pathlib
import pyprismatic
import py4DSTEM
import random
from matplotlib import cm
from matplotlib.lines import Line2D
from mpl_toolkits.axes_grid1 import make_axes_locatable

## Simulations

In [None]:
#output path
base_folder = "outputs/temp_series/"

hrtem_2D_base = "hrtem_2D_fp"
hrtem_3D_base = "hrtem_3D_fp"

stem_2D_base = "stem_2D_fp" 
stem_3D_base = "stem_3D_fp"

#input paths
fp_ws2 = "structure_files/WS2_moire_7_3_degree"

In [None]:
for i in range(5):
    ofolder = base_folder + "T%i" % j +"_2D/"
    pathlib.Path(ofolder).mkdir(parents=True, exist_ok=True) 
    ofolder = base_folder + "T%i" % j +"_3D/"
    pathlib.Path(ofolder).mkdir(parents=True, exist_ok=True) 

### HRTEM 2D

In [None]:
pot_hrtem_meta = pyprismatic.Metadata()
pot_hrtem_meta.E0 = 80
pot_hrtem_meta.realspacePixelSize = 0.075
pot_hrtem_meta.potential3D = False
pot_hrtem_meta.includeThermalEffects = True
pot_hrtem_meta.numFP = 1
pot_hrtem_meta.algorithm = "hrtem"
pot_hrtem_meta.probeDefocus = 33.0 #white atom peak

In [None]:
if not pathlib.Path(base_folder+"T4_2D/hrtem_2D_fp0127_T4.h5").exists():
    for j in range(5):
        print("On T%i " %j)
        for i in range(128):
            pot_hrtem_meta.filenameAtoms = fp_ws2 + "_T%i" % j + ".xyz"
            fp_str = hrtem_2D_base +"%04i" % i + "_T%i" % j + ".h5"
            ofolder = base_folder + "T%i" % j +"_2D/"
            pot_hrtem_meta.filenameOutput = ofolder + fp_str
            pot_hrtem_meta.randomSeed = i*1000
            pot_hrtem_meta.go()

### HRTEM 3D

In [None]:
pot_hrtem_meta.potential3D = True
pot_hrtem_meta.zSampling = 40

In [None]:
if not pathlib.Path(base_folder+"T4_3D/hrtem_3D_fp0127_T4.h5").exists():
    for j in range(5):
        print("On T%i " %j)
        for i in range(128):
            pot_hrtem_meta.filenameAtoms = fp_ws2 + "_T%i" % j + ".xyz"
            fp_str = hrtem_3D_base + "%04i" % i + "_T%i" % j + ".h5"
            ofolder = base_folder + "T%i" % j +"_3D/"
            pot_hrtem_meta.filenameOutput = ofolder + fp_str
            pot_hrtem_meta.randomSeed = i*1000
            pot_hrtem_meta.go()

## STEM 2D

In [None]:
pot_stem_meta = pyprismatic.Metadata()
pot_stem_meta.E0 = 80
pot_stem_meta.realspacePixelSize = 0.075
pot_stem_meta.probeSemiangle = 21
pot_stem_meta.potential3D = False
pot_stem_meta.save3DOutput = True
pot_stem_meta.includeThermalEffects = True
pot_stem_meta.numFP = 1
pot_stem_meta.algorithm = "p"
pot_stem_meta.interpolationFactor = 2
pot_stem_meta.probeStep = 0.25

In [None]:
if not pathlib.Path(base_folder+"T4_2D/stem_2D_fp0127_T4.h5").exists():
    for j in range(0,5):
        print("On T%i " %j)
        for i in range(128):
            pot_stem_meta.filenameAtoms = fp_ws2 + "_T%i" % j + ".xyz"
            fp_str = stem_2D_base + "%04i" % i + "_T%i" % j + ".h5"
            ofolder = base_folder + "T%i" % j +"_2D/"
            pot_stem_meta.filenameOutput = ofolder + fp_str
            pot_stem_meta.randomSeed = i*1000
            pot_stem_meta.go()

### STEM 3D

In [None]:
pot_stem_meta.potential3D = True
pot_stem_meta.zSampling = 40

In [None]:
if not pathlib.Path(base_folder+"T4_3D/stem_3D_fp0127_T4.h5").exists():
    for j in range(5):
        print("On T%i " %j)
        for i in range(128):
            pot_stem_meta.filenameAtoms = fp_ws2 + "_T%i" % j + ".xyz"
            fp_str = stem_3D_base + "%04i" % i + "_T%i" % j + ".h5"
            ofolder = base_folder + "T%i" % j +"_3D/"
            pot_stem_meta.filenameOutput = ofolder + fp_str
            pot_stem_meta.randomSeed = i*1000
            pot_stem_meta.go(display_run_time=False, save_run_time=False)

## Plotting

In [None]:
def get_splits(k, N, shuffle=False):
    splits = []
    true_idx = [x for x in range(N)]
    if(shuffle):
        random.shuffle(true_idx)
    for i in range(k):
        idx = [true_idx[x+i*(N//k)] for x in range((N//k))]
        splits.append(idx)
        
    return splits

def get_average_from_split(data, split):
    test_idx = np.zeros(data.shape[-1], dtype=bool)
    test_idx[split] = True

    #pre sum STEM virtual detectors
    test = np.mean(data[:,:,test_idx], axis=-1)
    
    return test

def get_std_devs(data, k_list, shuffle=True):
    errors = []
    for k in k_list:
        stack = np.zeros((data.shape[0], data.shape[1], 64))
        count = 0
        for r in range(128//(2*k)):
            splits = get_splits(k, data.shape[-1], shuffle=shuffle)
            for split in splits:
                cur_img = get_average_from_split(data, split)
                stack[:,:,count] = cur_img
                count += 1
        errors.append(np.mean(np.std(stack, axis=2)))
    
    return errors

In [None]:
stem_2D_T = []
stem_3D_T = []
hrtem_2D_T = []
hrtem_3D_T = []

for j in range(5):
    stem_2D_data = np.zeros((171,99,139,128), dtype=np.single)
    stem_3D_data = np.zeros((171,99,139,128), dtype=np.single)
    hrtem_2D_data = np.zeros((288,168,128), dtype=np.single)
    hrtem_3D_data = np.zeros((288,168,128), dtype=np.single)
    
    folder_2D = folder+"T%i"%j+"_2D/"
    folder_3D = folder+"T%i"%j+"_3D/"
    for i in range(128):
        stem_2D_fp = base_folder_2D+stem_2D_base + "%04i" % i + "_T%i" % j + ".h5"
        stem_3D_fp = base_folder_3D+stem_3D_base + "%04i" % i + "_T%i" % j + ".h5"
        hrtem_2D_fp = base_folder_2D+hrtem_2D_base + "%04i" % i + "_T%i" % j +  ".h5"
        hrtem_3D_fp = base_folder_3D+hrtem_3D_base + "%04i" % i + "_T%i" % j + ".h5"
        stem_2D_data[:,:,:,i] = py4DSTEM.io.read(stem_2D_fp, data_id=0).data
        stem_3D_data[:,:,:,i] = py4DSTEM.io.read(stem_3D_fp, data_id=0).data
        hrtem_2D_data[:,:,i] = np.squeeze(py4DSTEM.io.read(hrtem_2D_fp, data_id=0).data)
        hrtem_3D_data[:,:,i] = np.squeeze(py4DSTEM.io.read(hrtem_3D_fp, data_id=0).data)
        
    stem_2D_T.append(stem_2D_data)
    stem_3D_T.append(stem_3D_data)
    hrtem_2D_T.append(hrtem_2D_data)
    hrtem_3D_T.append(hrtem_3D_data)

In [None]:
k_list = [64, 32, 16, 8, 4, 2]

In [None]:
hrtem_2D_T_devs = []
hrtem_3D_T_devs = []

for j in range(5):
    hrtem_2D_T_devs.append(get_std_devs(hrtem_2D_T[j], k_list))
    hrtem_3D_T_devs.append(get_std_devs(hrtem_3D_T[j], k_list))

In [None]:
haadf_vd = slice(60,120)
haadf_2D_T_devs = []
haadf_3D_T_devs = []
for j in range(5):
    haadf_2D_data = np.sum(stem_2D_T[j][:,:,haadf_vd,:], axis=2)
    haadf_3D_data = np.sum(stem_3D_T[j][:,:,haadf_vd,:], axis=2)
    haadf_2D_T_devs.append(get_std_devs(haadf_2D_data, k_list))
    haadf_3D_T_devs.append(get_std_devs(haadf_3D_data, k_list))

In [None]:
abf_vd = slice(10,30)
abf_2D_T_devs = []
abf_3D_T_devs = []
for j in range(5):
    abf_2D_data = np.sum(stem_2D_T[j][:,:,abf_vd,:], axis=2)
    abf_3D_data = np.sum(stem_3D_T[j][:,:,abf_vd,:], axis=2)
    abf_2D_T_devs.append(get_std_devs(abf_2D_data, k_list))
    abf_3D_T_devs.append(get_std_devs(abf_3D_data, k_list))

In [None]:
N_fp = [128/x for x in k_list]
dw = np.array([0.01, 0.02, 0.04, 0.08, 0.16])
h2_arr = np.array(haadf_2D_T_devs)
h3_arr = np.array(haadf_3D_T_devs)
t2_arr = np.array(hrtem_2D_T_devs)
t3_arr = np.array(hrtem_3D_T_devs)
a2_arr = np.array(abf_2D_T_devs)
a3_arr = np.array(abf_3D_T_devs)

In [None]:
imap = plt.get_cmap("inferno")
f = plt.figure(figsize=(6.4*2,6.4*2*(2/5)))

colors=[20,80,120,160,200]

gs = f.add_gridspec(7,3)
gs.update(wspace=0.2, hspace=1.2) # set the spacing between axes. 
split_idx = 4
ax00 = f.add_subplot(gs[:split_idx,0])
ax01 = f.add_subplot(gs[:split_idx,1])
ax02 = f.add_subplot(gs[:split_idx,2])

ax10 = f.add_subplot(gs[split_idx:,0])
ax11 = f.add_subplot(gs[split_idx:,1])
ax12 = f.add_subplot(gs[split_idx:,2])

axes = [[ax00,ax01,ax02],[ax10,ax11,ax12]]
for j in range(5):
    axes[0][2].loglog(N_fp, hrtem_2D_T_devs[j], linewidth=1.25, marker='.', linestyle="-", color=imap(colors[j]), label="2D")
    axes[0][2].loglog(N_fp, hrtem_3D_T_devs[j], linewidth=1.25, marker='.', linestyle="--",color=imap(colors[j]), label="3D")

    axes[0][1].loglog(N_fp, haadf_2D_T_devs[j], linewidth=1.25, marker='.', linestyle="-", color=imap(colors[j]), label="2D")
    axes[0][1].loglog(N_fp, haadf_3D_T_devs[j], linewidth=1.25, marker='.', linestyle="--",color=imap(colors[j]),  label="3D")

    axes[0][0].loglog(N_fp, abf_2D_T_devs[j], linewidth=1.25, marker='.', linestyle="-", color=imap(colors[j]), label="2D T%i" % j)
    axes[0][0].loglog(N_fp, abf_3D_T_devs[j], linewidth=1.25, marker='.', linestyle="--", color=imap(colors[j]), label="3D T%i" % j)

# axes[0][0].set_ylabel("Cross-validation Error")

axes[1][0].set_xlabel("Number of Frozen Phonons")
axes[1][1].set_xlabel("Number of Frozen Phonons")
axes[1][2].set_xlabel("Number of Frozen Phonons")

axes[0][2].set_title("HRTEM")
axes[0][1].set_title("ABF STEM (10-30mrad)")
axes[0][0].set_title("HAADF STEM (60-120mrad)")

custom_lines = [Line2D([0], [0], color="black", lw=1.25),
                Line2D([0], [0], color="black", lw=1.25, linestyle="--")]

angle = -15
axes[0][0].text(8, 2.7e-3, s="DW=0.16", rotation=angle)
axes[0][0].text(8, 1.4e-3, s="DW=0.08", rotation=angle)
axes[0][0].text(8, 7.5e-4, s="DW=0.04", rotation=angle)
axes[0][0].text(8, 4.3e-4, s="DW=0.02", rotation=angle)
axes[0][0].text(8, 1e-4, s="DW=0.01", rotation=angle)
axes[0][2].legend(custom_lines, ["2D", "3D"], bbox_to_anchor=(1, 0.98), loc='upper right', frameon=False)


for j in [4]:
    axes[1][0].loglog(1/dw, h2_arr[:,j], linewidth=1.25, marker='.', linestyle='-', color="black")
    axes[1][0].loglog(1/dw, h3_arr[:,j], linewidth=1.25, marker='.', linestyle='--', color="black")
    axes[1][1].loglog(1/dw, a2_arr[:,j], linewidth=1.25, marker='.', linestyle='-', color="black")
    axes[1][1].loglog(1/dw, a3_arr[:,j], linewidth=1.25, marker='.', linestyle='--', color="black")
    axes[1][2].loglog(1/dw, t2_arr[:,j], linewidth=1.25, marker='.', linestyle='-', color="black")
    axes[1][2].loglog(1/dw, t3_arr[:,j], linewidth=1.25, marker='.', linestyle='--', color="black")
    
axes[1][0].set_ylabel("Cross-validation error (fraction of probe)")

axes[1][2].set_ylabel("Cross-validation error (intensity)")

for ax in axes[0]:
    ax.set_xticks([2,4,8,16,32,64])
    ax.set_xticklabels([2,4,8,16,32,64])
    ax.grid(True)
    ax.set_xlabel("Number of Frozen Phonons")
    ax.tick_params(axis="y",direction="in", left="off",labelleft="on")
    plt.setp(ax.spines.values(), linewidth=1.25)
        
for ax in axes[1]:
    ax.set_xticks(1/dw)
    ax.set_xticklabels(dw)
    ax.grid(True)
    ax.set_xlabel("1/T")
    ax.tick_params(axis="y",direction="in", left="off",labelleft="on")
    plt.setp(ax.spines.values(), linewidth=1.25)