# Setup

## Import packages

In [None]:
import pathlib
import numpy as np
import xarray as xr
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.pyplot import subplots
from scipy.signal import convolve
from scipy.ndimage import gaussian_filter
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import LogNorm
from numpy.fft import fftn, fftshift
from tqdm.auto import tqdm
import os

import dask
import dask.array as da
from dask.diagnostics import ProgressBar
from dask_image.ndfourier import fourier_gaussian, _utils
import flox
import flox.xarray
from flox.xarray import xarray_reduce
import xrft

%matplotlib widget

In [None]:
client.restart()

## Setup dask client
Optional, without setting up anything, the arrays will be chunked but may not parallelize as expected

To install dask extension (only if you have jupyter version 4.x):
- pip install dask-labextension
- restart jupyter lab

To set up with JupyterLab dask extension (if installed):
- Click on dask tab on the left side, just under the kernel tab and above the table of contents tab
- Start new cluster at bottom, then click the '<>' to insert a cell into the notebook with code to intialize the client
- Click Launch dashboard in output from inserted cell to see standard dashboard windows on right side
- In the dask tab, launch any other progress windows of intereste (graph is nice)

In [None]:
from dask.distributed import Client

client = Client("tcp://127.0.0.1:50391")
client

## Import local files

In [None]:
from ptable_dict import ptable, atomic_masses
from utilities import write_xyz, load_xyz, rotation_matrix, gaussian_kernel, load_array_from_npy_stack
from meshgrids import generate_density_grid, convert_grid_qspace, plot_3D_grid, generate_electron_grid_npys_fixed, xrft_fft
from detector import make_detector, rotate_about_normal, rotate_about_horizontal, rotate_about_vertical, intersect_detector

# Generate and plot real-space voxel map for xyz file

## Fixed voxel segmentations along x
- generates real space voxel map with electron density values according to atomic positions and z
- no smearing is done here electron density is only populated at single voxel for each atom

In [None]:
# Define base path
basePath = pathlib.Path.cwd()
xyzPath = basePath.joinpath('test_xyz_files/graphite_medium.xyz')
npySavePath = basePath.joinpath('output_files')
npySavePath.mkdir(exist_ok=True)

voxel_size = 0.1
min_ax_size = 1024
num_chunks = 16  # segments along x to be saved as .npy files. make reasonable for ram

#iteratively populates voxel grid and saves in chunks as .npy files
#.npy files will be loaded into dask array later
shape_info = generate_electron_grid_npys_fixed(xyzPath, 
                                               voxel_size, 
                                               num_chunks,
                                               npySavePath,
                                               min_ax_size=min_ax_size)
x_axis, y_axis, z_axis, grid_vox_x, grid_vox_y, grid_vox_z = shape_info

In [None]:
# Below loads the numpy array stacks into a dask array
# I've so far been unable to fit this all into the separate python script without 
# running into strange moduleimport errors... but this should work!
# def load_array_from_npy_stack(npy_paths):
#     arrs = []
#     for npy_path in npy_paths:
#         arr = np.load(npy_path)
#         arrs.append(arr)
#     return np.concatenate(arrs, axis=1)

#finds .npy files corresponding to real space voxel map slabs
npy_paths = sorted(npySavePath.glob('*.npy'), key=lambda x: int(x.name.split('_')[3].split('-')[-1]))
#delayed loading conserves memory
density_grid_segments = []
for npy_path in npy_paths:
    density_grid_segment = dask.delayed(np.load)(npy_path)
    #rechunk based on num_chunks
    density_grid_segment = dask.array.from_delayed(density_grid_segment, 
                                                   shape=(grid_vox_y, int(grid_vox_x/num_chunks), grid_vox_z),
                                                   dtype=np.float64)
    #append list of dask array objects
    density_grid_segments.append(density_grid_segment)

# uses list of dask array objects and concatenates into 
density_grid = dask.array.concatenate(density_grid_segments, axis=1)

