In [None]:
from skimage import filters
import pickle
import os
from metadata import Metadata
import time
from skimage.morphology import selem
from skimage.morphology import dilation
from scipy.ndimage import gaussian_filter
import numpy as np
import pandas as pd
import itertools as it
from skimage.feature import peak_local_max
from skimage.measure import label,regionprops
from skimage.morphology import watershed


#This function loads images, substract background and append a list of selected images to df
def pfunc_zstk_filt(pname, threshold, md_local, acq=['tnfaip3'],
                    chan='FarRed', blur=(0.25,0.25,0.25), 
                   rel_thresh=0.2, npixels_thresh=6):
    print(pname, '\n')
    disk = selem.disk(2)
    stk = md_local.stkread(Position=pname, Channel=chan, acq=acq)#loading images
    
    #check if the type of stk is dictionary, print the position, channel, and aquisition
    if type(stk)==dict:
        print('stk empty')
        print('Pos: ',pname)
        print('Channel: ',chan)
        print('Acq: ',acq)
        
        
    # Failed to load files
    if isinstance(stk, dict):
        return None
    
   # background subtraction
    fstk = diffGauss(stk, blur)
    
    # returns the coordinates of local peaks (maxima) in an image
    peaks_coords = peak_local_max(fstk, threshold_abs=threshold,
                                  min_distance=3)
    peaks = peak_local_max(fstk, threshold_abs=threshold,
                           min_distance=3, indices=False)
    dfs = []
 
    for z in set(peaks_coords[:, 2]):
        img = fstk[:,:,z]
        labels = label(peaks[:, :, z])
        #watershed is used for image segmentation
        spot_labels = watershed(gaussian_filter(img, 1.2)*-1, labels, mask=dilation(peaks[:,:, z], selem=disk))
        spots = regionprops(spot_labels, img)#Measure properties of labeled image regions
        npixels = np.array([p.area for p in spots])#creates an array
        spot_means = np.array([p.mean_intensity for p in spots])
        spot_max = [p.max_intensity for p in spots]
        centroids = np.array([p.weighted_centroid for p in spots])
    #     rna = np.where((spot_means>threshold)&(npixels>=npixels_thresh))[0]
        rna = np.arange(len(spots))
        rna_coords = centroids[rna]
        spot_values = []
        for p in spots:
            vals = [img[y, x] for y, x in p.coords]
            spot_values.append(vals)#append coordinates to spot_value

        df = pd.DataFrame(np.stack([spot_means[rna], rna_coords[:,0],
                                    rna_coords[:, 1],
                                    tuple(it.repeat(z, len(spot_means[rna])))], axis=1),
                          columns = ['val', 'y', 'x', 'z'])#Join a sequence of arrays along new axises.
        df['pixel_values'] = spot_values
        dfs.append(df)
    
    return pname, fstk.max(axis=2), pd.concat(dfs, ignore_index=True), spot_labels


# This function sets blur magnitude and substract the background
def diffGauss(stk, blur = (0.5,0.5,0.)):
    stk = stk.astype('float64')
    stk = stk.copy()
    lp = gaussian_filter(stk, blur)
    stk = stk-lp
    stk = gaussian_filter(stk, (0.25,0.25,0.25))
    np.place(stk, stk<0, 0)
    return stk.astype('int16')

