# Initialization

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook
import numpy as np
from skimage import io
import matplotlib.pyplot as plt
import pickle
import pyfftw
from displayTools import Plot_2D
import displayTools as dt
import opticsTools as ot
# wisdom = pickle.load(open('../wisdoms/wisdom_gallery.pickle','rb'))
# pyfftw.import_wisdom(wisdom)

# DPC example

## Load Raw Data

In [None]:
imgDir = '/Users/zfphil/Dropbox/Datasets/AQLM/2017_05_05 - Multi-Contrast/'
fileName = '2017_05_06_dpc+gfp_30sec_df.tif'

# Create object for dataset
dataStack = dt.MultiTiff(imgDir + fileName)
assert dataStack.nPages > 0 , 'Could not load data!'

## Define Dataset Parameters

In [None]:
dataParams = ot.metadata(dataStack.imgSize,\
                         wavelength = 0.514,\
                         mag = 40,\
                         NA = 0.4,\
                         NA_in = 0.0,\
                         RI = 1.33,\
                         rotation = [90, 270, 180, 0])

In [None]:
from opticsTools import DPC_solver
DPCObj = DPC_solver(dataParams)

firstFrame = mt.read(channel = 0, time = 0, z = 0)
DPCObj.setRoi(firstFrame)

## Show Sources

In [None]:
%matplotlib inline
#plot the sources
maxNAx = max(dataParams.fxlin.real * dataParams.wavelength / dataParams.NA)
minNAx = min(dataParams.fxlin.real * dataParams.wavelength / dataParams.NA)
maxNAy = max(dataParams.fylin.real * dataParams.wavelength / dataParams.NA)
minNAy = min(dataParams.fylin.real * dataParams.wavelength / dataParams.NA)
f,ax = plt.subplots(1,4,sharex=True,sharey=True,figsize=(12,3))
for plotIdx,s in enumerate(list(DPCObj.source)):
    ax[plotIdx].imshow(np.fft.fftshift(DPCObj.source[plotIdx]),\
                                            cmap='gray',clim=(0,1),extent=[minNAx,maxNAx,minNAy,maxNAy])
    ax[plotIdx].set_xlim(-1.2,1.2)
    ax[plotIdx].set_ylim(-1.2,1.2)
    ax[plotIdx].set_aspect(1)

## Show Transfer Functions

In [None]:
f,ax = plt.subplots(2,4,sharex=True,sharey=True,figsize = (12,6))
for plotIdx in range(ax.size):
    row = plotIdx // 4
    col = np.mod(plotIdx,4)
    if row == 0:
        ax[row,col].imshow(np.fft.fftshift(DPCObj.Hu[col].real),cmap='jet',\
                           extent=[minNAx,maxNAx,minNAy,maxNAy],clim=[-1.6,1.6])
    else:
        ax[row,col].imshow(np.fft.fftshift(DPCObj.Hp[col].imag),cmap='jet',\
                           extent=[minNAx,maxNAx,minNAy,maxNAy],clim=[-.8,.8])
    ax[row,col].set_xlim(-2.2,2.2)
    ax[row,col].set_ylim(-2.2,2.2)
    ax[row,col].set_aspect(1)

# Process Video Frames

In [None]:
# How to divide up channels
dpcIdx = range(0,4)
dfIdx = 4
gfpIdx = 5

# Regularization
DPCObj.reg_u = 1e-2
DPCObj.reg_p = 1e-3
DPCObj.reg_TV = (1e-2,1e-2)
DPCObj.rho = 5e-1

# Variables to Populate
frameSz = dataStack.imgSize
imgs_gfp   = np.zeros((dataStack.nTimePoints, frameSz[0], frameSz[1]))
imgs_df    = np.zeros((dataStack.nTimePoints, frameSz[0], frameSz[1]))
imgs_phase = np.zeros((dataStack.nTimePoints, frameSz[0], frameSz[1]))

# Loop over all time positions, 
for timeIdx in range(0, 10):
    frameList = np.zeros((dataStack.nChan, frameSz[0], frameSz[1]), dtype = np.uint16)
    
    frameList = dataStack.read(z = 0, time = timeIdx)
    
    # Set Darkfield and Phase
    imgs_df[timeIdx,:,:] = frameList[dfIdx, :, :]
    imgs_gfp[timeIdx, :,:] = frameList[gfpIdx, :, :]
    
    # Collect DPC Images
    imgs_dpc = frameList[dpcIdx,:, :]
    
    # Set Data For This Frame
    DPCObj.setRawData(imgs_dpc)
    
    # Solve for Object using higher-order TV
    dpc_phase = DPCObj.solve(method='TVDeconv', order = 3, maxIter = 2)

    # TODO - write (append) page to multi-page tiff
    
    
    print('Finished Time Point %d of %d' % (timeIdx+1, dataStack.nTimePoints))
    
