# import packages

In [None]:
import pathlib
import numpy as np
import xarray as xr
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.pyplot import subplots
from scipy.signal import convolve
from scipy.ndimage import gaussian_filter
from mpl_toolkits.mplot3d import Axes3D
from numpy.fft import fftn, fftshift
from tqdm.auto import tqdm
import os

import dask
import dask.array as da
from dask_image.ndfilters import convolve
from dask.diagnostics import ProgressBar
import flox
import flox.xarray
from flox.xarray import xarray_reduce
import xrft

%matplotlib widget

In [None]:
from dask.distributed import Client

client = Client("tcp://127.0.0.1:46749")
client

# import local files

In [None]:
from ptable_dict import ptable, atomic_masses
from utilities import write_xyz, load_xyz, rotation_matrix, gaussian_kernel
from meshgrids import generate_density_grid, convert_grid_qspace, plot_3D_grid, generate_electron_grid_npys, load_npy_files_to_dask
from detector import make_detector, rotate_about_normal, rotate_about_horizontal, rotate_about_vertical, intersect_detector

# Generate and plot real-space voxel map for xyz file

In [None]:
# Define base path
# dirr = os.getcwd()
# xyz_path = f'{dirr}/test_xyz_files/mercury_line.xyz'
basePath = pathlib.Path('/nsls2/users/alevin/repos/giwaxs_forward_sim')
xyzPath = basePath.joinpath('test_xyz_files/graphite_small.xyz')
npySavePath = pathlib.Path('/nsls2/users/alevin/misc_data/density_grid_segments/graphite_small')

sigma = 0.2
voxel_size = 0.05
min_ax_size = 1024
segments = 8  # segments along x

x_axis, y_axis, z_axis, grid_vox_x, grid_vox_y, grid_vox_z = generate_electron_grid_npys(xyzPath, 
                                                                                         voxel_size, 
                                                                                         segments,
                                                                                         npySavePath,
                                                                                         sigma,
                                                                                         min_ax_size=min_ax_size)

# Below loads the numpy array stacks into a dask array
# I've so far been unable to fit this all into the separate python script without 
# running into strange moduleimport errors... but this should work!
def load_array_from_npy_stack(npy_paths):
    """"""
    arrs = []
    for npy_path in npy_paths:
        arr = np.load(npy_path)
        arrs.append(arr)

    return np.concatenate(arrs, axis=1)   

npy_paths = sorted(npySavePath.glob('*.npy'))
density_grid = dask.delayed(load_array_from_npy_stack)(npy_paths)
density_grid = dask.array.from_delayed(density_grid, shape=(grid_vox_y, grid_vox_x, grid_vox_z), dtype=float)
density_grid = density_grid.rechunk((grid_vox_y, int(grid_vox_x/8), grid_vox_z))

density_grid = density_grid.persist()
density_grid

In [None]:
# # Create a Gaussian kernel
# sigma_voxel = sigma/voxel_size
# # kernel_size = 6 * sigma_voxel + 1  # Ensure the kernel size covers enough of the Gaussian
# kernel_size = density_grid.shape[0]
# gaussian_kernel_3d = da.from_array(gaussian_kernel(kernel_size, sigma_voxel))
# # convolve gaussian with 
# # density_grid = convolve(density_grid, gaussian_kernel_3d, mode='reflect')
    
# gaussian_kernel_3d

In [None]:
# # with ProgressBar():
# density_grid = density_grid.persist()

In [None]:
density_grid

In [None]:
%%time

# put loaded numpy array into xarray dataarray
dens_grid_DA = xr.DataArray(data=density_grid,
                            dims=['y', 'x', 'z'],
                            coords={'y':y_axis,
                                    'x':x_axis,
                                    'z':z_axis})

# # Dask-ify it
# num_chunks = 8
# dens_grid_DA = dens_grid_DA.chunk({'x':int(len(dens_grid_DA.x)/num_chunks)})  # chunk along just one dimension, for slab fftw(?)
dens_grid_DA

