In [None]:
%load_ext cython
import os
import cbclib as cbc
import numpy as np
import pandas as pd
import pyximport
from importlib import reload
import sys
import hdf5plugin
import h5py
import pygmo
import matplotlib.animation as animation
from tqdm.auto import tqdm
from scipy import ndimage
from scipy.optimize import minimize, differential_evolution
from scipy.interpolate import interpn

import plotly.graph_objects as go
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.widgets import Slider
from pickle import load, dump

pyximport.install(reload_support=True, build_in_temp=False,
                  build_dir='.pyxbld')

In [None]:
if sys.modules.get('dev'): # Maybe sys.modules is better?
    dev = sys.modules.get('dev')
    dev = reload(dev)
else:
    import dev
print(dir(dev))

In [None]:
setup = cbc.ScanSetup.import_ini('results/exp_geom_232_ref_2.ini')
table = cbc.CBCTable.import_hdf('results/scan_232_modeled_part.h5', 'data', setup)
table_det = cbc.CBCTable.import_hdf('results/scan_232_detected.h5', 'data', setup)
crop = cbc.Crop(roi=(1100, 3260, 1040, 3108))
basis = cbc.Basis.import_ini('results/scan_232_basis_ref_2.ini')
samples = cbc.ScanSamples.import_dataframe(pd.read_hdf('results/scan_232_samples_log_2.h5', 'data'))

In [None]:
scan_num = 232
dir_path = '/asap3/petra3/gpfs/p11/2021/data/11010570/raw'
h5_dir = os.path.join(dir_path, f'scan_frames/Scan_{scan_num:d}')
h5_files = sorted([os.path.join(h5_dir, path) for path in os.listdir(h5_dir)
                   if path.endswith(('LambdaFar.nxs', '.h5'))])
data = cbc.CrystData(cbc.CXIStore(h5_files), transform=crop)

wf_data = cbc.CrystData(cbc.CXIStore(f'results/scan_250_whitefield.h5'),
                        transform=crop).load('whitefield')

In [None]:
idxs = np.arange(30, 60)

data = data.clear().load(idxs=idxs)
data = data.update_mask(method='range-bad', vmax=10000000)
data = data.mask_pupil(setup, padding=60)
data = data.import_whitefield(wf_data.whitefield)
data = data.blur_pupil(setup, padding=80, blur=20)

In [None]:
data = data.import_patterns(table.table)
data = data.update_background()

In [None]:
det_obj = data.lsd_detector()
det_obj = det_obj.generate_patterns(vmin=0.9, vmax=5.0, size=(1, 3, 3))

In [None]:
det_res = det_obj.detect(cutoff=90.0, filter_threshold=12.0, group_threshold=0.65, dilation=1.5)

In [None]:
det_res = det_res.update_patterns((1.5, 2.5, 8.5))

In [None]:
%matplotlib widget
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(det_res.patterns[0], vmin=0.0, vmax=0.5, cmap='gray_r')
# for line in det_res.streaks[0].to_numpy():
#     ax.plot(line[:4:2], line[1:4:2], c='r', alpha=0.1)
ax.set_xlim(0, det_obj.shape[2])
ax.set_ylim(det_obj.shape[1], 0)
fig.tight_layout()
plt.show()

In [None]:
mdl_det = data.model_detector(basis, samples, setup)

In [None]:
alpha = 0.1
hkl = basis.generate_hkl(0.3)
counts = mdl_det.count_outliers(hkl, width=5.0, alpha=alpha)
idxs = counts.index[counts['outliers'] > 1.0 * alpha * counts['counts']]
mdl_det = mdl_det.detect(hkl=hkl[idxs], width=5.0)

In [None]:
mdl_det = mdl_det.refine_streaks(2.5)
mdl_det = mdl_det.update_patterns((1.5, 2.5, 7.0))
table = mdl_det.export_table(concatenate=True)

In [None]:
%matplotlib widget
index = 4

fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(data.cor_data[index], vmin=0.9, vmax=5.0, cmap='gray_r')
# ax.imshow(mdl_det.streaks[index].pattern_image(shape=mdl_det.shape[1:]))
for line in mdl_det.streaks[index].to_lines():
    ax.plot(line[:4:2], line[1:4:2], c='r', alpha=0.5)
# for line in mdl_det.models[index].generate_streaks(hkl, 5.0).to_lines():
#     ax.plot(line[:4:2], line[1:4:2], c='r', alpha=0.5)

ax.set_xlim(0, det_obj.shape[2])
ax.set_ylim(det_obj.shape[1], 0)
fig.tight_layout()
plt.show()

In [None]:
bounds = cbc.SetupRefiner.generate_bounds(lat_tol=(0.015, 0.035), foc_tol=0.015,
                                          rot_tol=0.035, z_tol=0.015, tilt_tol=0.0,
                                          frames=table_det.frames)

refiner = table_det.refine_setup(bounds, basis, hkl, tilts, width=4.0)

In [None]:
uda = pygmo.de(gen=300)
algo = pygmo.algorithm(uda)
prob = pygmo.problem(problem)
pops = [pygmo.population(size=10, prob=prob, b=pygmo.bfe()) for _ in range(32)]
archi = pygmo.archipelago()
for pop in pops:
    archi.push_back(algo=algo, pop=pop)

In [None]:
archi.evolve()
%time archi.wait()
x = archi.get_champions_x()[np.argmin(archi.get_champions_f())]

In [None]:
problem.fitness(problem.x0), problem.fitness(x)