# Display Results
f = plt.subplots(1,2,figsize = (12,6))
ax = plt.subplot(121); plt.imshow(np.real(dpc_phase), cmap = 'gray'); plt.title('Recovered Absorption')
ax = plt.subplot(122); plt.imshow(np.imag(dpc_phase), cmap = 'gray'); plt.title('Recovered Phase')

In [None]:
DPCObj.reg_u = 1e-2
DPCObj.reg_p = 7e-3
DPCObj.reg_TV = (5e-3, 5e-3)
DPCObj.rho = 0.1
dpc_phase = DPCObj.solve(method = 'TVDeconv', maxIter = 2, order = 3)

plt.figure(figsize=(12,12))
plt.imshow(np.imag(-dpc_phase[0,:,:,]),cmap = 'gray'); plt.colorbar()


# Tesing Code Snippets

## Dataset IO

In [None]:
import displayTools

# imgDir = '/Users/zfphil/Dropbox/Datasets/AQLM/2017_05_04 - 3D-DPC/'
# fileName= '2017_05_04_3d-dpc-test001.tif'

imgDir = '/Users/zfphil/Dropbox/Datasets/AQLM/2017_05_05 - Multi-Contrast/'
fileName = '2017_05_06_dpc+gfp_30sec_df.tif'

mt = displayTools.MultiTiff(imgDir + fileName)
frame = mt.read(channel = 0, time = 0, z = 0, \
                returnDimOrder = 'CZT', debugFlag = True, squeezeResult = False)

imgStack = io.imread(imgDir + fileName)
assert((imgStack[0,0,:,:] == frame[0,0,0,:,:]).all())


### Time Stack

In [None]:
import displayTools

imgDir = '/Users/zfphil/Dropbox/Datasets/AQLM/2017_05_05 - Multi-Contrast/'
fileName = '2017_05_06_dpc+gfp_30sec_df.tif'

mt = dt.MultiTiff(imgDir + fileName)
frame = mt.read(channel = 0, time = range(0,3), z = 0, \
                returnDimOrder = 'CZT', debugFlag = False, squeezeResult = False)

plt.figure(figsize=(8,3))
plt.subplot(131)
plt.imshow(frame[0,0,0,:,:])
plt.subplot(132)
plt.imshow(frame[0,0,1,:,:])
plt.subplot(133)
plt.imshow(frame[0,0,2,:,:])

## Wavelength (DPC) Stack

In [None]:
import displayTools

imgDir = '/Users/zfphil/Dropbox/Datasets/AQLM/2017_05_05 - Multi-Contrast/'
fileName = '2017_05_06_dpc+gfp_30sec_df.tif'

mt = dt.MultiTiff(imgDir + fileName)
frame = mt.read(channel = range(0,4), time = 0, z = 0, \
                returnDimOrder = 'CZT', debugFlag = False, squeezeResult = False)

plt.figure(figsize = (10, 3))
plt.subplot(141)
plt.imshow(frame[0,0,0,:,:])
plt.subplot(142)
plt.imshow(frame[1,0,0,:,:])
plt.subplot(143)
plt.imshow(frame[2,0,0,:,:])
plt.subplot(144)
plt.imshow(frame[3,0,0,:,:])

## 3D-DPC Stack

In [None]:
imgDir = '/Users/zfphil/Dropbox/Datasets/Brain-MIC 2017/Testing/'
fileName = '3d-dpc.tif'

mt = dt.MultiTiff(imgDir + fileName)
frame = mt.read(channel = range(0,4), z = [0,9], \
                returnDimOrder = 'CZT', debugFlag = False, squeezeResult = False)

plt.figure(figsize = (10, 3))
plt.subplot(241)
plt.imshow(frame[0,0,0,:,:])
plt.subplot(242)
plt.imshow(frame[1,0,0,:,:])
plt.subplot(243)
plt.imshow(frame[2,0,0,:,:])
plt.subplot(244)
plt.imshow(frame[3,0,0,:,:])

