In [None]:
import pyrost as rst
import pyrost.simulation as st_sim
import numpy as np
import matplotlib.pyplot as plt
import os

# Performing a multislice beam propagation

In [None]:
import pyrost.multislice as ms_sim
params = ms_sim.MSParams.import_default()
params = params.replace(x_step=5e-5, z_step=5e-3, n_min=100, n_max=5000,
                        focus=1.5e3, mll_sigma=5e-5, mll_wl=6.2e-5, wl=6.2e-5,
                        x_max=30.0, mll_depth=5.0)

In [None]:
mll = ms_sim.MLL.import_params(params)

In [None]:
ms_prgt = ms_sim.MSPropagator(params, mll)
ms_prgt.beam_propagate()

In [None]:
z_arr = np.linspace(0.2 * params.focus, 2.0 * params.focus, 300)
ds_beam, x_arr = ms_prgt.beam_downstream(z_arr, step=4.0 * params.x_step)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
im1 = ax.imshow(np.abs(ds_beam[::10]), vmax=3., cmap='gist_heat_r',
                extent=[z_arr.min(), z_arr.max(), x_arr.min(), x_arr.max()])
cbar = fig.colorbar(im1, ax=ax, shrink=0.7)
cbar.ax.set_ylabel('Normalized intensity, a.u.', fontsize=20)
ax.set_ylabel(r'x coordinate, $\mu m$', fontsize=20)
ax.set_aspect(10)
ax.tick_params(labelsize=15)
ax.set_xlabel(r'$z_1, \mu m$', fontsize=20)
ax.set_title('Beam profile', fontsize=25)
plt.show()

# Speckle tracking reconstruction of a 2d dataset

In [None]:
# OS X
!mkdir -p results/exp
!curl https://www.cxidb.org/data/134/diatom.cxi -o results/exp/diatom.cxi

In [None]:
# Linux
!wget -P results/exp "https://www.cxidb.org/data/134/diatom.cxi"

In [None]:
!h5ls -r results/exp/diatom.cxi

In [None]:
protocol = rst.CXIProtocol.import_default()

In [None]:
inp_file = rst.CXIStore('results/exp/diatom.cxi', protocol=protocol)

In [None]:
out_file = rst.CXIStore('results/exp/diatom_proc.cxi', mode='a',
                        protocol=protocol)

In [None]:
data = rst.STData(input_file=inp_file, output_file=out_file)

In [None]:
inp_file.keys()

In [None]:
data = data.load(processes=4)

In [None]:
crop = rst.Crop(roi=[80, 420, 60, 450])
data = data.update_transform(transform=crop)

In [None]:
data = data.update_mask(method='range-bad', vmax=2e3)
data = data.mask_frames(frames=np.arange(1, 121))

In [None]:
defoci = np.linspace(2e-3, 3e-3, 50)
sweep_scan = data.defocus_sweep(defoci, size=5, hval=1.5)
defocus = defoci[np.argmax(sweep_scan)]
print(defocus)

In [None]:
%matplotlib widget
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(defoci * 1e3, sweep_scan)
ax.set_xlabel('Defocus distance, [mm]', fontsize=15)
ax.set_title('Average gradient magnitude squared', fontsize=20)
ax.grid(True)
ax.tick_params(labelsize=10)
plt.tight_layout()
plt.show()
# plt.savefig('docs/figures/sweep_scan.png', dpi=300)

In [None]:
data = data.update_defocus(defocus)
# data = data.update_defocus(0.002204081632653061)

In [None]:
st_obj = data.get_st(ds_x=3.0, ds_y=3.0)

In [None]:
h_vals = np.linspace(0.5, 3.0, 25)
cv_vals = st_obj.CV_curve(h_vals)

In [None]:
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(h_vals, cv_vals)
ax.set_xlabel('Kernel bandwidth', fontsize=15)
ax.set_title('Cross-validation', fontsize=20)
ax.tick_params(labelsize=10)
ax.grid(True)
plt.tight_layout()
plt.show()
# plt.savefig('docs/figures/cv_curve.png', dpi=300)

In [None]:
st_obj = data.get_st(ds_x=3.0, ds_y=3.0)
st_obj = st_obj.create_initial()
h0 = st_obj.find_hopt(verbose=True)
print(h0)

