In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
from matplotlib.patches import Rectangle, Circle
import os
from PIL import Image
from scipy.special import erf
#import tifffile

fname = '09_2DWrinkAu_640kx_770mm_70um_0p64mrad_ss8_50x_50y_100z_216step_x256_y256.raw'
fnum = int(fname[0:2])
yrows = 130
xcols = 128

bright_disk_radius = 5
erf_sharpness = 5

hann2d = np.hanning(xcols)
hann2d = np.outer(hann2d, hann2d)

kx = np.arange(-xcols, xcols, 2)/2
kx,ky = np.meshgrid(kx, kx)

dist = np.hypot(kx, ky)

haadf_mask = np.array(dist >= 30, np.int)

x = np.arange(-xcols/2,xcols/2,1)
x,y = np.meshgrid(x,x)
dist = np.hypot(x,y)
bdisk_filter = erf((dist-bright_disk_radius)*erf_sharpness)/2 - erf((dist+bright_disk_radius)*erf_sharpness)/2 + 1

hann_filter = np.hanning(xcols)
hann_filter = np.outer(hann_filter, hann_filter)

#%%
#Pulls a yrows x xcols diffraction pattern at [first_index, second_index] from filename
#Will crop to xcols x xcols if crop == True
def dp_slice(filename, first_index = 0, second_index = 0, yrows = yrows, 
             xcols = xcols, crop = True, dtype = np.float32, min_count_to_zero = 20, min_offset = 1e-6):
    
    dsize = np.dtype(dtype).itemsize
    
    num_pixels = int((os.stat(filename).st_size/yrows/xcols/dsize)**0.5)   
    offset = int((num_pixels*first_index + second_index)*yrows*xcols*dsize)    
    dp_slice = np.memmap(filename, dtype = dtype, mode = 'r', shape = (yrows, xcols), 
                         order = 'C', offset = offset)
    dp_slice = np.array(dp_slice)    
    if crop:
        dp_slice = dp_slice[:xcols, :xcols]
    
    #Counts under min_count set to min_offset to be very close to zero but not exactly zero to avoid errors
    #with taking logarithms
    dp_slice[dp_slice <= min_count_to_zero] = min_offset
    
    return dp_slice

#Transforms either dpslice or full 4D dp to cepstrum
def dp_to_cep(dp, window = hann2d):
#    cep = dp*window
#    cep[cep==0] = 0.0001
    cep = np.log10(dp)
    cep = np.fft.fft2(cep)
    cep = np.fft.fftshift(cep, (-1, -2))
    cep = np.abs(cep)**2.0 
    return cep


#Creates image from filename dp based on mask
def generate_image(filename, mask, yrows = yrows, xcols = xcols, dtype = np.float32):
    dsize = np.dtype(dtype).itemsize 
    num_pixels = int((os.stat(filename).st_size/yrows/xcols/dsize)**0.5)
    haadf = np.zeros((num_pixels, num_pixels))
    for i in range(num_pixels):
        for j in range(num_pixels):
            haadf[i,j] = np.sum(dp_slice(filename, i, j)*haadf_mask)
    return haadf

def browser(image, filename, cep_max):
    
    #Cursor used to select which scanning point to show associated diffraction pattern and linescan
    class Cursor(object):
        def __init__(self, ax):
            self.ax = ax
            self.lock = False
            self.lx = ax.axhline(color = 'k')
            self.ly = ax.axvline(color = 'k')
            self.x = 0
            self.y = 0
        

        def mouse_move(self, event):
            if not event.inaxes == self.ax:
                return
            if self.lock:
                return
            x,y = event.xdata, event.ydata
            x = int(round(x))
            y = int(round(y))
            self.x = x
            self.y = y
            self.lx.set_ydata(y)
            self.ly.set_xdata(x)
            update_dps(y, x)
            plt.draw()

        

        def click(self, event):
            if not event.inaxes == self.ax:
                return
            if not event.dblclick:
                return
            self.lock = not self.lock

            

    def update_dps(y, x):
        dslice = dp_slice(filename, y, x)
        cep = dp_to_cep(dslice, window = hann_filter)
        dpdisp.set_data(dslice)
        dpdisp.set_clim(dp_min_sl.val, dp_max_sl.val)

        dpmin = np.min(dslice)
        dpmax = np.max(dslice)

        dp_min_sl.valmin = dpmin
        dp_min_sl.valmax = dpmax
        dp_min_ax.set_xlim(dpmin, dpmax)

        

        dp_max_sl.valmin = dpmin
        dp_max_sl.valmax = dpmax
        dp_max_ax.set_xlim(dpmin, dpmax)

        cepdisp.set_data(cep)
        cepdisp.set_clim(cep_min_sl.val, cep_max_sl.val)

        cepmin = np.min(cep)

        cep_min_sl.valmin = cepmin
        cep_min_ax.set_xlim(cepmin, cep_max)

        cep_max_sl.valmin = cepmin
        cep_max_ax.set_xlim(cepmin, cep_max)

        plt.draw()

        

    def update_clim(disp):

        if disp == 'dp':
            dpdisp.set_clim(dp_min_sl.val, dp_max_sl.val)
        elif disp == 'cep':
            cepdisp.set_clim(cep_min_sl.val, cep_max_sl.val)
        plt.draw()

    fig, ax = plt.subplots(1, 3)
    fig.set_size_inches(15, 5)
    plt.subplots_adjust(bottom = 0.20, left = 0.00, right = 0.95)

    
    ax[0].imshow(image, origin = 'lower', aspect = 'equal')
    ax[0].invert_xaxis()
    ax[0].axis('off')

    dslice = dp_slice(filename, 0, 0)
    cep = dp_to_cep(dslice)

    dpdisp = ax[1].imshow(dslice, origin = 'upper', aspect = 'equal')
    c1 = Circle((64,64), 19.41858, fill = False, linestyle = '--', color = 'red', linewidth = 4)
    c2 = Circle((64,64), 19.41858, fill = False, linestyle = '--', color = 'blue', linewidth = 4)

    c1.set_radius(0)
    c2.set_radius(0) 

    ax[1].add_artist(c1)
    ax[1].add_artist(c2)
