In [1]:
import numpy as np

from matplotlib.widgets import Button, Slider
import matplotlib.pyplot as plt
%matplotlib

from skimage.feature import register_translation
from skimage.morphology import binary_dilation
from skimage import transform as tf
from skimage import img_as_uint

import tifffile as tif

from PIL import Image, ImageDraw


from sklearn.neural_network import MLPClassifier
import pickle

import copy

import time

Using matplotlib backend: Qt5Agg


In [2]:
def radial_intensity_profile(image, cell_centre, w=15, numedges=36):
    
    radprofile = []
    radcoords_x = []
    radcoords_y = []

    for i in range(w):
        endx, endy = [], []
        rp = []
        for ii in range(numedges):

            a = (360/numedges) * ii

            epx = int(i*np.cos(np.deg2rad(a))) + cell_centre[0]
            epy = int(i*np.sin(np.deg2rad(a))) + cell_centre[1]

            endx.append(epx), endy.append(epy)

            rp.append(image[epx,epy])

        radprofile.append(rp)
        radcoords_x.append(endx)
        radcoords_y.append(endy)

    radprofile = np.array(radprofile)
    radcoords_x = np.array(radcoords_x)
    radcoords_y = np.array(radcoords_y)
    
    return radprofile, radcoords_x, radcoords_y


def smooth_radial_edge_profile(raw_edges, window):
    
    smooth_edges = np.concatenate((raw_edges,raw_edges,raw_edges))
    smooth_edges = smoothen(smooth_edges,window)
    smooth_edges = smooth_edges[len(raw_edges):2*len(raw_edges)]
    
    smooth_edges = [int(x) for x in smooth_edges]
    
    return smooth_edges


def smoothen(data,window):
    """sliding window average of input signal"""
    w = np.ones(window)/window
    return np.convolve(data,w,'same')


def create_mask_from_polygon(width, height, polygon_xs, polygon_ys):

    maskable = []
    for i in range(len(polygon_xs)):
        maskable.append((polygon_xs[i],polygon_ys[i]))
        
    img = Image.new('L', (width, height), 0)
    ImageDraw.Draw(img).polygon(maskable, outline=1, fill=1)
    mask = np.array(img)

    return mask


def extract_dff(imagestack, masks, motion_correct = False, template_frame = 0, bsframes = 0, tavg = 1):

    numRois = masks.shape[0]
    nframes = imagestack.shape[0]
    
    shifts = []
    
    if motion_correct:
        
        frames = []
        x_off, y_off = [],[]
        for i in range(nframes):

            offset_image = imagestack[i]
            shift, error, diffphase = register_translation(imagestack[template_frame], offset_image,10)
            
            tform = tf.SimilarityTransform(scale=1, rotation=0,
                                           translation=(-shift[1], -shift[0]))

            mc = tf.warp(offset_image, tform)

            frames.append(img_as_uint(mc))
                        
            x_off.append(shift[1])
            y_off.append(shift[0])
            
        shifts.append(x_off)
        shifts.append(y_off)
    
    else:
        frames = imagestack
    
    
    dff = []
    for i in range(numRois):
                
        maskarea = np.sum(masks[i])
                
        trace = []
        for ii in range(nframes):
            
            frame = frames[ii]
            frame[np.isnan(frame)] = 0
            
            trace.append(np.sum(masks[i]*frame)/maskarea)
        
        if bsframes == 0:
            F = np.percentile(trace,10)
            trace = [(x-F)/F for x in trace]

        else:
            F = get_baseline(trace, bsframes)
            for t in range(len(trace)):
                trace[t] = (trace[t] - F[t]) / F[t] 
        
        dff.append(trace)
        
        
    if tavg > 1:
        t = []
        for i in range(len(dff)):
            t_ = []
            for ii in np.arange(0,nframes,tavg):
                if ii+tavg < nframes:
                    t_.append(np.mean(dff[i][ii:ii+tavg]))
            t.append(t_)
        
    else:
        t = dff
        
    return np.array(t), shifts


def get_baseline(data, w):
    
    # make sure w is even
    w = w + np.remainder(w,2)
    hw = int(w/2)

    bl = []
    
    for i in range(len(data)):
    
        if i < hw:
            b = np.percentile(data[0:w], 10)
            
        if i > hw and len(data)-i > hw:
            b = np.percentile(data[i-hw:i+hw], 10)
            
        if i > hw and len(data)-i < hw:
            b = np.percentile(data[-w::], 10)
            
        bl.append(b)
    
    return bl


def adjust_gamma(image, gamma):
    
    gamma_corrected = ((image/np.max(image))**gamma)*np.max(image)
    
    return gamma_corrected


In [68]:
#fname = '/home/sriram/ncbs/analysis_codes/Fictor_Analysis_Codes/Calcium_Imaging_Analysis/Motion_Correction/0000.tif'
fname = "/mnt/sriramn@shares/Data/2PImaging/181218/F2_T1/0000.tif"

numtoavg = 100
template = 1020 #index of frame to use as template

im = tif.imread(fname)

roinn = "trained_roi_nn_9.sav"

save = True
#saveas = 'testdata.npz'
saveas = fname.split('.tif')[0]+str('_results_roinn.npz')

w = 15
numedges = 36

padwidth = 25 # pad with zeros to let the neural network draw freely around the edges. will be cropped later

imaging_fps = 7.8125
temporal_downsample_factor = 1

baseline_frames = 468 # window to estimate baseline

motion_correct = True

clf = pickle.load(open(roinn, 'rb'))

In [69]:
image = im[0]