In [None]:
summed_z_arr = dens_grid_DA.sum('z').compute()

In [None]:
summed_z_arr = summed_z_arr.data
summed_z_arr

In [None]:
plt.close('all')
smeared_z_arr = gaussian_filter(summed_z_arr, sigma=5)
plt.imshow(smeared_z_arr, origin='lower')
plt.show()

In [None]:
# %%time
# dens_grid_DA.reduce(np.nanpercentile, q=0.1)

In [None]:
# %%time
# lazy_clims = da.percentile(dens_grid_DA.data.ravel(), [0.1, 0.99])
# lazy_clims

In [None]:
%%time
lazy_binned_DA = dens_grid_DA.groupby_bins('x', 128).mean().groupby_bins('y',128).mean().groupby_bins('z',128).mean()

In [None]:
lazy_binned_DA.data.visualize()

In [None]:
%%time
binned_DA = lazy_binned_DA.persist()

display(binned_DA)

In [None]:
binned_DA = binned_DA.assign_coords({
            'x': ('x_bins', np.array([interval.mid for interval in binned_DA.x_bins.data])),
            'y': ('y_bins', np.array([interval.mid for interval in binned_DA.y_bins.data])),
            'z': ('z_bins', np.array([interval.mid for interval in binned_DA.z_bins.data]))
                   }).swap_dims({'x_bins':'x', 'y_bins':'y', 'z_bins':'z'})
binned_DA

In [None]:
plt.close('all')
threshold = 99.9
num_levels = 10
cmap = 'plasma'
# fig, ax = plot_3D_grid(density_grid.compute(), x_axis, y_axis, z_axis, cmap, threshold, num_levels, log=True)
fig, ax = plot_3D_grid(binned_DA.data.compute(), binned_DA.x.data, binned_DA.y.data, binned_DA.z.data, cmap, threshold, num_levels, log=True)

plt.show()

# Generate and plot reciprocal space voxel map for xyz file

In [None]:
def xrft_iq(DA, num_chunks):
    fft_yz = xrft.fft(DA, dim=['y','z'])  # take dft in y & z direction
    fft_yz_rechunked = fft_yz.chunk({'freq_y':int(len(DA.y))/num_chunks,'x':int(len(DA.x))})  # rechunk along y direction 
    fft_all = xrft.fft(fft_yz_rechunked, dim=['x'])  # take dft in x direction
    iq_DA = np.abs(fft_all)**2
    
    return iq_DA

In [None]:
%%time
num_chunks = 8
iq_DA = xrft_iq(dens_grid_DA, num_chunks)  #.compute()
iq_DA

In [None]:
iq_DA.data.visualize()

In [None]:
%%time 
iq_DA = iq_DA.persist()
iq_DA

In [None]:
%%time
fft_yz = xrft.fft(dens_grid_DA, dim=['y','z'])  # take dft in y & z direction
fft_yz_rechunked = fft_yz.chunk({'freq_y':int(len(dens_grid_DA.y))/8,'x':int(len(dens_grid_DA.x))})  # rechunk along y direction 
fft_all = xrft.fft(fft_yz_rechunked, dim=['x'])  # take dft in x direction
# with ProgressBar():

# fft_all = fft_all.persist()

fft_all

In [None]:
fft_yz.data.visualize()

In [None]:
fft_all.data.visualize()

In [None]:
fft_all = fft_all.compute()
fft_all.data.visualize()

In [None]:
iq_DA = np.abs(fft_all)**2
persisted_iq_DA = iq_DA.persist()
persisted_iq_DA

In [None]:
# fft_DA = xrft.fft(dens_grid_DA, chunks_to_segments=True).mean(['x_segment', 'y_segment', 'z_segment'])
# fft_DA = fft_DA.rename({'freq_y':'qy', 'freq_x':'qx', 'freq_z':'qz'})
# iq_DA = np.abs(fft_DA)**2

In [None]:
# %%time
with ProgressBar():
    iq_DA = iq_DA.compute()

In [None]:
iq, qx, qy, qz = convert_grid_qspace(dens_grid, x_axis, y_axis, z_axis)