In [None]:
%matplotlib widget
idx = 0

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.imshow(data.cor_data[idx], cmap='gray_r', vmin=0.5, vmax=5.0)
# ax.imshow(table.pattern_image(frame_idx), vmin=0.0, vmax=0.5, cmap='gray_r')
# for line in mdl_det.streaks[idx].to_numpy():
#     ax.plot(line[:4:2], line[1:4:2], c='r', alpha=0.4)
for line in refiner.generate_streaks(problem.x0)[idx].to_numpy():
    ax.plot(line[:4:2], line[1:4:2], c='r', alpha=0.4)
for line in refiner.generate_streaks(x)[idx].to_numpy():
    ax.plot(line[:4:2], line[1:4:2], c='b', alpha=0.4)

ax.set_xlim(0, refiner.shape[2])
ax.set_ylim(refiner.shape[1], 0)
plt.tight_layout()
plt.show()

# Intensity merging

In [None]:
if sys.modules.get('dev'): # Maybe sys.modules is better?
    dev = sys.modules.get('dev')
    dev = reload(dev)
else:
    import dev
print(dir(dev))

In [None]:
setup = cbc.ScanSetup.import_ini('results/exp_geom_232_ext_2.ini')
table = cbc.CBCTable.import_hdf('results/scan_232_modeled_part_ext.h5', 'data', setup)
crop = cbc.Crop(roi=(1100, 3260, 1040, 3108))
basis = cbc.Basis.import_ini('results/scan_232_basis_ref_2.ini')
samples = cbc.ScanSamples.import_dataframe(pd.read_hdf('results/scan_232_samples_log_2.h5', 'data'))

In [None]:
scan_num = 232
dir_path = '/asap3/petra3/gpfs/p11/2021/data/11010570/raw'
h5_dir = os.path.join(dir_path, f'scan_frames/Scan_{scan_num:d}')
h5_files = sorted([os.path.join(h5_dir, path) for path in os.listdir(h5_dir)
                   if path.endswith(('LambdaFar.nxs', '.h5'))])
data = cbc.CrystData(cbc.CXIStore(h5_files), transform=crop)

wf_data = cbc.CrystData(cbc.CXIStore(f'results/scan_250_whitefield.h5'),
                        transform=crop).load('whitefield')

In [None]:
index = 0

data = data.clear().load(idxs=index)
data = data.update_mask(method='range-bad', vmax=10000000)
data = data.mask_pupil(setup, padding=60)
data = data.import_whitefield(wf_data.whitefield)
data = data.blur_pupil(setup, padding=80, blur=20)

image = table.pattern_image(index, key='rp')

In [None]:
%matplotlib widget
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.imshow(data.cor_data[0], cmap='gray_r', vmin=0.5, vmax=5.0)
ax.imshow(np.array([1.0, 0.0, 0.0, 0.0]) + image[..., None] * np.array([0.0, 0.0, 0.0, 1.0]))

ax.set_xlim(0, data.shape[2])
ax.set_ylim(data.shape[1], 0)
plt.tight_layout()
plt.show()

In [None]:
%time scaler_init = table.scale((80, 80), basis, samples, 64)

In [None]:
scaler = scaler_init.merge_hkl('4mm')
scaler, x = scaler.train(bandwidth=4.5e-4, n_iter=30, fit_intercept=True, max_iter=20)

In [None]:
scaler, x = scaler.train(bandwidth=4.5e-4, n_iter=30, fit_intercept=True, x0=x, max_iter=20)

In [None]:
scaler0 = scaler_init.update_xtal(bandwidth=4.5e-4, x=np.zeros(2 * (scaler_init.iidxs.size - 1)))[0]

In [None]:
import matplotlib

matplotlib.rcParams['animation.ffmpeg_path']= '/gpfs/cfel/user/nivanov/.conda/envs/pyrost/bin/ffmpeg'

In [None]:
%matplotlib widget
fig, axes = plt.subplots(1, 2, figsize=(8, 4))

im0 = axes[0].imshow(scaler0.xtal.val[0], vmin=0.0, vmax=4.6)
im1 = axes[1].imshow(scaler.xtal.val[0], vmin=0.0, vmax=4.6)

for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])

fig.tight_layout()

# ims is a list of lists, each row is a list of artists to draw in the
# current frame; here we are just animating one artist, the image, in
# each frame
ims = []
for idx in np.arange(0, scaler.xtal.shape[0]):
    im0 = axes[0].imshow(scaler0.xtal.val[idx], vmin=0.0, vmax=5.0, cmap='gray_r', animated=True)
    im1 = axes[1].imshow(scaler.xtal.val[idx], vmin=0.0, vmax=5.0, cmap='gray_r', animated=True)
    if idx == 0:
        axes[0].imshow(scaler0.xtal.val[0], vmin=0.0, vmax=5.0, cmap='gray_r')
        axes[1].imshow(scaler.xtal.val[0], vmin=0.0, vmax=5.0, cmap='gray_r')
    ims.append([im0, im1])

ani = animation.ArtistAnimation(fig, ims, interval=100, blit=True)
writer = animation.FFMpegWriter(fps=10)
ani.save('figures/lysozyme_pupils.mp4', writer=writer, dpi=100)
plt.show()

In [None]:
%matplotlib widget
fig, ax = plt.subplots(1, 3, figsize=(8, 4), gridspec_kw={'width_ratios': [0.1, 1.0, 1.0]})

slider = Slider(ax=ax[0], label='Frame', valmin=0, valmax=scaler.xtal.shape[0] - 1,
                valstep=np.arange(0, scaler.xtal.shape[0]),
                valinit=0, orientation="vertical")