density_grid = density_grid.persist()

# Put loaded dask array into xarray data_array
dens_grid_DA = xr.DataArray(data=density_grid,
                            dims=['y', 'x', 'z'],
                            coords={'y':y_axis,
                                    'x':x_axis,
                                    'z':z_axis})
dens_grid_DA

# Generate and plot reciprocal space voxel map for xyz file

## Forward FFT with gaussian

In [None]:
# Define fft parallelized function with xrft + dask + xarray!
def xrft_fft(DA, num_chunks):
    fft_yz = xrft.fft(DA, dim=['y','z'], shift=True)  # take dft in y & z direction
    fft_yz_rechunked = fft_yz.chunk({'freq_y':int(len(DA.y))/num_chunks,'x':int(len(DA.x))})  # rechunk along y direction 
    fft_all = xrft.fft(fft_yz_rechunked, dim=['x'], shift=True)  # take dft in x direction
    return fft_all

# Take fft (lazily)
fft_DA = xrft_fft(dens_grid_DA, num_chunks)  # num chunks defined earlier when loading dens_grid
fft_DA = fft_DA.assign_coords({
            'qx': ('freq_x', fft_DA.freq_x.data * 2 * np.pi),
            'qy': ('freq_y', fft_DA.freq_y.data * 2 * np.pi),
            'qz': ('freq_z', fft_DA.freq_z.data * 2 * np.pi)
                   }).swap_dims({'freq_x':'qx', 'freq_y':'qy', 'freq_z':'qz'})


# # Optionally, apply gaussian (multiply by gaussian (analytically fourier transformed sigma value))
# def fft_gaussian_kernel(DA, sigma):
#     sigma *= 1/(2*np.pi) #converts sigma to q-space units
#     qx, qy, qz = da.meshgrid(DA.qy.data, DA.qx.data, DA.qz.data,)
#     g_fft = np.exp(-1/2 * (sigma**2) * (qx**2 + qy**2 + qz**2)) 
    
#     return g_fft

# # # Multiply gaussian (lazily)
# sigma = 1
# fft_DA = fft_DA * fft_gaussian_kernel(fft_DA, sigma)

# Run computation graph with dask and persist result into memroy
fft_DA = fft_DA.persist()
fft_DA

## Visualize 3D reciprocal space

In [None]:
# fft_DA.data.visualize()

In [None]:
extent = 10

iq_DA = np.abs(fft_DA)**2

# Plot
plt.close('all')
iq_DA_sum = iq_DA.sum('qz').compute()
sel_DA = iq_DA_sum.sel(qy=slice(-extent,extent),qx=slice(-extent,extent))
# sel_DA = iq_DA_sum.sel(qx=slice(None,None),qy=slice(None,None))
cmin,cmax = sel_DA.quantile([0.1,0.999])
ax = sel_DA.plot.imshow(norm=plt.Normalize(cmin,cmax))
plt.show()

In [None]:
# 3D plotter, curently needs to compute & hold whole array in memory!

extent = 7.5

plt.close('all')
threshold = 98
num_levels = 10
cmap = 'plasma'

iq_DA = np.abs(fft_DA)**2
sel_DA = iq_DA.sel(qx=slice(-extent,extent), qy=slice(-extent,extent), qz=slice(-extent,extent))

fig, ax = plot_3D_grid(sel_DA.data.compute(), sel_DA.qx.data, sel_DA.qy.data, sel_DA.qz.data, cmap, threshold, num_levels)

plt.show()

## Optional Inverse FFT check  

## Work in progress: code to down sample 3D output (for plotting)

In [None]:
lazy_binned_DA = dens_grid_DA.groupby_bins('x', 128).mean().groupby_bins('y',128).mean().groupby_bins('z',128).mean()

In [None]:
lazy_binned_DA.data.visualize()