def spot_detection(md_path, data_list,threshold = 200,colors = ['Orange','FarRed'],sigmas = ['low','med','high']):
    result_path = os.path.join(md_path,'results')#function constructs a pathname out of one or more partial pathnames. 
   
    #if result_path does not exist, create result path
    if not os.path.exists(result_path):
        os.mkdir(result_path)
        
    md = Metadata(md_path)
    total_start = time.time()#record the start time
   
    #create files for each position in each acq, each sigma and each color
    #if color_path, sigma_path, acq_path, pos_path does not exist, create one
    #if they already exists, continue
    for color in colors:
        color_path = os.path.join(result_path,color)
        if not os.path.exists(color_path):
            os.mkdir(color_path)
        for sigma in sigmas:
            sigma_path = os.path.join(color_path,sigma)
            if not os.path.exists(sigma_path):
                os.mkdir(sigma_path)
            if sigma == 'low':
                blur = (2.2,2.2,2.2)
            elif sigma == 'med':
                blur = (5,5,5)
            elif sigma=='high':
                blur = (10,10,10)
            else:
                print('Unknown sigma: ',sigma)
                continue
            for acq in data_list:
                acq_path = os.path.join(sigma_path,acq)
                if os.path.exists(os.path.join(acq_path,'df.pkl')):
                    continue
                if not os.path.exists(acq_path):
                    os.mkdir(acq_path)
                start = time.time()
                poses = md.image_table[(md.image_table.Channel==color)&(md.image_table.acq==acq)].Position.unique()
                df_list = []
                for pos in poses:
                    pos_path = os.path.join(acq_path,pos)
                    if not os.path.exists(pos_path):
                        os.mkdir(pos_path)
                    if os.path.exists(os.path.join(pos_path,'pname.pkl')):
                        if os.path.exists(os.path.join(pos_path,'fstk_max.pkl')):
                            if os.path.exists(os.path.join(pos_path,'spot_labels.pkl')):
                                if os.path.exists(os.path.join(pos_path,'df.pkl')):
                                    df = pickle.load(open(os.path.join(pos_path,'df.pkl'),'rb'))
                                    df_list.append(df)
                                    print('')
                                    print(acq,sigma,color)
                                    continue
                    pname,fstk_max,df,spot_labels = pfunc_zstk_filt(pos, threshold, md, acq=acq,chan=color,blur=blur)
                    npixels = []
                    ssum = []
                    stdpixels = []
                    npixels_thresh = filters.threshold_otsu(np.array(df.val))
                    for pixel_values in df.pixel_values:
                        pixel_array = np.array(pixel_values)
                        npixels.append(len(pixel_array[pixel_array>npixels_thresh]))
                        ssum.append(np.sum(pixel_values))
                        stdpixels.append(np.std(pixel_values))
                    df['npixels'] = npixels
                    df['ssum'] = ssum
                    df['stdpixels'] = stdpixels
                    pickle.dump(pname,open(os.path.join(pos_path,'pname.pkl'),'wb'))
                    pickle.dump(fstk_max,open(os.path.join(pos_path,'fstk_max.pkl'),'wb'))
                    pickle.dump(df,open(os.path.join(pos_path,'df.pkl'),'wb'))
                    pickle.dump(spot_labels,open(os.path.join(pos_path,'spot_labels.pkl'),'wb'))
                    df_list.append(df)
                    print('')
                    print(acq,sigma,color,'time: ',time.time()-start)
                pickle.dump(df,open(os.path.join(acq_path,'df.pkl'),'wb'))
    print('Finished ')
    print('Total Time: ',time.time()-total_start)
    return

In [None]:
import matplotlib.pyplot as plt
def load_plot_spot_histograms(md_path,colors = ['Orange','FarRed'],sigmas = ['low','med','high'],exposures=['100ms','600ms']):
    %matplotlib inline
    results_path = os.path.join(md_path,'results')#join path components
    if colors==True:
        colors = os.listdir(results_path)#returns a list containing the names of the entries in the directory given by path
    for color in colors:
        print(color)
        color_path = os.path.join(results_path,color)
        if sigmas==True:
            sigmas = os.listdir(color_path)
        for sigma in sigmas:
            print(sigma)
            for exposure in exposures:
                sigma_path = os.path.join(color_path,sigma)
                for acq in os.listdir(sigma_path):
                    if exposure in acq:
                        acq_path = os.path.join(sigma_path,acq)
                        df = pickle.load(open(os.path.join(acq_path,'df.pkl'),'rb'))
                        plt.hist(np.log10(df.val),bins=100,label=acq,alpha=0.5) 
                plt.xlabel('intensity')
                plt.ylabel('counts')
                plt.title(str(color+' '+sigma+' '+exposure))
                plt.legend()
                plt.show()
    return

In [None]:
import matplotlib.pyplot as plt
def bg_signal_histogram(md_path,colors = ['Orange','FarRed'],sigmas = ['low','med','high'],exposures=['100ms','600ms']):
    %matplotlib inline
    results_path = os.path.join(md_path,'results')#join path components
    if colors==True:
        colors = os.listdir(results_path)#returns a list containing the names of the entries in the directory given by path
    for color in colors:
        print(color)
        color_path = os.path.join(results_path,color)
        if sigmas==True:
            sigmas = os.listdir(color_path)
        for sigma in sigmas:
            print(sigma)
            sigma_path = os.path.join(color_path,sigma)
            for acq in os.listdir(sigma_path):
                acq_path = os.path.join(sigma_path,acq)
                df = pickle.load(open(os.path.join(acq_path,'df.pkl'),'rb'))
                intensity_thresh = filters.threshold_otsu(np.array(df.val))
                npixels_thresh = 0
                ssum_thresh = filters.threshold_otsu(np.array(df.ssum))
                std_thresh = filters.threshold_otsu(np.array(df.stdpixels))
                good_df = df[(df.val>intensity_thresh)&(np.array(df.npixels)>npixels_thresh)&(np.array(df.ssum)>ssum_thresh)&(np.array(df.stdpixels)>std_thresh)]
                bad_df = df[(df.val<=intensity_thresh)&(np.array(df.npixels)<=npixels_thresh)&(np.array(df.ssum)<=ssum_thresh)&(np.array(df.stdpixels)<=std_thresh)]
                pickle.dump(good_df,open(os.path.join(acq_path,'good_df.pkl'),'wb'))
                pickle.dump(bad_df,open(os.path.join(acq_path,'bad_df.pkl'),'wb'))
                plt.hist(np.log10(good_df.val),bins=100,color='r',alpha=0.5)
                plt.hist(np.log10(bad_df.val),bins=100,color='b',alpha=0.5)
                plt.xlabel('intensity')
                plt.ylabel('counts')
                plt.title(str(color+' '+sigma+' '+acq))
                plt.legend()
                plt.show()
    return

  