im0 = ax[1].imshow(scaler0.xtal.val[0], vmin=0.1, vmax=4.4)
im1 = ax[2].imshow(scaler.xtal.val[0], vmin=0.1, vmax=4.4)

def update(val):
    im0.set_data(scaler0.xtal.val[slider.val])
    im1.set_data(scaler.xtal.val[slider.val])
    fig.canvas.draw_idle()

slider.on_changed(update)

fig.tight_layout()
plt.show()

In [None]:
from scripts.cbc_simulation import unmerge_hkl_list

sf_data = scaler.merge_sfac(x, '4mm')
sfac, sfac_err = sf_data['sfac'], sf_data['serr']
hkl_full, sfac_full = unmerge_hkl_list(basis, 'results/1azf.pdb.hkl', 0.3)

In [None]:
idxs = np.where(np.all(scaler.hkl[:, None] == hkl_full[None, :], axis=-1))[0]
idxs_sim = np.where(np.all(scaler.hkl[:, None] == hkl_full[None, :], axis=-1))[1]
sfac_sim, hkl_sim = sfac_full[idxs_sim], hkl_full[idxs_sim]
rec_abs = np.sqrt(np.sum(hkl_sim.dot(basis.mat)**2, axis=1))

In [None]:
sfac_mrg = np.empty(scaler.hkl.shape[0])
sfac_mrg[scaler.hkl_idxs] = scaler.model(x)[1][scaler.iidxs[:-1]]
sfac_mrg = sfac_mrg[idxs]

In [None]:
from scipy.optimize import minimize

def r_split(x, sfac_sim, sfac_mrg, mask):
    sfac_mrg = np.exp(x) * sfac_mrg
    return np.sqrt(2.0) * np.sum(np.abs(sfac_sim[mask] - sfac_mrg[mask])) / \
           np.sum(np.abs(sfac_sim[mask] + sfac_mrg[mask]))

def cc_star(x, sfac_sim, sfac_mrg, mask):
    sfac_mrg = np.exp(x) * sfac_mrg
    a = sfac_sim[mask] - np.mean(sfac_sim.mean())
    b = sfac_mrg[mask] - np.mean(sfac_mrg[mask])
    return np.sum(a * b) / np.sqrt(np.sum(a * a) * np.sum(b * b))

res = minimize(r_split, 10, args=(sfac_sim, sfac_mrg, scaler.gain(x) > 100))
print(r_split(res.x, sfac_sim, sfac_mrg, mask))
print(cc_star(res.x, sfac_sim, sfac_mrg, mask))

In [None]:
sf_sim.size

In [None]:
%matplotlib widget
desy_orange = np.array([0.95, 0.55, 0.00])
desy_cyan = np.array([0.00, 0.65, 0.92])
frames = [0, 180]
vmin, vmax = 0.0, 4.5
cmap = 'gray_r'
fontsize = 15

fig = plt.figure(figsize=(8, 4.0))
fig.patch.set_alpha(0.0)
subfigs = fig.subfigures(1, 2, width_ratios=[0.44, 0.55])

subfigs[0].suptitle('Projection maps', fontsize=fontsize)

axes = subfigs[0].subplots(2, 2, gridspec_kw=dict(wspace=0.05, hspace=0.05))
axes[0, 0].imshow(scaler0.xtal.val[frames[0]], vmin=vmin, vmax=vmax, cmap=cmap)
axes[0, 1].imshow(scaler0.xtal.val[frames[1]], vmin=vmin, vmax=vmax, cmap=cmap)
axes[0, 0].set_ylabel('1st iteration', fontsize=fontsize)

axes[1, 0].imshow(scaler.xtal.val[frames[0]], vmin=vmin, vmax=vmax, cmap=cmap)
axes[1, 1].imshow(scaler.xtal.val[frames[1]], vmin=vmin, vmax=vmax, cmap=cmap)
axes[1, 0].set_ylabel('15th iteration', fontsize=fontsize)

for ax in axes.ravel():
    ax.set_xticks([])
    ax.set_yticks([])

sf_sim = np.log(sfac_sim)
sf_mrg = np.log(np.exp(res.x[0]) * sfac_mrg)
ii = np.where(scaler.gain(x) > 100)[0][::3]
    
subfigs[1].suptitle('Structure factors', fontsize=fontsize)

ax = subfigs[1].add_subplot(111)
ax.scatter(rec_abs[ii] / setup.wavelength * 1e-10, sf_sim[ii], s=4,
           c=np.tile(desy_cyan[None, :], (ii.size, 1)), label='ground truth')
ax.scatter(rec_abs[ii] / setup.wavelength * 1e-10, sf_mrg[ii], s=4,
           c=np.tile(desy_orange[None, :], (ii.size, 1)), label='reconstructed')
ax.set_ylim(9, 17)
ax.tick_params(labelsize=12)
ax.grid(True)
ax.legend(fontsize=12)

fig.tight_layout(pad=0.0, rect=(0.12, 0.0, 1.0, 0.9))
plt.show()
# plt.savefig('../figures/exp_scaling.pdf')

# Streak detection, June 2021

In [None]:
dir_path = '/asap3/petra3/gpfs/p11/2021/data/11010570/raw'
crop = cbc.Crop(roi=(1100, 3260, 1040, 3108))
scan_setup = cbc.ScanSetup.import_ini('results/exp_geom_232.ini')
data_dict = {}

In [None]:
scan_num = 209 # whitefield, 2s exposure

wf_data = cbc.converter_petra(dir_path, scan_num, transform=crop)
wf_data = wf_data.load('data', processes=16)

wf_data = wf_data.update_mask(method='range-bad', vmax=10000000)
wf_data = wf_data.update_whitefield(method='mean')