plt.subplot(245)
plt.imshow(frame[0,1,0,:,:])
plt.subplot(246)
plt.imshow(frame[1,1,0,:,:])
plt.subplot(247)
plt.imshow(frame[2,1,0,:,:])
plt.subplot(248)
plt.imshow(frame[3,1,0,:,:])

In [None]:
%matplotlib inline
#plot first set of measured DPC intensities
f,ax = plt.subplots(1,4, sharex=True, sharey=True, figsize=(12,3))
for plotIdx in range(4):
    ax[plotIdx].imshow(DPCObj.intensity[plotIdx],cmap='gray',clim=(-.2,.2),\
                       extent=[min(dataParams.xlin.real),max(dataParams.xlin.real),min(dataParams.ylin.real),max(dataParams.ylin.real)])
    ax[plotIdx].set_xlim(-5,5)
    ax[plotIdx].set_ylim(-5,5)
    ax[plotIdx].set_aspect(1)

In [None]:
DPCObj.reg_u=1
DPCObj.reg_p=1e-5
dpc_phase_l2 = DPCObj.solve()

In [None]:
DPCObj.reg_u=1
DPCObj.reg_p=1e-5
DPCObj.reg_TV = (1e-2, 1e-2)
DPCObj.rho = 5e-1
dpc_phase_1_iso = DPCObj.solve(method='TVDeconv')

In [None]:
DPCObj.reg_u=1
DPCObj.reg_p=1e-5
DPCObj.reg_TV = (1e-2, 1e-2)
DPCObj.rho = 5e-1
dpc_phase_2_iso = DPCObj.solve(method='TVDeconv', order = 2)

In [None]:
import contexttimer
with contexttimer.Timer() as t:
    DPCObj.reg_u=1
    DPCObj.reg_p=1e-5
    DPCObj.reg_TV = (1e-2,1e-2)
    DPCObj.rho = 5e-1
    dpc_phase_3_iso = DPCObj.solve(method='TVDeconv', order = 3)
    print(t.elapsed)

In [None]:
from opticsTools import pupilGen
from algorithms import Fourier
Fobj = Fourier(dataParams.dim,(-1,-2))
F = lambda x: Fobj.FourierTransform(x)
IF = lambda x: Fobj.InverseFourierTransform(x)
P2NA = pupilGen(dataParams,dataParams.NA*2.0)
phantom = np.zeros(dataParams.dim)
phantom = (1.5**-1)*(1.5**2-(dataParams.xlin.real[np.newaxis,:]-0.1)**2-(dataParams.ylin.real[:,np.newaxis]+0)**2)**0.5
phantom[(dataParams.xlin.real[np.newaxis,:]-0.1)**2+(dataParams.ylin.real[:,np.newaxis]+0)**2>1.5**2]=0
# phantom = np.angle(IF(F(np.exp(1j*phantom))*P2NA))
color_label = ['r','g','b','c']
for plotIdx,img in enumerate([dpc_phase_l2[0],dpc_phase_1_iso[0],dpc_phase_2_iso[0],dpc_phase_3_iso[0]]):
    plt.plot(img[260,:].imag,color=color_label[plotIdx])
    plt.xlim(160,350)
plt.scatter(np.arange(512),phantom[260,:],facecolors='none',edgecolors='k')

In [None]:
color_label = ['r','g','b','c','k']
for plotIdx,img in enumerate([dpc_phase_l2[0],dpc_phase_1_iso[0],dpc_phase_2_iso[0],dpc_phase_3_iso[0],1j*phantom]):
    plt.plot(np.gradient(img[260,:].imag),color=color_label[plotIdx])
    plt.xlim(160,350)

In [None]:
color_label = ['r','g','b','c','k']
for plotIdx,img in enumerate([dpc_phase_l2[0],dpc_phase_1_iso[0],dpc_phase_2_iso[0],dpc_phase_3_iso[0],1j*phantom]):
    plt.plot(np.gradient(np.gradient(img[260,:].imag)),color=color_label[plotIdx])
    plt.xlim(160,350)

In [None]:
color_label = ['r','g','b','c','k']
for plotIdx,img in enumerate([dpc_phase_l2[0],dpc_phase_1_iso[0],dpc_phase_2_iso[0],dpc_phase_3_iso[0],1j*phantom]):
    plt.plot(np.gradient(np.gradient(np.gradient(img[260,:].imag))),color=color_label[plotIdx])
    plt.xlim(160,350)