In [None]:
binned_DA = lazy_binned_DA.persist()

display(binned_DA)

In [None]:
binned_DA = binned_DA.assign_coords({
            'x': ('x_bins', np.array([interval.mid for interval in binned_DA.x_bins.data])),
            'y': ('y_bins', np.array([interval.mid for interval in binned_DA.y_bins.data])),
            'z': ('z_bins', np.array([interval.mid for interval in binned_DA.z_bins.data]))
                   }).swap_dims({'x_bins':'x', 'y_bins':'y', 'z_bins':'z'})
binned_DA

In [None]:
plt.close('all')
threshold = 99.9
num_levels = 10
cmap = 'plasma'
fig, ax = plot_3D_grid(density_grid.compute(), x_axis, y_axis, z_axis, cmap, threshold, num_levels, log=True)
# fig, ax = plot_3D_grid(binned_DA.data.compute(), binned_DA.x.data, binned_DA.y.data, binned_DA.z.data, cmap, threshold, num_levels, log=True)

plt.show()

# find q-resolutions
### The frequency resolution (qbin size) is given by sampling rate (1/voxel_size) over box size (size of molecule)

In [None]:
x_vals = qx
y_vals = qy
z_vals = qz
qx_res = x_vals[1]-x_vals[0]
qy_res = y_vals[1]-y_vals[0]
qz_res = z_vals[1]-z_vals[0]
print(f'Resolutions are [qx={qx_res:.4f}, qy={qy_res:.4f}, qz={qz_res:.4f}]')

In [None]:
x_vals = iq_DA.qx.data
y_vals = iq_DA.qy.data
z_vals = iq_DA.qz.data
qx_res = x_vals[1]-x_vals[0]
qy_res = y_vals[1]-y_vals[0]
qz_res = z_vals[1]-z_vals[0]
print(f'Resolutions are [qx={qx_res:.4f}, qy={qy_res:.4f}, qz={qz_res:.4f}]')

# Set up Detector

In [None]:
det_pixels = (200,200) #horizontal, vertical
det_qs = (8,8) #horizontal, vertical (these are absolute maximums. detector centered at 0)
det_x_grid, det_y_grid, det_z_grid, det_h, det_v = make_detector(det_qs[0], det_pixels[0], det_qs[1], det_pixels[1])

psi = 0 #rotation in degrees of detector about detector normal axis
det_x_grid, det_y_grid, det_z_grid = rotate_about_normal(det_x_grid, det_y_grid, det_z_grid, psi)
phi = 0 #rotation in degrees of detector about detector vertical axis
det_x_grid, det_y_grid, det_z_grid = rotate_about_vertical(det_x_grid, det_y_grid, det_z_grid, phi)
theta = 0 #rotation in degrees of detector about detector horizontal axis
det_x_grid, det_y_grid, det_z_grid = rotate_about_horizontal(det_x_grid, det_y_grid, det_z_grid, theta)

# plot single detector

In [None]:
det_ints = intersect_detector(iq, qx, qy, qz, det_x_grid, det_y_grid, det_z_grid, det_h, det_v)

# plot
fig, ax1 = subplots()
ax1.imshow(det_ints,
           norm=matplotlib.colors.Normalize(vmin=np.percentile(det_ints, 10), vmax=np.percentile(det_ints, 99)),
           extent=(np.min(det_h),np.max(det_h),np.min(det_v),np.max(det_v)),
           cmap='turbo',
           origin = 'lower')
ax1.set_xlabel('q horizontal')
ax1.set_ylabel('q vertical')

In [None]:
det_ints = intersect_detector(iq_DA.data, iq_DA.qx.data*2*np.pi, iq_DA.qy.data*2*np.pi, iq_DA.qz.data*2*np.pi, det_x_grid, det_y_grid, det_z_grid, det_h, det_v)