In [None]:
import matplotlib.pyplot as plt
def bg_to_signal_ratio(md_path,md_path2,colors = ['Orange','FarRed'],sigmas = ['low','med','high'],exposures=['100ms','600ms']):
    %matplotlib inline
    results_path = os.path.join(md_path,'results')#join path components
    if colors==True:
        colors = os.listdir(results_path)#returns a list containing the names of the entries in the directory given by path
    for color in colors:
        print(color)
        color_path = os.path.join(results_path,color)
        if sigmas==True:
            sigmas = os.listdir(color_path)
        for sigma in sigmas:
            print(sigma)
            sigma_path = os.path.join(color_path,sigma)
            acq_dict = {}
            s2n_dict = {}
            for exposure in exposures:
                acq_dict[exposure]=[]
                s2n_dict[exposure]=[]

            for acq in os.listdir(sigma_path):
                print (acq)
                acq_path = os.path.join(sigma_path,acq)
                df = pickle.load(open(os.path.join(acq_path,'df.pkl'),'rb'))

                intensity_thresh = filters.threshold_otsu(np.array(df.val))
                npixels_thresh = 0
                ssum_thresh = filters.threshold_otsu(np.array(df.ssum))
                std_thresh = filters.threshold_otsu(np.array(df.stdpixels))
                good_df = df[(df.val>intensity_thresh)&(np.array(df.npixels)>npixels_thresh)&(np.array(df.ssum)>ssum_thresh)&(np.array(df.stdpixels)>std_thresh)]
                bad_df = df[(df.val<=intensity_thresh)&(np.array(df.npixels)<=npixels_thresh)&(np.array(df.ssum)<=ssum_thresh)&(np.array(df.stdpixels)<=std_thresh)]
                pickle.dump(df,open(os.path.join(acq_path,'df.pkl'),'wb'))
                good_df_mean=np.mean(good_df.val)
                bad_df_mean=np.mean(bad_df.val)
                ratio=good_df_mean/bad_df_mean
                for exposure in exposures:
                    if exposure in acq:
                        acq_dict[exposure].append(int(acq.split('nM')[0]))
                        s2n_dict[exposure].append(ratio)
                #acq_dict.append(acq)
                #s2n_dict.append(ratio)
                print('good_df_mean: ',good_df_mean)
                print('good_df_mean: ',bad_df_mean)
                print('background_to_signal ratio: ',ratio)
               
            
            sigma_path = os.path.join(md_path2,'results',color,sigma)
            for acq in os.listdir(sigma_path):
                print(acq)
                acq_path = os.path.join(sigma_path,acq)
                df = pickle.load(open(os.path.join(acq_path,'df.pkl'),'rb'))
                intensity_thresh = filters.threshold_otsu(np.array(df.val))
                npixels_thresh = 0
                ssum_thresh = filters.threshold_otsu(np.array(df.ssum))
                std_thresh = filters.threshold_otsu(np.array(df.stdpixels))
                good_df= df[(df.val>intensity_thresh)&(np.array(df.npixels)>npixels_thresh)&(np.array(df.ssum)>ssum_thresh)&(np.array(df.stdpixels)>std_thresh)]
                bad_df = df[(df.val<=intensity_thresh)&(np.array(df.npixels)<=npixels_thresh)&(np.array(df.ssum)<=ssum_thresh)&(np.array(df.stdpixels)<=std_thresh)]
                good_df_mean=np.mean(good_df.val)
                bad_df_mean=np.mean(bad_df.val)
                ratio=good_df_mean/bad_df_mean
                for exposure in exposures:
                    if exposure in acq:
                        acq_dict[exposure].append(int(acq.split('nM')[0]))
                        s2n_dict[exposure].append(ratio)
                #acq_dict.append(acq)
                #s2n_dict.append(ratio)
                print('good_df_mean: ',good_df_mean)
                print('good_df_mean: ',bad_df_mean)
                print('background_to_signal ratio: ',ratio)
            for exposure in exposures:
                plt.scatter(acq_dict[exposure], s2n_dict[exposure])
                plt.xlabel('[PER] (nM)')
                plt.ylabel('signal_to_background ratio')
                plt.title(str(color+' '+sigma+' '+exposure))
                plt.legend()
                plt.show()
               
    return