avg = []
for i in range(template,template+numtoavg):
    
    offset_image = im[i]
    shift, error, diffphase = register_translation(image, offset_image,10)
    
    tform = tf.SimilarityTransform(scale=1, rotation=0,
                                   translation=(-shift[1], -shift[0]))
    
    a = tf.warp(im[i], tform)
    avg.append(img_as_uint(a))

avg = np.mean(np.array(avg),0)
avg = np.pad(avg,(padwidth,padwidth),mode='constant')


xpixels = avg.shape[1]
ypixels = avg.shape[0]


fig, ax = plt.subplots(figsize=[12,6])
plt.tight_layout()

ax.imshow(avg, cmap='gray')

coords = []

quitax = plt.axes([0.88, 0.025, 0.095, 0.08])
quit_button = Button(quitax, 'Done '+'\U0001f604', color='lightgreen', hovercolor='yellow')

clearax = plt.axes([0.78, 0.025, 0.095, 0.08])
clear_button = Button(clearax, 'Clear', color='lightgreen', hovercolor='yellow')

holdax = plt.axes([0.68, 0.025, 0.095, 0.08])
hold_button = Button(holdax, 'Hold', color='lightgreen', hovercolor='pink')


def exit(event):
    
    plt.close()

quit_button.on_clicked(exit)


hld = False

def hold(event):
    
    global hld
    
    hld = not hld
    
    if hld:
        hold_button.color = 'orange'
    else:
        hold_button.color = 'lightgreen'
        
    
hold_button.on_clicked(hold)


xcoords,ycoords = [],[]

def clear(event):
    
    global xcoords
    global ycoords
    
    if len(xcoords) != 0 and len(ycoords) != 0:
        xcoords = xcoords[:-1]
        ycoords = ycoords[:-1]
                
        ax.cla()
        ax.imshow(adjust_gamma(avg,gamma),cmap='gray')
        ax.scatter(xcoords,ycoords,color='yellow',s=10)

        fig.canvas.draw()
        fig.canvas.flush_events()
    
clear_button.on_clicked(clear)


def onclick(event):
    
    if not hld:
    
        global ix, iy

        try:
            ix, iy = int(event.xdata), int(event.ydata)
            if ix != 0 and iy != 0:
                ax.scatter(ix,iy, color='yellow', s=10)
                fig.canvas.draw()
                fig.canvas.flush_events()

                xcoords.append(ix)
                ycoords.append(iy)

        except:
            pass

cid = fig.canvas.mpl_connect('button_press_event', onclick)


axgamma = plt.axes([0.1, 0.025, 0.4, 0.08])
sgamma = Slider(axgamma, 'Gamma', 0, 1, valinit=1, valstep=0.05)


gamma = 1.0

def set_gamma(val):
    
    global gamma
    gamma = val
    
    max_intensity = np.max(avg)

    corrected_avg = adjust_gamma(avg,gamma)
    
    ax.clear()
    ax.imshow(corrected_avg, cmap='gray', vmin=0,vmax=max_intensity)
    
    ax.scatter(xcoords,ycoords, color='yellow', s=10)
    fig.canvas.draw()
    fig.canvas.flush_events()
    
sgamma.on_changed(set_gamma)

  .format(dtypeobj_in, dtypeobj_out))


0

In [70]:
ncells = len(xcoords)

rois = []
for n in range(ncells):
    
    rprof, rcx, rcy = radial_intensity_profile(avg,cell_centre=[ycoords[n],xcoords[n]],w=w,numedges=numedges)

    
    edges = []
    for i in range(rprof.shape[1]):
        X_pred_prob = clf.predict_proba(rprof[:,i].reshape(1,-1))
        l = np.argmax(X_pred_prob)
        edges.append(l)
        
    edges = smooth_radial_edge_profile(edges,2)
    
    roix = []
    roiy = []
    for j in range(len(edges)):

        cox = rcx[:,j] 
        coy = rcy[:,j]

        roix.append(cox[edges[j]])
        roiy.append(coy[edges[j]])    

    roix.append(roix[0])
    roiy.append(roiy[0])
    
    
    rois.append([roiy,roix])

    
plt.figure(figsize=[12,3])
plt.tight_layout()

plt.imshow(adjust_gamma(avg,gamma),cmap='gray')

for i in range(ncells):
    plt.plot(rois[i][0],rois[i][1])

In [71]:
masks = []

for i in range(ncells):
    m = create_mask_from_polygon(xpixels,ypixels,rois[i][0],rois[i][1])
    masks.append(m)
        
masks = np.array(masks)

masks_refined = []

for i in range(ncells):
    blank = sum(masks[np.arange(0,ncells,1)!=i])
    m = masks[i]-blank
    m[m!=1] = 0
    
    m = m[padwidth:-padwidth,padwidth:-padwidth]
    
    masks_refined.append(m)

masks_refined = np.array(masks_refined)

s = time.time()

traces, shifts = extract_dff(im, masks_refined, motion_correct = motion_correct, template_frame = template,
                             bsframes = baseline_frames, tavg = temporal_downsample_factor)

print("Elapsed time : {}s".format(round(time.time() - s, 0)))

  .format(dtypeobj_in, dtypeobj_out))


Elapsed time : 397.0s


In [72]:
fps = imaging_fps/temporal_downsample_factor
tstamp = np.arange(0,len(traces[0])/fps,1/fps)

plt.figure()
plt.plot(tstamp,shifts[0])
plt.plot(tstamp,shifts[1])

offset = 5

plt.figure()
for i in range(ncells):
    plt.plot(tstamp,traces[i]+offset*i)

In [None]:
if save:
    np.savez(saveas, traces=traces, shifts=shifts, padwidth=padwidth, 
             centres=[xcoords,ycoords], rois=rois, masks_refined=masks_refined,
             temporal_downsample_factor = temporal_downsample_factor)