In [None]:
wf_data = wf_data.update_output_file(cbc.CXIStore(f'results/scan_{scan_num:d}_whitefield.h5', 'w'))
wf_data.save(attributes=['mask', 'whitefield', 'data', 'translations', 'tilts'], mode='overwrite')

In [None]:
scan_num = 209 # whitefield, 2s exposure

file = cbc.CXIStore(f'results/scan_{scan_num:d}_whitefield.h5', 'r')
wf_data = cbc.CrystData(file, transform=None)
wf_data = wf_data.load()

In [None]:
scan_num = 213 # B12 protein crystal, 1s exposure, 50 frames
pupil = (1050, 1200,  860,  990)

data = cbc.converter_petra(dir_path, scan_num, mask=mask)

In [None]:
# Lysozyme crystal, 1s exposure, 101 x 4 x 4 frames
# Diffraction starts from frame 505
scan_num = 250
pupil = (1100, 1270,  860, 1070)

h5_dir = os.path.join(dir_path, f'scan_frames/Scan_{scan_num:d}')
h5_files = sorted([os.path.join(h5_dir, path) for path in os.listdir(h5_dir)
                   if path.endswith(('LambdaFar.nxs', '.h5'))])
data_dict[scan_num] = cbc.CrystData(cbc.CXIStore(h5_files, 'r'), crop)

In [None]:
data_dict[scan_num] = data_dict[scan_num].load(idxs=np.arange(101), processes=8)
data_dict[scan_num] = data_dict[scan_num].update_mask(method='range-bad', vmax=10000000)
data_dict[scan_num] = data_dict[scan_num].update_whitefield('mean')

In [None]:
output_file = cbc.CXIStore(f'results/scan_{scan_num:d}_whitefield.h5', mode='w')
data_dict[scan_num] = data_dict[scan_num].update_output_file(output_file)
data_dict[scan_num].save(mode='overwrite')

In [None]:
scan_num = 250

file = cbc.CXIStore(f'results/scan_{scan_num:d}_whitefield.h5')
data_dict[scan_num] = cbc.CrystData(file, transform=None)
data_dict[scan_num] = data_dict[scan_num].load('whitefield')

In [None]:
# Lysozyme crystal, 2s exposure, 721 frames, full rotation tilt series
scan_num = 232

h5_dir = os.path.join(dir_path, f'scan_frames/Scan_{scan_num:d}')
h5_files = sorted([os.path.join(h5_dir, path) for path in os.listdir(h5_dir)
                   if path.endswith(('LambdaFar.nxs', '.h5'))])
data_dict[scan_num] = cbc.CrystData(cbc.CXIStore(h5_files, 'r'), crop)

In [None]:
scan_num = 232

data_dict[scan_num] = data_dict[scan_num].clear().load(idxs=np.arange(100), processes=12)
data_dict[scan_num] = data_dict[scan_num].update_mask(method='range-bad', vmax=10000000)
data_dict[scan_num] = data_dict[scan_num].mask_pupil(scan_setup, padding=60)
data_dict[scan_num] = data_dict[scan_num].import_whitefield(data_dict[250].whitefield)
data_dict[scan_num] = data_dict[scan_num].blur_pupil(scan_setup, padding=80, blur=20)

In [None]:
%matplotlib widget
idx = 4

fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].imshow(data_dict[scan_num].background[idx], vmax=50)
ax[1].imshow(data_dict[scan_num].cor_data[idx], vmin=0.5, vmax=5)
plt.tight_layout()
plt.show()

In [None]:
scan_num = 232

det_obj = data_dict[scan_num].lsd_detector()
det_obj = det_obj.generate_pattern(vmin=0.5, vmax=5.0, size=(1, 3, 3))
det_obj = det_obj.update_lsd(quant=2.0e-2, ang_th=45.0)

In [None]:
%time det_res = det_obj.detect(cutoff=60.0, filter_threshold=0.1, group_threshold=0.5, \
                               dilation=0.5)

In [None]:
%time det_res = det_res.draw_streaks()

In [None]:
%time det_res = det_res.draw_background()

In [None]:
# %time det_res = det_res.update_pattern()

In [None]:
frame_idx = 80

fig, ax = plt.subplots(figsize=(6, 6))
# ax.imshow(det_obj.pattern[frame_idx], vmin=0.0, vmax=1.0, cmap='gray_r')
ax.imshow(det_res.data[frame_idx], vmin=0.5, vmax=5.0, cmap='gray_r')
for line in det_res.streaks[frame_idx].to_numpy():
    ax.plot(line[:4:2], line[1:4:2], c='r', alpha=0.6)
ax.set_xlim(0, det_obj.data.shape[2])
ax.set_ylim(det_obj.data.shape[1], 0)
plt.tight_layout()
plt.show()

In [None]:
df = det_res.export_table(concatenate=True)
# df.to_hdf(f'results/scan_232_test.h5', 'data')

In [None]:
scan_num = 250

file = cbc.CXIStore(f'results/scan_{scan_num:d}_whitefield.h5')
data_dict[scan_num] = cbc.CrystData(file, transform=None)
data_dict[scan_num] = data_dict[scan_num].load('whitefield', processes=2)

In [None]:
scan_num = 232
data = cbc.converter_petra(dir_path, scan_num, transform=crop,
                           idxs=[])