In [None]:
st_obj = st_obj.update_reference(hval=0.7)

In [None]:
st_res = st_obj.train_adapt(search_window=(2.0, 2.0, 0.1), h0=h0, blur=16.0, n_iter=10,
                            pm_method='rsearch', pm_args={'n_trials': 50, 'strides': (4, 4)},
                            options={'momentum': 0.3})

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
ax.imshow(st_res.reference_image[233:400, 33:233], vmin=0.7, vmax=1.3,
          extent=[33, 233, 400, 233])
ax.set_title('Reference image', fontsize=20)
ax.set_xlabel('horizontal axis', fontsize=15)
ax.set_ylabel('vertical axis', fontsize=15)
ax.tick_params(labelsize=15)
plt.tight_layout()
plt.show()
# plt.savefig('docs/figures/diatom_image.png', dpi=150)

In [None]:
data.import_st(st_res)
fit_obj_ss = data.get_fit(axis=0)
fit_ss = fit_obj_ss.fit(max_order=3)
fit_obj_fs = data.get_fit(axis=1)
fit_fs = fit_obj_fs.fit(max_order=3)

In [None]:
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(data.get('phase'))
ax.set_title('Phase', fontsize=20)
ax.set_xlabel('horizontal axis', fontsize=15)
ax.set_ylabel('vertical axis', fontsize=15)
ax.tick_params(labelsize=15)
plt.tight_layout()
plt.show()
# plt.savefig('docs/figures/diatom_phase.png', dpi=150)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(8, 3))
axes[0].plot(fit_obj_fs.pixels, fit_obj_fs.phase, label='Reconstructed profile')
axes[0].plot(fit_obj_fs.pixels, fit_obj_fs.model(fit_fs['ph_fit']), linestyle='dashed',
             label='Polynomial fit')
axes[0].set_xlabel('Horizontal axis', fontsize=15)
axes[1].plot(fit_obj_ss.pixels, fit_obj_ss.phase, label='Reconstructed profile')
axes[1].plot(fit_obj_ss.pixels, fit_obj_ss.model(fit_ss['ph_fit']), linestyle='dashed',
             label='Polynomial fit')
axes[1].set_xlabel('Horizontal axis', fontsize=15)
for ax in axes:
    ax.set_title('Phase', fontsize=15)
    ax.tick_params(labelsize=10)
    ax.legend(loc='lower left', fontsize=10)
    ax.grid(True)
plt.tight_layout()
plt.show()
# plt.savefig('docs/figures/phase_fit.png', dpi=300)

In [None]:
data.save(mode='overwrite')

In [None]:
!h5ls -r results/exp/diatom_proc.cxi

# Generating a speckle tracking dataset

In [None]:
params = st_sim.STParams.import_default()
params = params.replace(bar_size=0.7, bar_sigma=0.12, bar_atn=0.18,
                        bulk_atn=0.2, p0=5e4, th_s=8e-5, n_frames=100,
                        offset=2.0, step_size=0.1, defocus=150, alpha=0.05,
                        ab_cnt=0.7, bar_rnd=0.8)

In [None]:
protocol = rst.CXIProtocol.import_default()

In [None]:
sim_obj = st_sim.STSim(params)
ptych = sim_obj.ptychograph()
st_conv = st_sim.STConverter(sim_obj, ptych)
st_conv.save('results/sim.cxi', mode='overwrite')

In [None]:
!rm -rf results/sim.cxi

In [None]:
!h5ls -r results/sim.cxi

In [None]:
sim_obj = st_sim.STSim(params)
ptych = sim_obj.ptychograph()
st_conv = st_sim.STConverter(sim_obj, ptych)
data = st_conv.export_data('results/sim.cxi')
data = data.load()

In [None]:
%%bash
source /software/anaconda3/5.2/bin/activate pyrost
python -m pyrost.simulation --help

In [None]:
%%bash
source /software/anaconda3/5.2/bin/activate pyrost
python -m pyrost.simulation results/sim.cxi --bar_size 0.7 --bar_sigma 0.12 \
    --bar_atn 0.18 --bulk_atn 0.2 --p0 5e4 --th_s 8e-5 --n_frames 200 --offset 2 \
    --step_size 0.1 --defocus 150 --alpha 0.05 --ab_cnt 0.7 --bar_rnd 0.8 -p