In [None]:
plt.close('all')
threshold = 99.9
num_levels = 10
cmap = 'plasma'
fig, ax = plot_3D_grid(iq, qx, qy, qz, cmap, threshold, num_levels)
# fig, ax = plot_3D_grid(iq, qx, qy, qz, cmap, threshold, num_levels)

# ax.set_xlim((-3,3))
# ax.set_ylim((-3,3))
# ax.set_zlim((-3,3))
plt.show()

In [None]:
del iq_DA

In [None]:
iq_DA

In [None]:
persited_iq_DA = iq_DA.persist()

In [None]:
plt.close('all')
threshold = 99.9
num_levels = 10
cmap = 'plasma'
fig, ax = plot_3D_grid(iq_DA.data.compute(), iq_DA.freq_x.data*2*np.pi, iq_DA.freq_y.data*2*np.pi, iq_DA.freq_z.data*2*np.pi, cmap, threshold, num_levels)
# fig, ax = plot_3D_grid(iq_DA.data, iq_DA.freq_x.data*2*np.pi, iq_DA.freq_y.data*2*np.pi, iq_DA.freq_z.data*2*np.pi, cmap, threshold, num_levels)

plt.show()

# find q-resolutions
### The frequency resolution (qbin size) is given by sampling rate (1/voxel_size) over box size (size of molecule)

In [None]:
x_vals = qx
y_vals = qy
z_vals = qz
qx_res = x_vals[1]-x_vals[0]
qy_res = y_vals[1]-y_vals[0]
qz_res = z_vals[1]-z_vals[0]
print(f'Resolutions are [qx={qx_res:.4f}, qy={qy_res:.4f}, qz={qz_res:.4f}]')

In [None]:
x_vals = iq_DA.qx.data*2*np.pi
y_vals = iq_DA.qy.data*2*np.pi
z_vals = iq_DA.qz.data*2*np.pi
qx_res = x_vals[1]-x_vals[0]
qy_res = y_vals[1]-y_vals[0]
qz_res = z_vals[1]-z_vals[0]
print(f'Resolutions are [qx={qx_res:.4f}, qy={qy_res:.4f}, qz={qz_res:.4f}]')

# Set up Detector

In [None]:
det_pixels = (200,200) #horizontal, vertical
det_qs = (8,8) #horizontal, vertical (these are absolute maximums. detector centered at 0)
det_x_grid, det_y_grid, det_z_grid, det_h, det_v = make_detector(det_qs[0], det_pixels[0], det_qs[1], det_pixels[1])

psi = 0 #rotation in degrees of detector about detector normal axis
det_x_grid, det_y_grid, det_z_grid = rotate_about_normal(det_x_grid, det_y_grid, det_z_grid, psi)
phi = 0 #rotation in degrees of detector about detector vertical axis
det_x_grid, det_y_grid, det_z_grid = rotate_about_vertical(det_x_grid, det_y_grid, det_z_grid, phi)
theta = 0 #rotation in degrees of detector about detector horizontal axis
det_x_grid, det_y_grid, det_z_grid = rotate_about_horizontal(det_x_grid, det_y_grid, det_z_grid, theta)

# plot single detector

In [None]:
det_ints = intersect_detector(iq, qx, qy, qz, det_x_grid, det_y_grid, det_z_grid, det_h, det_v)

# plot
fig, ax1 = subplots()
ax1.imshow(det_ints,
           norm=matplotlib.colors.Normalize(vmin=np.percentile(det_ints, 10), vmax=np.percentile(det_ints, 99)),
           extent=(np.min(det_h),np.max(det_h),np.min(det_v),np.max(det_v)),
           cmap='turbo',
           origin = 'lower')
ax1.set_xlabel('q horizontal')
ax1.set_ylabel('q vertical')

In [None]:
det_ints = intersect_detector(iq_DA.data, iq_DA.qx.data*2*np.pi, iq_DA.qy.data*2*np.pi, iq_DA.qz.data*2*np.pi, det_x_grid, det_y_grid, det_z_grid, det_h, det_v)