tables = []
for idxs in tqdm(np.array_split(np.arange(721), 10), total=10, desc=f'Processing scan {scan_num:d}'):
    data = data.clear().load(idxs=idxs, processes=12)
    data = data.update_mask(method='range-bad', vmax=10000000)
    data = data.mask_pupil(scan_setup, padding=60)
    data = data.import_whitefield(data_dict[250].whitefield)
    data = data.blur_pupil(scan_setup, padding=80, blur=20)
    det_obj = data.get_detector()
    det_obj = det_obj.generate_streak_data(vmin=0.5, vmax=5., size=(1, 3, 3))
    det_obj = det_obj.update_lsd(quant=2.0e-2)
    det_res = det_obj.detect(cutoff=70.0, filter_threshold=32.5, group_threshold=0.7)
    det_res = det_res.generate_bgd_mask()
    det_res = det_res.update_streak_data()
    tables.append(det_res.export_table(concatenate=True))

In [None]:
table = pd.concat(tables)
table.to_hdf('results/scan_232_indexing.h5', 'data')

# Streak detection, May 2022

In [None]:
dir_path = '/asap3/petra3/gpfs/p11/2022/data/11012881/raw'
crop = cbc.Crop(roi=(800, 3900, 800, 3600))
scan_setup = cbc.ScanSetup.import_ini('results/exp_geom_206.ini')
whitefield = crop.forward(np.load('results/scan_206_whitefield.npy'))
data_dict = {}

In [None]:
scan_num = 173

data_dict[scan_num] = cbc.converter_petra(dir_path, scan_num,
                                          idxs=[])
data_dict[scan_num] = data_dict[scan_num].load(processes=12)
data_dict[scan_num] = data_dict[scan_num].update_mask(method='range-bad', vmax=10000000)

In [None]:
# Lysozyme crystal, 0.5s exposure, 18000 frames, full rotation tilt series
scan_num = 207

data_dict[scan_num] = cbc.converter_petra(dir_path, scan_num, transform=crop, idxs=[])

In [None]:
whitefield = np.load(f'results/scan_{scan_num:d}_whitefield.npy')

In [None]:
data_dict[scan_num].input_file.indices()

In [None]:
idxs = np.array_split(data_dict[scan_num].input_file.indices(), 100)[99]

In [None]:
scan_num = 207

data_dict[scan_num] = data_dict[scan_num].clear().load(idxs=idxs, processes=16)
data_dict[scan_num] = data_dict[scan_num].update_mask(method='range-bad', vmax=40)
data_dict[scan_num] = data_dict[scan_num].mask_pupil(scan_setup, padding=60)
data_dict[scan_num] = data_dict[scan_num].import_whitefield(data_dict[scan_num].transform.forward(whitefield))
data_dict[scan_num] = data_dict[scan_num].blur_pupil(scan_setup, padding=80, blur=20)

In [None]:
%matplotlib widget
idx = 104

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.imshow(data_dict[scan_num].cor_data[idx], vmax=3)
plt.tight_layout()
plt.show()

In [None]:
det_obj = data_dict[scan_num].get_detector()
det_obj = det_obj.generate_streak_data(vmin=0.3, vmax=2.0, size=(1, 3, 3))
det_obj = det_obj.update_lsd(quant=1.8e-2)

In [None]:
%time det_res = det_obj.detect(cutoff=70.0, filter_threshold=6.0, group_threshold=0.4)

In [None]:
%time det_res = det_res.generate_bgd_mask()

In [None]:
%time det_res = det_res.update_streak_data()

In [None]:
frame_idx = 0

fig, ax = plt.subplots(figsize=(6, 6))
# ax.imshow(det_res.data[frame_idx], vmin=0.2, vmax=2.0, cmap='gray_r')
ax.imshow(det_obj.streak_data[frame_idx], vmax=1.0, cmap='gray_r')
# for line in det_res.streaks[det_res.frames[frame_idx]].to_numpy():
#     ax.plot(line[:4:2], line[1:4:2], c='r', alpha=0.4)
ax.set_xlim(0, det_obj.data.shape[2])
ax.set_ylim(det_obj.data.shape[1], 0)
plt.tight_layout()
plt.show()

In [None]:
scan_setup = cbc.ScanSetup.import_ini('results/exp_geom_206.ini')
streaks = cbc.ScanStreaks.import_hdf('results/scan_206_indexing.h5', 'data', scan_setup)
data = cbc.CrystData(cbc.CXIStore('results/scan_206_data.h5'))

In [None]:
frame_idx = 460

pattern = streaks.pattern_image(frame_idx)
data = data.clear().load(idxs=np.array([160]))

In [None]:
det_obj = data.get_detector()
det_obj = det_obj.generate_streak_data(vmin=0.3, vmax=2.0, size=(1, 3, 3))
det_obj = det_obj.update_lsd(quant=1.8e-2)

In [None]:
%time det_res = det_obj.detect(cutoff=70.0, filter_threshold=6.0, group_threshold=0.4)
%time det_res = det_res.generate_bgd_mask()
%time det_res = det_res.update_streak_data()

In [None]:
%matplotlib widget
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(pattern, vmax=0.3, cmap='gray_r')
axes[1].imshow(data['cor_data'][0], vmax=2.0, cmap='gray_r')
for line in det_res.streaks[frame_idx].to_numpy():
    axes[1].plot(line[:4:2], line[1:4:2], c='r', alpha=0.4)
plt.tight_layout()
plt.show()

# CBC indexing

In [None]:
scan_setup = cbc.ScanSetup.import_ini('results/exp_geom_232.ini')
table = cbc.CBCTable.import_hdf('results/scan_232_detected.h5', 'data', scan_setup)
crop = cbc.Crop(roi=(1100, 3260, 1040, 3108))
basis = cbc.Basis.import_ini('results/scan_232_basis.ini')
samples = cbc.ScanSamples.import_dataframe(pd.read_hdf('results/scan_232_samples_trial_4.h5', 'data'))

