In [4]:
import numpy as np
import os
import scipy.misc
import scipy.optimize

In [5]:
'''
Image Class
===========
'''
class Image:
    '''Get image name and path, start self.var'''
    def __init__(self, name=None, path=None, od=None, **kwargs):
        Default_Image_Set = dict(name='Not Provided', path='Not Provided',
                                center_x=1, center_y=1, width=1000000, height=1000000,
                                subsample=1, rotate=0, rotate_method='bilinear',
                                prep_order=['rotate','crop','subsample'],
                                fudge=1, bg_width=0, bg_order=1, bad_light=0,
                                Isat=1, time=1, pixel=1e-6, detuning=0,
                                od_method='log', sigmaf=1, memory_saver=False,
                                lookup_table_version='v1')

        Level_Selector_Image = [['name','path','center_x','center_y','center',
                                 'width','height','cropset','cropi','subsample',
                                 'rotate','rotate_method','prep_order'],
                                ['bg_width','bg_order'],
                                ['bad_light','Isat','time','od_method']]

        # local storage
        self.var = {**Default_Image_Set, **kwargs}
        self.var['Level_Selector'] = list(Level_Selector_Image)
        self.var['recalc'] = [True]*len(self.var['Level_Selector'])

        # Use path if provided, else use name and find path
        if (type(path) is str) and os.path.exists(path):
            self.var['path'], self.var['name'] = path, os.path.splitext(os.path.split(path)[1])[0]
        elif type(name) is str: self.var['name'], self.var['path'] = name, imageio.imagename2imagepath(name)
        elif od is not None:
            self.var['od'] = od
            self.var['Level_Selector'][0] = [] # Disable Level, 0 computations
            self.var['Level_Selector'][1] = [] # Disable Level, 1 computations
            self.var['Level_Selector'][2] = [] # Disable Level, 2 computations
            self.var['recalc'] = [False]*len(self.var['Level_Selector'])
        else: raise ValueError('Please provide at least name, path, or od to the Image constructor.')

    def __str__(self): return 'Image object'
    def __repr__(self): return 'Image object'

    @property
    def od(self,):
        if ('od' not in self.var.keys()) or self.recalc[2]:
            self.optical_density()
        return self.var['od'] * self.fudge

    @property
    def n_2d(self,): return self.od / self.sigma

    @property
    def app(self,): return self.od / self.sigma * self.pixel_binned**2

    @property
    def od_raw(self,): return - np.log(self.If_raw / self.Ii_raw)

    @property
    def total_atoms(self,): return np.nansum(self.app)

    @property
    def rawdata(self,): return imageio.imagepath2imagedataraw(self.path)

    @property
    def alldata(self,):
        if 'alldata' in self.var.keys(): return self.var.get('alldata')
        alldata = imageio.imagedataraw2imagedataall(self.rawdata)
        if self.memory_saver is False: self.var['alldata'] = alldata
        return alldata

    @property
    def Ii_raw(self,):
        if ('Ii_raw' not in self.var.keys()) or self.recalc[0]:
            self.prep_image()
        return self.var['Ii_raw']

    @property
    def If_raw(self,):
        if ('If_raw' not in self.var.keys()) or self.recalc[0]:
            self.prep_image()
        return self.var['If_raw']

    @property
    def alpha_Ii(self,):
        if ('alpha_Ii' not in self.var.keys()) or self.recalc[1]:
            self.border_gradient()
        return self.var['alpha_Ii']

    @property
    def Ii(self,): return (self.Ii_raw * self.alpha_Ii) * (1-self.bad_light)

    @property
    def If(self,): return self.If_raw - (self.Ii_raw * self.alpha_Ii * self.bad_light)

    @property
    def si(self,): return self.Ii / self.Nsat

    @property
    def sf(self,): return self.If / self.Nsat

    @property
    def Ii_avg(self,): return np.nanmean(self.Ii) / self.subsample**2

    @property
    def Ii_avg_binned(self,): return np.nanmean(self.Ii)

    @property
    def si_avg(self,): return np.nanmean(self.si)

    @property
    def name(self,): return self.var.get('name')

    @property
    def path(self,): return self.var.get('path')

    @property
    def center_x(self,): return self.var.get('center_x')

    @property
    def center_y(self,): return self.var.get('center_y')

    @property
    def center(self,): return self.var.get('center', (self.center_x, self.center_y))

    @property
    def width(self,): return self.var.get('width')

    @property
    def height(self,): return self.var.get('height')

    @property
    def cropset(self,): return self.var.get('cropset', dict(center=self.center, width=self.width, height=self.height))

    @property
    def cropi(self,):
        if ('cropi' not in self.var.keys()) or self.recalc[0]:
            self.prep_image()
        return self.var['cropi']

    @property
    def subsample(self,): return self.var.get('subsample')

    @property
    def rotate(self,): return self.var.get('rotate')

    @property
    def rotate_method(self,): return self.var.get('rotate_method')

    @property
    def prep_order(self,): return self.var.get('prep_order')

    @property
    def fudge(self,): return self.var.get('fudge')

    @property
    def bg_width(self,): return self.var.get('bg_width')

    @property
    def bg_order(self,): return self.var.get('bg_order')

    @property
    def bad_light(self,): return self.var.get('bad_light')

    @property
    def Isat(self,): return self.var.get('Isat')

    @property
    def Nsat(self,): return self.Isat * self.time * self.subsample**2

    @property
    def time(self,): return self.var.get('time')

    @property
    def pixel(self,): return self.var.get('pixel')

    @property
    def pixel_binned(self,): return self.pixel * self.subsample

    @property
    def detuning(self,): return self.var.get('detuning')

    @property
    def od_method(self,): return self.var.get('od_method')

    @property
    def sigmaf(self,): return self.var.get('sigmaf')

    @property
    def sigma(self,): return self.var.get('sigma', cst_.sigma0 * self.sigmaf)

    @property
    def memory_saver(self,): return self.var.get('memory_saver')

    @property
    def lookup_table_version(self,): return self.var.get('lookup_table_version')


    '''Recalc Manager'''
    @property
    def recalc(self,): return self.var.get('recalc')

    '''Main Setter Function'''
    def set(self, **kwargs):
        if kwargs.get('refresh',False):
            self.var['recalc'] = [True] * len(self.recalc)
            return None
        keys = kwargs.keys()
        # recalc[0] is True if any of the keys in level 0 is provided and is different from current value
        recalc = [any([(j in keys) and (kwargs[j] != self.var.get(j,None)) for j in i]) for i in self.var['Level_Selector']]
        # Update self.var
        self.var = {**self.var, **kwargs}
        # If recalc[2] is True, then all that follows must also be true
        for i in range(len(recalc)):
            if recalc[i]:
                recalc[i+1:] = [True]*len(recalc[i+1:])
                break
        # self.recalc[0] is True if recalc[0] or self.recalc[0] was already True
        self.var['recalc'] = [recalc[i] or self.recalc[i] for i in range(len(recalc))]

    '''Load Image into Memory == Crop, Subsample, Rotate ==> Store cropi, Ii_raw, If_raw'''
    def prep_image(self,):
        [If, Ii] = self.alldata
        for task in self.prep_order:
            if task == 'crop':
                cropi = get_cropi(Ii, **self.cropset)  # Need to improve speed here, takes 50 ms, (99% of time spent at [XX, YY] = np.meshgrid(x, y))
                Ii = Ii[cropi]
                If = If[cropi]
            elif (task == 'rotate') and (self.rotate != 0):
                Ii = scipy.misc.imrotate(Ii, angle=self.rotate, interp=self.rotate_method) # Takes 250 ms
                If = scipy.misc.imrotate(If, angle=self.rotate, interp=self.rotate_method) # takes 250 ms
            elif (task == 'subsample') and (self.subsample != 1):
                Ii = subsample2D(Ii, bins=[self.subsample, self.subsample]) # 1 ms
                If = subsample2D(If, bins=[self.subsample, self.subsample]) # 1 ms
        self.var['If_raw'], self.var['Ii_raw'] = If, Ii
        self.var['recalc'][0] = False
        self.var['cropi'] = cropi

    '''Find alpha for background subtraction'''
    def border_gradient(self,):
        # If width is set to 0
        if self.bg_width == 0:
            self.var['alpha_Ii'] = np.ones_like(self.Ii_raw)
            self.var['recalc'][1] = False
            return None

        # Get slicer for the border
        data = self.If_raw / self.Ii_raw
        mask = np.ones_like(data)
        w = self.bg_width
        s = data.shape
        mask[w:s[0]-w, w:s[1]-w] = 0
        using = np.logical_and((mask==1) , (np.isfinite(data)) )

        # Get Data for fitting
        xx, yy = np.meshgrid(np.arange(s[1]), np.arange(s[0]))
        xx_f, yy_f, zz_f = (xx[using], yy[using], data[using])
        def poly_2d(xy, b, m1=0, m2=0):
            return b + m1*xy[0] + m2*xy[1]

        # Fit
        guess = [1e-1]
        if self.bg_order == 1: guess = [1e-1, 1e-5, 1e-5]
        fitres, fiterr = scipy.optimize.curve_fit(poly_2d, (xx_f, yy_f), zz_f, p0=guess)
        self.var['alpha_Ii'] = poly_2d((xx, yy), *fitres)
        self.var['recalc'][1] = False

        # Warning for correction larger than 10%
        if abs(np.mean(self.var['alpha_Ii'])-1) >= 0.1:
            print('WARNING! Background correction is larger than 10%. Imagename {}'.format(self.name))

    '''Compute Optical Density'''
    def optical_density(self,):
        method = self.od_method
        if method in ['table','dBL']: self.var['od'] = interp_od(self.si, self.sf, self.time)
        elif method in ['sBL']:
            with np.errstate(divide='ignore', invalid='ignore'): self.var['od'] = - np.log(self.sf / self.si) + self.si - self.sf
        else:
            with np.errstate(divide='ignore', invalid='ignore'): self.var['od'] = - np.log(self.sf / self.si)
        self.var['recalc'][2] = False

    def imshow(self, ax=None):
        if ax is None: _, ax = plt.subplots(figsize=(4,4))

        divider = make_axes_locatable(ax)
        ax_cb = divider.new_horizontal(size="8%", pad=0.05)
        fig1 = ax.get_figure()
        fig1.add_axes(ax_cb)
        im = ax.imshow(self.app, origin='lower')
        plt.colorbar(im, cax=ax_cb)
        ax.set_axis_off()
        ax.set(title='Atoms/Pixel')


    def plot_crop(self, ax=None):
        alldata = self.alldata
        w = self.bg_width
        s = self.Ii_raw.shape
        cropi = self.cropi

        # Prepare Box
        x = [cropi[1].start,cropi[1].start,cropi[1].stop,cropi[1].stop,cropi[1].start]
        y = [cropi[0].start,cropi[0].stop,cropi[0].stop,cropi[0].start,cropi[0].start]
        x.extend([x[2],x[3],x[1]])
        y.extend([y[2],y[3],y[1]])

        try: ax[1]
        except: ax = plt.subplots(figsize=(10,4), ncols=2)[1]

        # Plots
        divider = make_axes_locatable(ax[0])
        ax_cb = divider.new_horizontal(size="8%", pad=0.05)
        fig1 = ax[0].get_figure()
        fig1.add_axes(ax_cb)
        im = ax[0].imshow(np.log(alldata[1] / alldata[0]), clim = [self.od_raw.min(), self.od_raw.max()], origin='lower')
        plt.colorbar(im, cax=ax_cb)
        ax[0].plot(x, y, 'w-', alpha=0.5)
        ax[0].set(title='Bare Image')

        divider = make_axes_locatable(ax[1])
        ax_cb = divider.new_horizontal(size="8%", pad=0.05)
        fig1 = ax[1].get_figure()
        fig1.add_axes(ax_cb)
        im = ax[1].imshow(self.od_raw, origin='lower')
        plt.colorbar(im, cax=ax_cb)
        ax[1].set(title='Cropped, Rotated, Subsampled')
        ax[1].plot([w, w, s[1] - w, s[1] - w, w], [w, s[0] - w, s[0] - w, w, w], 'w-')
        ax[1].set(xlim=[0,s[1]], ylim=[0,s[0]])
        fig1.tight_layout()

    def plot_border_gradient(self,):
        data = self.If_raw / self.Ii_raw
        s = data.shape
        w = self.bg_width
        alpha_Ii = self.alpha_Ii

        fig, ax = plt.subplots(figsize=(8, 5), nrows=2, ncols=3)
        ax[0,0].imshow(self.od_raw, aspect='auto', origin='lower')
        ax[0,0].plot([w, w, s[1] - w, s[1] - w, w], [w, s[0] - w, s[0] - w, w, w], 'w-')
        ax[0,0].set_axis_off()
        ax[0,0].set(title='BG Width Boundary')
        if w != 0:
            ax[0,2].plot(np.nanmean(alpha_Ii[0:w, :], axis=0),'k-')
            ax[0,2].plot(np.nanmean(data[0:w,:], axis=0), '.', markersize=2)
            ax[0,2].set(title='top')
            ax[1,0].plot(np.nanmean(alpha_Ii[:, 0:w], axis=1),'k-')
            ax[1,0].plot(np.nanmean(data[:,0:w], axis=1), '.', markersize=2)
            ax[1,0].set(title='left')
            ax[1,1].plot(np.nanmean(alpha_Ii[:, -w:], axis=1),'k-')
            ax[1,1].plot(np.nanmean(data[:,-w:], axis=1), '.', markersize=2)
            ax[1,1].set(title='right')
            ax[1,2].plot(np.nanmean(alpha_Ii[-w:, :], axis=0),'k-')
            ax[1,2].plot(np.nanmean(data[-w:,:], axis=0), '.', markersize=2)
            ax[1,2].set(title='bottom')

        divider = make_axes_locatable(ax[0,1])
        ax_cb = divider.new_horizontal(size="8%", pad=0.05)
        fig.add_axes(ax_cb)
        im = ax[0,1].imshow((self.alpha_Ii - 1)*100, aspect='auto', origin='lower')
        plt.colorbar(im, cax=ax_cb)
        ax[0,1].set_axis_off()
        ax[0,1].set(title='(alpha_Ii - 1) * 100')

        fig.tight_layout()

In [None]:
# Tests 
default_settings = dict(
   # Image Preparation and Background 
   center_x=1110, center_y=1314, width=800, height=1000, 
   subsample=3, rotate=0, bg_width=20, bg_order=1, bad_light=0,

   # Physical Parameters
   Isat=77, time=10, pixel=0.7e-6, sigmaf=0.5, trap_f=23.35,
   od_method='table', fudge=1, ellipticity=1,
   
   # Hybrid Preparation
   xsec_extension='default', xsec_slice_width=4, 
   xsec_fit_range=1.75, radial_selection=0.5,
   
   # Hybrid Thermometry
   kind='unitary', Tfit_lim=0.1, Tfit_guess_kT=3, Tfit_guess_mu0=0, 
   
   # Other
   memory_saver=True, )