In [None]:
plt.figure(figsize=(10,6))
for plotIdx,img in enumerate([dpc_phase_l2.real,dpc_phase_1_iso.real,dpc_phase_2_iso.real,dpc_phase_3_iso.real,\
                              dpc_phase_l2.imag,dpc_phase_1_iso.imag,dpc_phase_2_iso.imag,dpc_phase_3_iso.imag]):
    plt.subplot('24'+str(plotIdx+1))
    val_limit = (0.95,1.05) if plotIdx<4 else (-.5,1.1)
    data_display = np.exp(img[0]) if plotIdx<4 else img[0]
    plt.imshow(data_display,clim=val_limit,cmap='gray',extent=[min(dataParams.xlin.real),max(dataParams.xlin.real),min(dataParams.ylin.real),max(dataParams.ylin.real)])
    if plotIdx>3: plt.xlabel('x, $\mu m$')
    if plotIdx==0 or plotIdx==4:plt.ylabel('y, $\mu m$')
    plt.xlim(-5,5)
    plt.ylim(-5,5)
    title_text = 'amplitude' if plotIdx<4 else 'phase'
    plt.title(title_text)

In [None]:
plt.figure(figsize=(16,4))
for plotIdx,img in enumerate([dpc_phase_l2.imag,dpc_phase_1_iso.imag,dpc_phase_2_iso.imag,dpc_phase_3_iso.imag,[phantom]]):
    plt.subplot('15'+str(plotIdx+1))
    plt.imshow(np.gradient(img[0],axis=0),clim=(-.1,.1),cmap='gray',extent=[min(dataParams.xlin.real),max(dataParams.xlin.real),min(dataParams.ylin.real),max(dataParams.ylin.real)])
    plt.xlabel('x, $\mu m$')
    if plotIdx==0:plt.ylabel('y, $\mu m$')
    plt.xlim(-5,5)
    plt.ylim(-5,5)
    plt.title('phase_gradient')

In [None]:
dpc_phase_gd,error_gd = DPCObj.solve(method='gradientDescent',maxIter=200,convex=True,plot_verbose=False,verbose=True)

In [None]:
plt.figure(figsize=(12,4))
plt.subplot(131)
plt.imshow(dpc_phase_gd[0].imag,cmap='gray',clim=(-0.5,1.1),extent=[min(dataParams.xlin.real),max(dataParams.xlin.real),min(dataParams.ylin.real),max(dataParams.ylin.real)])
plt.xlim(-5,5)
plt.ylim(-5,5)
plt.subplot(132)
plt.plot(dpc_phase_gd[0,260,:].imag,color='r')
plt.xlim(160,350)
plt.subplot(133)
plt.scatter(np.arange(len(error_gd[0])),np.log10(error_gd[0]),marker='o',color='b')

In [None]:
dpc_phase_FISTA,error_FISTA = DPCObj.solve(method='FISTA',maxIter=200,convex=True,plot_verbose=False,verbose=True)

In [None]:
plt.figure(figsize=(12,4))
plt.subplot(131)
plt.imshow(dpc_phase_FISTA[0].imag,cmap='gray',clim=(-0.5,1.1),extent=[min(dataParams.xlin.real),max(dataParams.xlin.real),min(dataParams.ylin.real),max(dataParams.ylin.real)])
plt.xlim(-5,5)
plt.ylim(-5,5)
plt.subplot(132)
plt.plot(dpc_phase_FISTA[0,260,:].imag,color='r')
plt.xlim(160,350)
plt.subplot(133)
plt.scatter(np.arange(len(error_FISTA[0])),np.log10(error_FISTA[0]),marker='o',color='b')

In [None]:
dpc_phase_Newton,error_Newton = DPCObj.solve(method='Newton',maxIter=3,verbose=True)

In [None]:
plt.figure(figsize=(12,4))
plt.subplot(131)
plt.imshow(dpc_phase_Newton[0].imag,cmap='gray',clim=(-0.5,1.1),extent=[min(dataParams.xlin.real),max(dataParams.xlin.real),min(dataParams.ylin.real),max(dataParams.ylin.real)])
plt.xlim(-5,5)
plt.ylim(-5,5)
plt.subplot(132)
plt.plot(dpc_phase_Newton[0,260,:].imag,color='r')
plt.xlim(160,350)
plt.subplot(133)
plt.scatter(np.arange(len(error_Newton[0])),np.log10(error_Newton[0]),marker='o',color='b')

In [None]:
pickle.dump(pyfftw.export_wisdom(),open('../wisdoms/wisdom_gallery.pickle','wb'))