In [None]:
dataframes = []
for frame, sample in tqdm(samples.items(), total=len(samples)):
    problem = table.refine_sample(frame, (0.0, 0.0, 0.0), basis, sample, 0.5, 4.0)
    dataframes.append(problem.index_frame(problem.x0, frame, num_threads=64))

In [None]:
table_idx = pd.concat(dataframes)
table_idx.to_hdf('results/scan_232_indexed.h5', 'data')

In [None]:
%matplotlib widget
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(table.pattern_image(62))
plt.show()

In [None]:
qsize = (512, 512, 512)
q_bounds = (0.25, 0.25, 0.25)

qx_arr = np.linspace(-q_bounds[0], q_bounds[0], qsize[0], endpoint=True)
qy_arr = np.linspace(-q_bounds[1], q_bounds[1], qsize[1], endpoint=True)
qz_arr = np.linspace(-q_bounds[2], q_bounds[2], qsize[2], endpoint=True)
qmap = table.create_qmap(samples, qx_arr, qy_arr, qz_arr)

In [None]:
%matplotlib widget
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(qmap.val[:, qmap.shape[1] // 2 - 5:qmap.shape[1] // 2 + 5].mean(axis=1), vmax=10.0, cmap='gray_r')
plt.tight_layout()
plt.show()

In [None]:
rmap = qmap.fft(num_threads=64)

In [None]:
if sys.modules.get('dev'): # Maybe sys.modules is better?
    dev = sys.modules.get('dev')
    dev = reload(dev)
else:
    import dev
print(dir(dev))

In [None]:
%time wdw = dev.filter_direction(rmap.grid, [0.0, 1.0, 0.0], 5.0, 1.0, num_threads=64)

In [None]:
%time rmap = qmap.fft(num_threads=64)
%time rmap = rmap.gaussian_blur(1.0, num_threads=64)

In [None]:
%time rmap = rmap.white_tophat(np.ones((9, 9, 9)), num_threads=64).clip(0.0, np.inf)

In [None]:
%time rmap.val /= np.max(rmap.val)

In [None]:
peaks = rmap.find_peaks(0.1)
basis = rmap.find_peaks(0.1, reduce=True)

In [None]:
fig = go.Figure()

fig.add_trace(go.Scatter3d(x=peaks[:, 1], y=peaks[:, 0], z=peaks[:, 2], mode='markers',
                           marker=dict(size=6)))
fig.add_trace(go.Scatter3d(x=basis[:, 1], y=basis[:, 0], z=basis[:, 2], mode='markers',
                           marker=dict(size=8)))

fig.update_layout(width=800, height=800, margin=dict(l=20, r=20, b=20, t=20))
fig.show()

In [None]:
rfit = rmap[::8, ::8, ::8]

In [None]:
x0 = basis.ravel()
bounds = np.sort(np.stack((0.9 * x0, 1.1 * x0), axis=1), axis=1)
res = minimize(rfit.criterion, x0, jac=True, method='L-BFGS-B', bounds=bounds,
               args=(100.0, 5.0, 1e-12, 64))

In [None]:
# Lattice basis in meters
basis * scan_setup.wavelength

In [None]:
fig = go.Figure()

fig.add_trace(go.Volume(x=rx_fit.flatten(),
                        y=ry_fit.flatten(),
                        z=rz_fit.flatten(),
                        value=r_fit.flatten(),
                        isomin=1e-5, isomax=1e-3,
                        opacity=0.1, # needs to be small to see through all surfaces
                        surface_count=20))
# fig.add_trace(go.Scatter3d(x=axes[:, 1], y=axes[:, 0], z=axes[:, 2], mode='markers'))
fig.add_trace(go.Scatter3d(x=res.x[1::3], y=res.x[::3], z=res.x[2::3], mode='markers'))

fig.update_layout(width=800, height=800, margin=dict(l=20, r=20, b=20, t=20))
fig.show()

# CBC Model

In [None]:
scan_num = 232

dir_path = '/asap3/petra3/gpfs/p11/2021/data/11010570/raw'
log_path = os.path.join(dir_path, f'server_log/Scan_logs/Scan_{scan_num:d}.log')
log_data = cbc.LogContainer().read_logs(log_path).read_translations()

In [None]:
scan_setup = cbc.ScanSetup.import_ini('results/exp_geom_232.ini')
smp_pos = np.array([0.14343635, 0.16275142, 0.38192596])
samples = log_data.generate_samples(smp_pos, scan_setup)
# samples.to_dataframe().to_hdf('results/scan_232_samples_trial_3.h5', 'data')

In [None]:
scan_setup = cbc.ScanSetup.import_ini('results/exp_geom_232.ini')
table = cbc.CBCTable.import_hdf('results/scan_232_detected.h5', 'data', scan_setup)
crop = cbc.Crop(roi=(1100, 3260, 1040, 3108))
basis = cbc.Basis.import_ini('results/scan_232_basis.ini')
samples = cbc.ScanSamples.import_dataframe(pd.read_hdf('results/scan_232_samples_trial_4.h5', 'data'))

In [None]:
frame_idx = 0

problem = table.refine_sample(frame_idx, (0.0, 0.0, 0.0), basis, samples[frame_idx], 0.3, 4.0, 0.0)

In [None]:
samples_reg = samples.regularise((5e0, 5e1))

In [None]:
scan_num = 232
dir_path = '/asap3/petra3/gpfs/p11/2021/data/11010570/raw'
h5_dir = os.path.join(dir_path, f'scan_frames/Scan_{scan_num:d}')
h5_files = sorted([os.path.join(h5_dir, path) for path in os.listdir(h5_dir)
                   if path.endswith(('LambdaFar.nxs', '.h5'))])
data = cbc.CrystData(cbc.CXIStore(h5_files), transform=crop)

wf_data = cbc.CrystData(cbc.CXIStore(f'results/scan_250_whitefield.h5')).load('whitefield')

In [None]:
tilts = np.array([sample.rotation.to_tilt() for sample in samples.values()])

In [None]:
%matplotlib widget
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(tilts[:, 0])
ax.plot(tilts[:, 1])
ax.plot(tilts[:, 2])
plt.show()

In [None]:
%matplotlib widget
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(samples.get_positions(1))
ax.plot(samples_reg.get_positions(1))
ax.grid(True)
# ax.plot(samples.get_positions(1) / samples.get_positions(1)[0])
# ax.plot(samples.get_positions(2) / samples.get_positions(2)[0])

plt.show()

In [None]:
problem = table.refine_sample(frame_idx, (2e-2, 2e-3, 0.0), basis, samples[frame_idx],
                              0.3, 3.0, 0.5)

In [None]:
def criterion(x, problem):
    return problem.fitness(x)[0]

res = differential_evolution(criterion, bounds=np.stack(problem.get_bounds()).T,
                             maxiter=100, popsize=25, workers=64, updating='deferred',
                             args=(problem,))

In [None]:
uda = pygmo.de(gen=100)
algo = pygmo.algorithm(uda)
prob = pygmo.problem(problem)
pops = [pygmo.population(size=30, prob=prob, b=pygmo.bfe()) for _ in range(32)]
archi = pygmo.archipelago()
for pop in pops:
    archi.push_back(algo=algo, pop=pop)

In [None]:
archi.evolve()
%time archi.wait()
x = archi.get_champions_x()[np.argmin(archi.get_champions_f())]

# CBC Geometry calibration

In [None]:
# Experimental geometry

log_prt = cbc.LogProtocol.import_default()

foc_z = -0.012028122 # From PXST
dir_path = '/asap3/petra3/gpfs/p11/2021/data/11010570/raw'
h5_dir = os.path.join(dir_path, 'scan_frames/Scan_253')
log_path = os.path.join(dir_path, 'server_log/Scan_logs/Scan_253.log')
h5_files = sorted([os.path.join(h5_dir, path) for path in os.listdir(h5_dir)
                   if path.endswith(('LambdaFar.nxs', '.h5')) and not path.endswith('master.h5')])
h5_master = [os.path.join(h5_dir, path) for path in os.listdir(h5_dir)
             if path.endswith('master.h5')][0]

defocus = log_prt.load_attributes(log_path)['Session logged attributes']['z_sample'] - foc_z
file = cbc.CXIStore(h5_master)
with file:
    x_pixel_size = file.load_attribute('x_pixel_size').item()
    y_pixel_size = file.load_attribute('y_pixel_size').item()
wl = 2.0664032162696132e-11 # 60 keV

In [None]:
pupil = np.array([2210, 2345, 1957, 2090])
pupil_2 = np.array([2214, 2400, 1974, 2155])
db = np.array([2171, 1913])
db_2 = np.array([2159, 1913])

In [None]:
det_dist_2 = 0.15 / (1.0 - 0.718) # sample-to-detector distance for scan 254
det_dist = det_dist_2 - 0.15 # sample-to-detector distance for scan 253

In [None]:
center = db + np.mean(((pupil - np.repeat(db, 2)) * defocus / det_dist).reshape((2, 2)), axis=1)
center_2 = db_2 + np.mean(((pupil_2 - np.repeat(db_2, 2)) * defocus / det_dist_2).reshape((2, 2)), axis=1)
tilt_shift = np.arange(4362)**2 / 2159**2 * 12.0 # Y_253 = Y_254 + tilt_shift[Y_254]
center[0], center_2[0] + tilt_shift[int(center_2[0])]

In [None]:
delta_x = (pupil[2] - db[1]) * x_pixel_size
delta_y = (pupil[0] - db[0]) * y_pixel_size
phis = np.arctan2(delta_y, delta_x)
thetas = np.arctan(np.sqrt(delta_x**2 + delta_y**2) / det_dist)
kin_min = np.stack((np.sin(thetas) * np.cos(phis), np.sin(thetas) * np.sin(phis), np.cos(thetas)))

In [None]:
delta_x = (pupil[3] - db[1]) * x_pixel_size
delta_y = (pupil[1] - db[0]) * y_pixel_size
phis = np.arctan2(delta_y, delta_x)
thetas = np.arctan(np.sqrt(delta_x**2 + delta_y**2) / det_dist)
kin_max = np.stack((np.sin(thetas) * np.cos(phis), np.sin(thetas) * np.sin(phis), np.cos(thetas)))

In [None]:
smp_pos = np.array([center[1] * x_pixel_size,
                    center[0] * y_pixel_size, det_dist])
foc_pos = np.array([db[1] * x_pixel_size, db[0] * y_pixel_size, det_dist + defocus])
scan_setup = cbc.ScanSetup(smp_pos=smp_pos, foc_pos=foc_pos, rot_axis=np.array([0.0, 1.0, 0.0]),
                           kin_min=kin_min, kin_max=kin_max, wavelength=wl, x_pixel_size=x_pixel_size,
                           y_pixel_size=y_pixel_size)

In [None]:
with open('results/exp_geom_248_253.ini', 'w') as ini_file:
    scan_setup.export_ini().write(ini_file)

In [None]:
smp_pos_2 = np.array([center_2[1] * x_pixel_size,
                      center_2[0] * y_pixel_size, det_dist_2])
foc_pos_2 = np.array([db_2[1] * x_pixel_size, db_2[0] * y_pixel_size, det_dist_2 + defocus])
scan_setup_2 = cbc.ScanSetup(smp_pos=smp_pos_2, foc_pos=foc_pos_2, rot_axis=np.array([0.0, 1.0, 0.0]),
                             wavelength=wl, x_pixel_size=x_pixel_size, y_pixel_size=y_pixel_size)

In [None]:
with h5py.File('results/scan_250_proc.h5', 'r') as cxi_file:
    whitefield = cxi_loader.read_cxi('whitefield', cxi_file)
    mask = cxi_loader.read_cxi('mask', cxi_file)[0]

In [None]:
scan_num = 224 # Det dist = 20 cm
data = cbc.converter_petra(dir_path, scan_num, mask=mask,
                           pupil=(2200, 2370, 1900, 2110),
                           roi=(1100, 3260, 1040, 3108),
                           whitefield=whitefield)

In [None]:
scan_num = 225 # Det dist = 40 cm
data = cbc.converter_petra(dir_path, scan_num, mask=mask,
                           pupil=(2200, 2370, 1900, 2110),
                           roi=(1100, 3260, 1040, 3108),
                           whitefield=whitefield)

In [None]:
scan_num = 253 # Det dist = 40 cm

data = cbc.converter_petra(dir_path, scan_num,
                           pupil=(2200, 2370, 1900, 2110),
                           roi=(1100, 3260, 1040, 3108))
data.data = np.sum(data.data, axis=0, keepdims=True)
data.good_frames = data.good_frames[[0]]
data = data.import_mask(mask, update='multiply')
data = data.import_whitefield(3.0 * whitefield)

In [None]:
det_obj = cbc.StreakDetector.import_data(data, 0.2, 3., (1, 3, 3))
det_obj = det_obj.update_lsd(scale=0.9, sigma_scale=0.9, log_eps=0, ang_th=60, density_th=0.5, quant=1.5e-2)
det_obj = det_obj.update_mask(dilation=8)
det_obj = det_obj.update_streak_data(iterations=4)

In [None]:
scan_num = 254 # Det dist = 55 cm
whitefield_2 = np.zeros(whitefield.shape, dtype=np.float64)
whitefield_2[29:, 41:] = whitefield[:-29, :-41]
mask_2 = np.ones(mask.shape, dtype=bool)
mask_2[29:, 41:] = mask[:-29, :-41]

data_2 = cbc.converter_petra(dir_path, scan_num,
                             pupil=(2140, 2450, 1900, 2200),
                             roi=(700, 3660, 640, 3508))

data_2.data = np.sum(data_2.data, axis=0, keepdims=True)
data_2.good_frames = data_2.good_frames[[0]]
data_2 = data_2.import_mask(mask * mask_2, update='multiply')
data_2 = data_2.import_whitefield(3.0 * whitefield)

In [None]:
det_obj_2 = cbc.StreakDetector.import_data(data_2, 0.2, 3., (1, 3, 3))
det_obj_2 = det_obj_2.update_lsd(quant=2.8e-2)
det_obj_2 = det_obj_2.update_mask(dilation=8)
det_obj_2 = det_obj_2.update_streak_data(iterations=4)

In [None]:
%matplotlib notebook
frame_idx = 0
scale = 0.718


fig, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.imshow(data.get('cor_data')[frame_idx], vmin=0.2, vmax=3., cmap='gray_r')
for line in det_obj.streaks[frame_idx]:
    ln = ax.plot(line[:4:2], line[1:4:2], line[4], c='r', alpha=0.6)[0]
ln.set_label('Scan 253')
for line in det_obj_2.streaks[frame_idx]:
    ln = ax.plot((line[:4:2] + data_2.roi[2] - center[1]) * scale + center[1] - data.roi[2],
                 (line[1:4:2] + data_2.roi[0] - center[0]) * scale + center[0] - data.roi[0]
                 + tilt_shift[line[1:4:2].astype(int)], line[4], c='g', alpha=0.6)[0]
ln.set_label('Scan 254, with a tilt')
for line in det_obj_2.streaks[frame_idx]:
    ln = ax.plot((line[:4:2] + data_2.roi[2] - center[1]) * scale + center[1] - data.roi[2],
                 (line[1:4:2] + data_2.roi[0] - center[0]) * scale + center[0] - data.roi[0],
                 line[4], c='b', alpha=0.3)[0]
ln.set_label('Scan 254, no tilt')
ax.scatter(1913 - data.roi[2], 2159 + tilt_shift[2159] - data.roi[0], s=50, c='r')
ax.set_xlim(0, data.roi[3] - data.roi[2])
ax.set_ylim(data.roi[1] - data.roi[0], 0)
ax.legend(fontsize=15)
plt.tight_layout()
plt.savefig('figures/scan_253_254.jpg', dpi=300)

In [None]:
%matplotlib notebook

fig, ax = plt.subplots(1, 2, figsize=(9, 5))
ax[0].imshow(data.data[2, 2140:2450, 1900:2200], vmax=500)
ax[1].imshow(data_2.data[2, 2140:2450, 1900:2200], vmax=500)
plt.show()

smp_pos = [0.15 / (1.0 - 0.718)

In [None]:
%matplotlib notebook

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(data.data[0, 2210:2344, 1957:2090], vmax=300000)
ax[1].imshow(data_2.data[0, 2214:2400, 1974:2155], vmax=150000)
plt.show()