#    ax[1].set_xlim(43, 85)
#    ax[1].set_ylim(43, 85)
    ax[1].axis('off')

    plt.colorbar(dpdisp, ax = ax[1])

    cepdisp = ax[2].imshow(cep, origin = 'upper', aspect = 'equal')
    c1 = Circle((64,64), 11.02, fill = False, linestyle = '--', color = 'red', linewidth = 4)
    c2 = Circle((64,64), 9.665, fill = False, linestyle = '--', color = 'blue', linewidth = 4)

    ax[2].add_artist(c1)
    ax[2].add_artist(c2)

    c1.set_radius(0)
    c2.set_radius(29.88/2*(3/11)**0.5)
    ax[2].axis('off')

    plt.colorbar(cepdisp, ax = ax[2])

    cursor = Cursor(ax[0])

    plt.connect('motion_notify_event', cursor.mouse_move)
    plt.connect('button_press_event', cursor.click)

    dp_min_ax = plt.axes([0.35, 0.15, 0.20, 0.03])
    dp_max_ax = plt.axes([0.35, 0.10, 0.20, 0.03])

    cep_min_ax = plt.axes([0.70, 0.15, 0.20, 0.03])
    cep_max_ax = plt.axes([0.70, 0.10, 0.20, 0.03])

    dp_min_sl = Slider(dp_min_ax, 'Min', np.min(dslice), np.max(dslice), valinit = 0)
    dp_max_sl = Slider(dp_max_ax, 'Max', np.min(dslice), np.max(dslice), valinit = 0)
    dp_max_sl.set_val(np.max(dslice))

    dp_min_sl.slidermax = dp_max_sl
    dp_max_sl.slidermin = dp_min_sl

    cep_min_sl = Slider(cep_min_ax, 'Min', np.min(cep), cep_max, valinit = 0)
    cep_max_sl = Slider(cep_max_ax, 'Max', np.min(cep), cep_max, valinit = 0)
    cep_max_sl.set_val(cep_max)

    cep_min_sl.slidermax = cep_max_sl
    cep_max_sl.slidermin = cep_min_sl

    dp_min_sl.on_changed(lambda x: update_clim('dp'))
    dp_max_sl.on_changed(lambda x: update_clim('dp'))

    cep_min_sl.on_changed(lambda x: update_clim('cep'))
    cep_max_sl.on_changed(lambda x: update_clim('cep'))
    return cursor, dp_min_sl, dp_max_sl, cep_min_sl, cep_max_sl

def pull_current_dp_cep(browser):
    y = browser[0].y
    x = browser[0].x   

    dp_fname = '%02d_dp_%d_%d.tif' % (fnum, y, x)
    cep_fname = '%02d_cep_%d_%d.tif' % (fnum, y, x)

    dp = dp_slice(fname, y, x)
    cep = dp_to_cep(dp)

    #tifffile.imwrite(dp_fname, dp)
    #tifffile.imwrite(cep_fname, cep)
    return dp_fname, cep_fname

#%%

im = generate_image(fname, haadf_mask)
plt.close('all')
brow = browser(im, fname, cep_max = 1e5)
