In [None]:
import torch

import numpy as np
import matplotlib.pyplot as plt
import os
work_dir = "H:\workspace\ptyrad"
os.chdir(work_dir)
print("Current working dir: ", os.getcwd())

In [None]:
from ptyrad.data_io import load_params
from ptyrad.initialization import Initializer
from ptyrad.models import PtychoAD
from ptyrad.optimization import CombinedConstraint, CombinedLoss, create_optimizer
from ptyrad.reconstruction import recon_step
from ptyrad.utils import (
    copy_params_to_dir,
    CustomLogger,
    get_blob_size,
    make_batches,
    make_output_folder,
    parse_sec_to_time_str,
    print_system_info,
    save_results,
    select_scan_indices,
    set_gpu_device,
    test_loss_fn,
    time_sync,
    vprint,
)
from ptyrad.visualization import (
    plot_forward_pass,
    plot_pos_grouping,
    plot_scan_positions,
    plot_summary,
)

In [None]:
print_system_info()
device = set_gpu_device(gpuid=0)

In [None]:
# Initialize the Initializer
params_path = "params/paper/ptyrad_convergence_base.yml"

params              = load_params(params_path)
exp_params          = params['exp_params']
source_params       = params['source_params']
model_params        = params['model_params']
loss_params         = params['loss_params']
recon_params        = params['recon_params']

init = Initializer(exp_params, source_params).init_all()

In [None]:
model = PtychoAD(init.init_variables, model_params, device='cuda', verbose=False)
loss_fn = CombinedLoss(loss_params, device='cuda')

np.random.seed(43)

INDICES_MODE      = recon_params['INDICES_MODE']
batch_size        = 512 #recon_params['BATCH_SIZE'].get("size")
GROUP_MODE        = recon_params['GROUP_MODE']

indices = select_scan_indices(
    init.init_variables['N_scan_slow'],
    init.init_variables['N_scan_fast'],
    subscan_slow=INDICES_MODE['subscan_slow'],
    subscan_fast=INDICES_MODE['subscan_fast'],
    mode=INDICES_MODE['mode'],
)

pos = (model.crop_pos + model.opt_probe_pos_shifts).detach().cpu().numpy() # The .to(torch.float32) upcast is a preventive solution because .numpy() doesn't support bf16

batches = make_batches(indices, pos, batch_size, mode=GROUP_MODE)

In [None]:
with torch.no_grad():
    avg_loss = 0
    for batch in batches:
        model_CBEDs, objp_patches = model(batch)
        measured_CBEDs = model.get_measurements(batch)
        _, losses = loss_fn(model_CBEDs, measured_CBEDs, objp_patches, model.omode_occu)
        avg_loss += sum(losses).cpu().numpy()
    avg_loss /= len(batches)
    print(avg_loss)
    

In [None]:
def getFilename(path, extension):
    '''
    This function go through the folder and return a list of filenames with given extension
    '''
    f_list = os.listdir(path)
    file_list=[]
    for i in f_list:    
        if os.path.splitext(i)[1] == extension:
            file_list.append(os.path.splitext(i)[0])
    return file_list

def natural_sort(lst):
    import re
    def natural_sort_key(s):
        def convert(text):
            return int(text) if text.isdigit() else text.lower()

        segments = [convert(segment) for segment in re.split('([0-9]+)', s)]
        return segments
    return sorted(lst, key=natural_sort_key)

Not sure why the error curve doesn't really match the record, even for ptyrad?

- Seems like the recorded batch_loss is calculated before the constraints and the current plotted loss is after the constraints so there's some difference.

In [None]:
from ptyrad.visualization import plot_forward_pass

In [None]:
package_dict = {
    # 'PtyRAD':   {'path': 'output/paper/tBL_WSe2/20241211_ptyrad_convergence/full_N16384_dp128_flipT100_random16_p12_1obj_6slice_dz2_Adam_plr1e-4_oalr5e-4_oplr5e-4_slr5e-4_orblur0.2_ozblur1_oathr0.98_opos_sng1.0_spr0.1_aff1_0_-3_0/', 'extension': '.pt'},
    'py4DSTEM': {'path': 'output/paper/tBL_WSe2/20241211_py4DSTEM_convergence/N16384_dp128_flipT100_random16_p12_6slice_dz2_update0.02_kzf1/', 'extension': '.hdf5'},
    # 'PtyShv':   {'path': 'data/paper/tBL_WSe2/Panel_g-h_Themis/10/roi10_Ndp128_step128\MLs_ptyrad_p12_g16_pc0_noModel_updW100_mm_Ns6_dz2_reg1_dpFlip_ud_T/', 'extension': '.mat'}
}

data_errors = np.zeros((len(package_dict),20))

for i, key in enumerate(package_dict.keys()):
    path      = package_dict[key]['path']
    extension = package_dict[key]['extension']
    path_list = natural_sort(getFilename(path, extension))
    for j, ckpt in enumerate(path_list):
        source_params['obj_source'] = key
        source_params['pos_source'] = key
        source_params['probe_source'] = key
        source_params['obj_params'] = os.path.join(path, ckpt + extension)
        source_params['pos_params'] = os.path.join(path, ckpt + extension)
        source_params['probe_params'] = os.path.join(path, ckpt + extension)
        init.verbose=False
        init.init_cache()
        init.init_probe()
        init.init_pos()
        init.init_obj()
        
        model = PtychoAD(init.init_variables, model_params, device='cuda', verbose=False)
        # model.opt_obja.data = torch.ones_like(model.opt_objp)
        if j == 2:
            plot_forward_pass(model, [8224, 8288], dp_power = 0.5)
        
        with torch.no_grad():
            avg_loss = 0
            for batch in [batches[0]]: # Just evaluate on 1 batch seems to be fine as well
                model_CBEDs, objp_patches = model(batch)
                measured_CBEDs = model.get_measurements(batch)
                _, losses = loss_fn(model_CBEDs, measured_CBEDs, objp_patches, model.omode_occu)
                avg_loss += sum(losses).cpu().numpy()
            # avg_loss /= len(batches)
        print(f"{key}, {ckpt+extension}, data error: {avg_loss}")
        data_errors[i,j] = avg_loss

In [None]:
iter_idx = np.arange(10,210,10)
import matplotlib.pyplot as plt
plt.figure()
plt.scatter(iter_idx, data_errors[0])
# plt.scatter(iter_idx, data_errors[1])
# plt.scatter(iter_idx, data_errors[2])

plt.show()