# plot
fig, ax1 = subplots()
ax1.imshow(det_ints,
           norm=matplotlib.colors.Normalize(vmin=np.percentile(det_ints, 10), vmax=np.percentile(det_ints, 99)),
           extent=(np.min(det_h),np.max(det_h),np.min(det_v),np.max(det_v)),
           cmap='turbo',
           origin = 'lower')
ax1.set_xlabel('q horizontal')
ax1.set_ylabel('q vertical')

# Generate and sum multiple plots across selected angles

In [None]:
def generate_detector_ints(det_pixels, det_qs, psi, phi, theta):
    det_x_grid, det_y_grid, det_z_grid, det_h, det_v = make_detector(det_qs[0], det_pixels[0], det_qs[1], det_pixels[1])
    
    # psi = 0 #rotation in degrees of detector about detector normal axis
    det_x_grid, det_y_grid, det_z_grid = rotate_about_normal(det_x_grid, det_y_grid, det_z_grid, psi)
    # phi = 0 #rotation in degrees of detector about detector vertical axis
    det_x_grid, det_y_grid, det_z_grid = rotate_about_vertical(det_x_grid, det_y_grid, det_z_grid, phi)
    # theta = 0 #rotation in degrees of detector about detector horizontal axis
    det_x_grid, det_y_grid, det_z_grid = rotate_about_horizontal(det_x_grid, det_y_grid, det_z_grid, theta)
    det_ints = intersect_detector(iq, qx, qy, qz, det_x_grid, det_y_grid, det_z_grid, det_h, det_v)

    return det_ints

In [None]:
#setup detector
det_pixels = (150,150) #horizontal, vertical
det_qs = (6.5,6.5) #horizontal, vertical (these are absolute maximums. detector centered at 0)
psi = 0 #rotation in degrees of detector about detector normal axis
phis = np.linspace(0,180,num=60) #rotation in degrees of detector about detector vertical axis
theta = 0 #rotation in degrees of detector about detector horizontal axis

det_ints = []
det_x_grid, det_y_grid, det_z_grid, det_h, det_v = make_detector(det_qs[0], det_pixels[0], det_qs[1], det_pixels[1])
for i, phi in enumerate(phis):
    det_int = generate_detector_ints(det_pixels, det_qs, psi, phi, theta)
    if i == 0:
        det_sum = det_int
    else:
        det_sum +=det_int
    det_ints.append(det_int)

In [None]:
%matplotlib widget
fig, ax1 = subplots()
cax = ax1.imshow(det_sum,
           norm=matplotlib.colors.LogNorm(vmin=np.percentile(det_sum, 30), vmax=np.percentile(det_sum, 99)),
           extent=(np.min(det_h),np.max(det_h),np.min(det_v),np.max(det_v)),
           cmap='turbo',
           origin = 'lower')
ax1.set_xlabel('q horizontal')
ax1.set_ylabel('q vertical')
ax1.set_xlim(left=0)
ax1.set_ylim(bottom=0)
cbar = fig.colorbar(cax, ax=ax1)

# Visualize each individual detector across angles

In [None]:
%matplotlib inline

In [None]:
for i in range(len(det_ints[:,0,0])):
    det_int = det_ints[i,:,:]
    fig, ax1 = subplots()
    cax = ax1.imshow(det_int,
           norm=matplotlib.colors.LogNorm(vmin=np.percentile(det_int, 10), vmax=np.percentile(det_int, 99)),
           extent=(np.min(det_h),np.max(det_h),np.min(det_v),np.max(det_v)),
           cmap='turbo',
           origin = 'lower')
    ax1.set_xlabel('q horizontal')
    ax1.set_ylabel('q vertical')
    ax1.set_xlim(0, 3)
    ax1.set_ylim(0, 3)
    cbar = fig.colorbar(cax, ax=ax1)
    ax1.set_title(f'Phi = {i*3} degrees')
    plt.show()
    plt.close('all')
    

In [None]:
client.restart()