In [None]:
!h5ls -r results/sim.cxi

# Speckle tracking reconstruction of a simulated dataset

In [None]:
protocol = rst.CXIProtocol.import_default()
inp_file = rst.CXIStore('results/sim.cxi', protocol=protocol)
out_file = rst.CXIStore('results/sim.cxi', mode='a', protocol=protocol)
data = rst.STData(input_file=inp_file, output_file=out_file)
data = data.load()

In [None]:
data.contents()

In [None]:
st_obj = data.get_st()
h0 = st_obj.find_hopt()
st_res = st_obj.train_adapt(search_window=(0.0, 10.0, 0.1), h0=h0, blur=8.0)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(8, 3))
axes[0].plot(np.arange(st_res.reference_image.shape[1]) - st_res.ref_orig[1],
             st_res.reference_image[0])
axes[0].set_title('Reference image', fontsize=20)
axes[1].plot((st_res.pixel_map - st_obj.pixel_map)[1, 0])
axes[1].set_title('Pixel mapping', fontsize=20)
for ax in axes:
    ax.tick_params(labelsize=10)
    ax.set_xlabel('Fast axis, pixels', fontsize=15)
    ax.grid(True)
plt.tight_layout()
plt.show()
# plt.savefig('docs/figures/1d_sim_res.png', dpi=300)

In [None]:
data.import_st(st_res)

In [None]:
fit_obj = data.get_fit(axis=1, center=20)
fit_obj = fit_obj.remove_linear_term()
fit = fit_obj.fit(max_order=2)
print(fit['c_3'])

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(8, 3))
axes[0].plot(fit_obj.pixels, fit_obj.pixel_aberrations)
axes[0].plot(fit_obj.pixels, fit_obj.model(fit['fit']))
axes[0].set_title('Pixel aberrations', fontsize=20)
axes[1].plot(fit_obj.pixels, fit_obj.phase)
axes[1].plot(fit_obj.pixels, fit_obj.model(fit['ph_fit']),
             label=fr"$\alpha$ = {fit['c_3'][0]:.5f} rad/mrad^3")
axes[1].set_title('Phase', fontsize=20)
axes[1].legend(fontsize=10)
for ax in axes:
    ax.tick_params(axis='both', which='major', labelsize=10)
    ax.set_xlabel('horizontal axis', fontsize=15)
    ax.grid(True)
plt.tight_layout()
plt.show()
# plt.savefig('docs/figures/1d_sim_fits.png', dpi=300)

In [None]:
data.save(mode='overwrite')

In [None]:
! h5ls -r results/sim.cxi

# Processing a wavefront metrology experiment

In [None]:
scan_num = 2989
log_path = f'/gpfs/cfel/group/cxi/labs/MLL-Sigray/scan-logs/Scan_{scan_num:d}.log'
data_dir = f'/gpfs/cfel/group/cxi/labs/MLL-Sigray/scan-frames/Scan_{scan_num:d}'
data_files = sorted([os.path.join(data_dir, path) for path in os.listdir(data_dir)
                     if path.endswith('Lambda.nxs')])
wl_dict = {'Mo': 7.092917530503447e-11,
           'Cu': 1.5498024804150033e-10,
           'Rh': 6.137831605603974e-11}

In [None]:
converter = rst.KamzikConverter()
converter = converter.read_logs(log_path)

In [None]:
rst.KamzikConverter()

In [None]:
converter.cxi_keys()

In [None]:
log_data = converter.cxi_get(['basis_vectors', 'log_translations'])

In [None]:
input_file = rst.CXIStore(data_files)
data = rst.STData(input_file, **log_data, distance=2.0, wavelength=wl_dict['Mo'])

In [None]:
data.contents()

In [None]:
data = data.load('data')

In [None]:
%matplotlib widget
fig, ax = plt.subplots(figsize=(8, 3))
ax.imshow(data.data[0], vmax=100)
ax.set_title('Frame 0', fontsize=20)
ax.tick_params(labelsize=15)
plt.tight_layout()
plt.show()
# plt.savefig('docs/figures/sigray_frame.png', dpi=150)