# plot
fig, ax1 = subplots()
ax1.imshow(det_ints,
           norm=matplotlib.colors.Normalize(vmin=np.percentile(det_ints, 10), vmax=np.percentile(det_ints, 99)),
           extent=(np.min(det_h),np.max(det_h),np.min(det_v),np.max(det_v)),
           cmap='turbo',
           origin = 'lower')
ax1.set_xlabel('q horizontal')
ax1.set_ylabel('q vertical')

# Generate and sum multiple plots across selected angles

In [None]:
def generate_detector_ints(det_pixels, det_qs, psi, phi, theta):
    det_x_grid, det_y_grid, det_z_grid, det_h, det_v = make_detector(det_qs[0], det_pixels[0], det_qs[1], det_pixels[1])
    
    # psi = 0 #rotation in degrees of detector about detector normal axis
    det_x_grid, det_y_grid, det_z_grid = rotate_about_normal(det_x_grid, det_y_grid, det_z_grid, psi)
    # phi = 0 #rotation in degrees of detector about detector vertical axis
    det_x_grid, det_y_grid, det_z_grid = rotate_about_vertical(det_x_grid, det_y_grid, det_z_grid, phi)
    # theta = 0 #rotation in degrees of detector about detector horizontal axis
    det_x_grid, det_y_grid, det_z_grid = rotate_about_horizontal(det_x_grid, det_y_grid, det_z_grid, theta)
    det_ints = intersect_detector(iq, qx, qy, qz, det_x_grid, det_y_grid, det_z_grid, det_h, det_v)

    return det_ints

In [None]:
#setup detector
det_pixels = (150,150) #horizontal, vertical
det_qs = (6.5,6.5) #horizontal, vertical (these are absolute maximums. detector centered at 0)
psi = 0 #rotation in degrees of detector about detector normal axis
phis = np.linspace(0,180,num=60) #rotation in degrees of detector about detector vertical axis
theta = 0 #rotation in degrees of detector about detector horizontal axis

det_ints = []
det_x_grid, det_y_grid, det_z_grid, det_h, det_v = make_detector(det_qs[0], det_pixels[0], det_qs[1], det_pixels[1])
for i, phi in enumerate(phis):
    det_int = generate_detector_ints(det_pixels, det_qs, psi, phi, theta)
    if i == 0:
        det_sum = det_int
    else:
        det_sum +=det_int
    det_ints.append(det_int)

In [None]:
%matplotlib widget
fig, ax1 = subplots()
cax = ax1.imshow(det_sum,
           norm=matplotlib.colors.LogNorm(vmin=np.percentile(det_sum, 30), vmax=np.percentile(det_sum, 99)),
           extent=(np.min(det_h),np.max(det_h),np.min(det_v),np.max(det_v)),
           cmap='turbo',
           origin = 'lower')
ax1.set_xlabel('q horizontal')
ax1.set_ylabel('q vertical')
ax1.set_xlim(left=0)
ax1.set_ylim(bottom=0)
cbar = fig.colorbar(cax, ax=ax1)

# Visualize each individual detector across angles

In [None]:
%matplotlib inline

In [None]:
for i in range(len(det_ints[:,0,0])):
    det_int = det_ints[i,:,:]
    fig, ax1 = subplots()
    cax = ax1.imshow(det_int,
           norm=matplotlib.colors.LogNorm(vmin=np.percentile(det_int, 10), vmax=np.percentile(det_int, 99)),
           extent=(np.min(det_h),np.max(det_h),np.min(det_v),np.max(det_v)),
           cmap='turbo',
           origin = 'lower')
    ax1.set_xlabel('q horizontal')
    ax1.set_ylabel('q vertical')
    ax1.set_xlim(0, 3)
    ax1.set_ylim(0, 3)
    cbar = fig.colorbar(cax, ax=ax1)
    ax1.set_title(f'Phi = {i*3} degrees')
    plt.show()
    plt.close('all')
    