In [None]:
import imars3d

from imars3d.backend.dataio.data import load_data, _get_filelist_by_dir
import tomopy
import os

import numpy as np

import ipywidgets as widgets
from ipywidgets import interactive
from IPython.display import display

import matplotlib.pyplot as plt
%matplotlib notebook

In [None]:
ncore = 10

# Input location 

In [None]:
ct_dir = "/HFIR/CG1D/IPTS-29298/raw/ct_scans/2022_09_29_sample3"
assert os.path.exists(ct_dir)

ob_dir = "/HFIR/CG1D/IPTS-29298/raw/ob/2022_09_30"
assert os.path.exists(ob_dir)

dc_dir = "/HFIR/CG1D/IPTS-29298/raw/df/2022_09_29"
assert os.path.exists(dc_dir)

In [None]:
list_ct_files, list_ob_files, list_dc_files = _get_filelist_by_dir(ct_dir=ct_dir,
                                    ob_dir=ob_dir,
                                    dc_dir=dc_dir)

## loading the data 

In [None]:
%%time
ct, ob, dc, rot_angles = load_data(ct_dir=ct_dir,
                                   ob_dir=ob_dir,
                                   dc_dir=dc_dir,
                                   ct_fnmatch="*.tiff",
                                   ob_fnmatch="*.tiff",
                                   dc_fnmatch="*.tiff")

In [None]:
print(f"{len(ct)=}")
print(f"{type(ct)=}")
print(f"{ct.dtype=}")
print(f"{ct[0,0,0]=}")

visualize imported data 

In [None]:
plt.figure(1)

def plot_ct(index):
    plt.title(f"Angle: {rot_angles[index]:.2f} degrees")
    plt.imshow(ct[index])
    plt.show()
    
ct_plot_ui = interactive(plot_ct,
                        index=widgets.IntSlider(min=0,
                                               max=len(ct),
                                               value=0))
display(ct_plot_ui)

## OB 

In [None]:
plt.figure(2)

def plot_ob(index):
    plt.title(f"Angle: {rot_angles[index]:.2f} degrees")
    plt.imshow(ob[index])
    plt.show()
    
ob_plot_ui = interactive(plot_ob,
                        index=widgets.IntSlider(min=0,
                                               max=len(ob),
                                               value=0))
display(ob_plot_ui)

# Crop

In [None]:
from imars3d.backend.morph.crop import crop, detect_bounds

In [None]:
#%%time
#bounds = detect_bounds(arrays=ob)

In [None]:
#print(bounds)

In [None]:
crop_region = [600, 1350, 100, 1950]    # [left, right, top, bottom]

In [None]:
%%time
ct_crop = crop(arrays=ct,
         crop_limit=crop_region)
ob_crop = crop(arrays=ob,
         crop_limit=crop_region)
dc_crop = crop(arrays=dc,
         crop_limit=crop_region)

In [None]:
np.shape(ct_crop)

In [None]:
plt.figure()
plt.imshow(ct_crop[0])
plt.colorbar()
plt.show()

In [None]:
plt.figure()
vertical_profile = np.mean(ct_crop[0][:,0:300], axis=1)
plt.plot(vertical_profile)
plt.show()

## Gamma filtering 

In [None]:
from imars3d.backend.corrections.gamma_filter import gamma_filter

In [None]:
%%time
ct_gamma = gamma_filter(arrays=ct_crop, 
                        selective_median_filter=False, 
                        diff_tomopy=20, 
                        max_workers=48, 
                        median_kernel=3)

In [None]:
%%time
#ct_gamma = ct_gamma.astype(np.ushort)
#ob_gamma = ob_gamma.astype(np.ushort)
#ob_gamma = ob_crop.astype(np.ushort)
#dc_gamma = dc_gamma.astype(np.ushort)
#dc_gamma = dc_crop.astype(np.ushort)
ob_gamma = ob_crop
dc_gamma = dc_crop

In [None]:
plt.figure(0)
plt.imshow(ct_gamma[0])
plt.colorbar()
plt.show()

In [None]:
plt.figure()
vertical_profile = np.mean(ct_gamma[0][:,0:300], axis=1)
plt.plot(vertical_profile)
plt.show()

# Normalization

In [None]:
# from imars3d.backend.preparation.normalization import normalization

In [None]:
# %%time
# ct_normalized = normalization(arrays=ct_gamma,
#                               flats=ob_gamma,
#                               darks=dc_gamma)

In [None]:
# print(np.shape(ob_gamma))

In [None]:
# plt.figure(3)
# plt.imshow(ct_normalized[0], vmin=0, vmax=1)
# plt.colorbar()
# plt.show()

### doing the calculation manually while waiting for iMars3D to be fixed !!!

In [None]:
%%time
my_ob = np.median(ob_gamma, axis=0)
my_dc = np.median(dc_gamma, axis=0)

ct_norm = []
for ct in ct_gamma:
    ct_norm.append(np.true_divide(ct-my_dc, my_ob-my_dc))


In [None]:
vertical_profile = np.mean(ct_norm[0][:, 0:300], axis=1)
plt.figure()
plt.plot(vertical_profile)
plt.show()

# Beam fluctuations using normalize_roi

In [None]:
bg_region = [5, 100, 250, 1100]  #  [left, right, top, bottom]

roi = [bg_region[2], bg_region[0],
       bg_region[3], bg_region[1]]

proj_norm_beam_fluctuation = tomopy.prep.normalize.normalize_roi(ct_norm,
                                               roi=roi,
                                               ncore=ncore)