In [None]:
crop = rst.Crop([270, 300, 200, 1240])
mirror = rst.Mirror(axis=1, shape=(crop.roi[1] - crop.roi[0], crop.roi[3] - crop.roi[2]))
transform = rst.ComposeTransforms([crop, mirror])
data = data.update_transform(transform=transform)
data = data.update_mask(vmax=100000)

In [None]:
%matplotlib widget
fig, ax = plt.subplots(figsize=(8, 1))
ax.imshow(data.data[0], vmax=100)
ax.set_title('Frame 0', fontsize=20)
ax.tick_params(labelsize=15)
plt.tight_layout()
plt.show()
# plt.savefig('docs/figures/sigray_crop.png', dpi=150)

In [None]:
data = data.integrate_data()

In [None]:
data = data.mask_frames(np.arange(5, 100))

In [None]:
%matplotlib widget
fig, ax = plt.subplots(figsize=(8, 3))
ax.imshow(data.data[:, 0])
ax.set_title('Ptychograph', fontsize=20)
ax.set_xlabel('horizontal axis', fontsize=15)
ax.set_ylabel('frames', fontsize=15)
ax.tick_params(labelsize=15)
plt.show()
# plt.savefig('docs/figures/sigray_ptychograph.png', dpi=150)

In [None]:
defoci = np.linspace(50e-6, 300e-6, 50)
sweep_scan = data.defocus_sweep(defoci, size=50)
defocus = defoci[np.argmax(sweep_scan)]
print(defocus)

In [None]:
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(defoci * 1e3, sweep_scan)
ax.set_xlabel('Defocus distance, [mm]', fontsize=15)
ax.set_title('Average gradient magnitude squared', fontsize=20)
ax.tick_params(labelsize=15)
ax.grid(True)
plt.tight_layout()
plt.show()
# plt.savefig('docs/figures/sweep_scan_sigray.png', dpi=300)

In [None]:
data = data.update_defocus(defocus)

In [None]:
st_obj = data.get_st()
h0 = st_obj.find_hopt()
st_res = st_obj.train_adapt(search_window=(0.0, 10.0, 0.1), h0=h0, blur=8.0)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(8, 3))
axes[0].plot(np.arange(st_res.reference_image.shape[1]) - st_res.ref_orig[1],
             st_res.reference_image[0])
axes[0].set_title('Reference image', fontsize=20)
axes[1].plot((st_res.pixel_map - st_obj.pixel_map)[1, 0])
axes[1].set_title('Pixel mapping', fontsize=20)
for ax in axes:
    ax.tick_params(labelsize=15)
    ax.set_xlabel('Fast axis, pixels', fontsize=15)
    ax.grid(True)
plt.tight_layout()
plt.show()
# plt.savefig('docs/figures/sigray_res.png', dpi=300)

In [None]:
data.import_st(st_res)

In [None]:
fit_obj = data.get_fit(axis=1)
fit_obj = fit_obj.remove_linear_term()
fit = fit_obj.fit(max_order=3)

In [None]:
fit['c_4']

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(8, 3))
axes[0].plot(fit_obj.thetas, fit_obj.theta_ab * 1e9, 'b')
axes[0].plot(fit_obj.thetas, fit_obj.model(fit['fit']) * fit_obj.ref_ap * 1e9,
             'b--', label=fr"R-PXST $c_4$ = {fit['c_4']:.4f} rad/mrad^4")
axes[0].set_title('Angular displacements, nrad', fontsize=15)

axes[1].plot(fit_obj.thetas, fit_obj.phase, 'b')
axes[1].plot(fit_obj.thetas, fit_obj.model(fit['ph_fit']), 'b--',
             label=fr"R-PXST $c_4$ ={fit['c_4']:.4f} rad/mrad^4")
axes[1].set_title('Phase, rad', fontsize=15)
for ax in axes:
    ax.legend(fontsize=10)
    ax.tick_params(labelsize=10)
    ax.set_xlabel('Scattering angles, rad', fontsize=15)
    ax.grid(True)
plt.tight_layout()
plt.show()
# plt.savefig('docs/figures/sigray_fits.png', dpi=300)

In [None]:
out_file = rst.CXIStore('results/sigray.cxi', mode='a')
data = data.update_output_file(out_file)
data.save(mode='overwrite')