# Analysis of puncta size and intensity in individual cell nuclei

Detect and segment puncta and analyze their size and intensity. Calculate puncta statistics per cell nucleus.

## Requirements
- A folder with images that should be analyzed.  All z-layers for a specific sample must be combined into a single file. To combine z-layers and channels, run [images_to_stack.ipynb](images_to_stack.ipynb).
- A folder with segmented cell nuclei. To segment cell nuclei, run [segment_cells.ipynb](segment_cells.ipynb).

## Config

### The following code imports and declares functions used for the processing:

In [None]:
#################################
#  Don't modify the code below  #
#################################

import intake_io
import os
import numpy as np
import pylab as plt
import seaborn as sns
from skimage import io
import pandas as pd
import xarray as xr
from scipy import ndimage
from tqdm import tqdm
from scipy.stats import entropy
from skimage.measure import regionprops_table

from am_utils.utils import walk_dir, imsave
from am_utils.parallel import run_parallel
from lib import mutual_information_2d, segment_puncta


## Data & parameters

### Data
`input_dir`: folder with images to be analyzed

`segm_dir`: folder with segmented cell nuclei

`output_dir`: folder to save results

### Channel parameters

`channel_names`: list of channel names, e.g `['GFP', 'DNA']`

`puncta_channels`: list of channels to use for puncta segmentation (e.g. `["GFP"]`)


### Puncta detection parameters

`minsize_um`: minimal sigma for the Laplacian of Gaussian detection (microns); default is 0.2

`maxsize_um`: maximal sigma for the Laplacian of Gaussian detection (microns); default is 2

`num_sigma`: number of sigma values for the Laplacian of Gaussian detection; default is 5

`overlap`: a value between 0 and 1; if two blobs overlaps by a fraction greater than this value, the smaller blob is eliminated; default is 1 (blobs are removed only if overlapping completely)

`threshold_detection`: threshold for detecting LoG blobs. The absolute lower bound for scale space maxima. Local maxima smaller than thresh are ignored. Reduce this to detect blobs with less intensities

`threshold_segmentation`: Threshold for puncta segmentation. The way the threshold is applied is determined by `segmentation_mode`. For mode 0, choose values in the order of 0.001; for mode 1, choose values in the order of 50; for mode 2, choose values in the order of 3. Reduce to detect more/larger puncta, increase to detect fewer/smaller puncta

`threshold_detection` and `threshold_segmentation` for mode 0 should be close to 0, and can be both positive and negative

`threshold_background`: threshold used to post-filter puncta in cells with diffuse signal. This threshold is provided relative to the median GFP intensity inside cells (e.g, `threshold_background` = 2 will result in all puncta with intensity lower than two median GPF (background) intensities being removed). Set to 0 to keep all puncta.

`segmentation_mode`: determines the mode how `threshold_segmentation` is applied; 0: apply absolute threshold in LoG space; 1: apply threshold relative to background in LoG space; 2: apply threshold relative to the background in image intensity space.

### Other parameters

`max_threads`: number of processes to run in parallel; default is 30

## Specify data paths and analysis parameters

In [None]:
input_dir = "/research/sharedresources/cbi/common/Anna/test/input"
segm_dir = "/research/sharedresources/cbi/common/Anna/test/Analysis/cell_segmentation"
output_dir = "/research/sharedresources/cbi/common/Anna/test/Analysis"

cell_stats_dir = 'quantification_cells'
puncta_stats_dir = 'quantification_puncta'
puncta_segm_dir = 'puncta_segmentation_with_raw'

channel_names = ["DNA", "GFP"]
puncta_channels = ["GFP"]

minsize_um = 0.2  # minimal sigma for the Laplacian of Gaussian detection (microns)
maxsize_um = 2  # maximal sigma for the Laplacian of Gaussian detection (microns)
num_sigma = 5  # number of sigma values for the Laplacian of Gaussian detection

overlap = 1  # A value between 0 and 1. If two blobs overlaps by a fraction greater than this value, the smaller blob is eliminated.
threshold_detection = 0.001 # The absolute lower bound for scale space maxima. Local maxima smaller than thresh are ignored. Reduce this to detect blobs with less intensities
threshold_segmentation = 50  # Threshold for puncta segmentation in the LoG scale space. Reduce to detect more/larger puncta, increase to detect fewer/smaller puncta

threshold_background = 3

max_threads = 20    # number of processes to run in parallel
segmentation_mode = 1

### The following code lists all datasets in the input directory:

In [None]:
#################################
#  Don't modify the code below  #
#################################
samples = walk_dir(input_dir)

print(f'{len(samples)} images were found:')
print(np.array(samples))

### The following code segments and quantifies puncta in all input images:

In [None]:
#################################
#  Don't modify the code below  #
#################################

def quantify(item, input_dir, output_dir, segm_dir=None, 
             channel_names=None, puncta_channels=None, output_dir_puncta=None, output_dir_puncta_segm=None,
             threshold_detection=None, threshold_segmentation=None, **puncta_kwargs):
    sample = item
    dataset = intake_io.imload(sample, metadata={"coords": {'c': channel_names}})
    sample_name = sample[len(input_dir):].replace(sample.split('.')[-1], '')
        
    scale = np.array([dataset['z'][1], dataset['y'][1], dataset['x'][1]])
    channels = dataset['c'].data
    
    # load cell segmentation
    cells = np.ones_like(np.array(dataset.loc[dict(c=channels[0])]['image'].data))    
    if segm_dir is not None:
        segm_fn = sample.replace(input_dir, segm_dir).replace(sample.split('.')[-1], 'tif')
        if os.path.exists(segm_fn):
            cells = io.imread(segm_fn)
            
    dist_to_border = ndimage.morphology.distance_transform_edt(cells > 0, sampling=scale)
    
          
    # compute cell stats
    cell_stats = pd.DataFrame(regionprops_table(label_image=cells,
                                                properties=['label', 'area', 'centroid']))
    cell_stats = cell_stats.rename(columns={'area': 'cell volume pix', 
                                            'centroid-0': 'z', 
                                            'centroid-1': 'y',
                                            'centroid-2': 'x',
                                            'label': 'cell label'})
    cell_stats['cell volume um'] = cell_stats['cell volume pix']*np.prod(scale)
    
    for cind, channel in enumerate(channels):
        channel_data = np.array(dataset.loc[dict(c=channel)]['image'].data)
        intensity_stats = regionprops_table(label_image=cells,
                                            intensity_image=channel_data,
                                            properties=['label', 'mean_intensity'])
        cell_stats[channel + ' mean intensity per nucleus'] = intensity_stats['mean_intensity']
        cell_stats[channel + 
                   ' integrated intensity per nucleus'] = cell_stats[channel + 
                                                                     ' mean intensity per nucleus'] * cell_stats['cell volume pix']
        cell_stats[channel + ' mean background intensity'] = np.mean(channel_data[np.where(cells == 0)])
        cell_stats[channel + ' integrated background intensity'] = np.sum(channel_data[np.where(cells == 0)])
        
        other_channels = []
        for other_channel in channels[cind+1:]:
            other_channels.append((other_channel, np.array(dataset.loc[dict(c=other_channel)]['image'].data)))
        
        for i in range(len(cell_stats)):
            cur_cell_pix = np.where(cells == cell_stats['cell label'].iloc[i])
            cell_stats.at[i, channel + ' entropy'] = entropy(np.histogram(channel_data[cur_cell_pix], bins=channel_data.max())[0])
            
            for channel2, channel_data2 in other_channels:
                cell_stats.at[i, 'Mutual information ' + channel + 
                              ' vs ' + channel2] = mutual_information_2d(channel_data[cur_cell_pix], channel_data2[cur_cell_pix],
                                                                               bins=max([channel_data[cur_cell_pix].max(), 
                                                                                         channel_data2[cur_cell_pix].max()]))
                cell_stats.at[i, 'Pearson correlation ' + channel + 
                              ' vs ' + channel2] = np.corrcoef(channel_data[cur_cell_pix]*1., channel_data2[cur_cell_pix]*1.)[0,1]
        del other_channels
      
    cell_stats['condition'] = sample.split('/')[-2]
    cell_stats['sample'] = sample.split('/')[-1]
    
    if output_dir_puncta_segm is not None:
        output_stack = np.zeros((len(channel_names) + len(puncta_channels) + 1,) + cells.shape)
        for ch_ind, chname in enumerate(channel_names):
            output_stack[ch_ind] = np.array(dataset.loc[dict(c=chname)]['image'].data)
        output_stack[-1] = cells
    
    # segment puncta
    if puncta_channels is not None:
        puncta_stats_all = pd.DataFrame()
        threshold_detection = np.ravel(threshold_detection)
        threshold_segmentation = np.ravel(threshold_segmentation)
        if len(threshold_detection) == 1:
            threshold_detection = np.ones(len(puncta_channels))*threshold_detection[0]
        if len(threshold_segmentation) == 1:
            threshold_segmentation = np.ones(len(puncta_channels))*threshold_segmentation[0]
        for pc_ind, puncta_channel in enumerate(puncta_channels):
            puncta_channel_data = np.array(dataset.loc[dict(c=puncta_channel)]['image'].data)
            puncta = segment_puncta(puncta_channel_data, cells, scale,
                                    threshold_detection=threshold_detection[pc_ind],
                                    threshold_segmentation=threshold_segmentation[pc_ind],
                                    **puncta_kwargs)
       
            # compute puncta stats
            puncta_stats = pd.DataFrame(regionprops_table(label_image=puncta, 
                                                          intensity_image=dist_to_border,
                                                          properties=['label', 'centroid', 'area', 'mean_intensity']
                                                         ))
            puncta_stats = puncta_stats.rename(columns={
                                                        'mean_intensity': 'distance to nucleus border um',
                                                        'centroid-0': 'z', 
                                                        'centroid-1': 'y',
                                                        'centroid-2': 'x',
                                                        'area': 'volume_pix'
                                                       })
            puncta_stats['volume_um'] = puncta_stats['volume_pix']*np.prod(scale)
            puncta_stats['cell_label'] = cells[np.int_(np.round_(puncta_stats['z'])),
                                               np.int_(np.round_(puncta_stats['y'])),
                                               np.int_(np.round_(puncta_stats['x']))]
            
            # remove puncta outside cells (if any cells were detected)
            if cells.max() > 0:
                cell_label = np.array(puncta_stats['cell_label'])
                ind = np.where(cell_label == 0)
                bv = np.unique(np.array(puncta_stats['label'])[ind])
                ix = np.in1d(puncta.ravel(), bv).reshape(puncta.shape)
                puncta[ix] = 0 

                ind = np.where(cell_label > 0)
                puncta_stats = puncta_stats.iloc[ind].reset_index(drop=True)
    
            # remove large puncta
            maxvol = 4./3*np.pi * (puncta_kwargs['maxsize_um']*3)**3
            l_large = puncta_stats[puncta_stats['volume_um'] > maxvol]['label']
            for l in l_large:
                puncta[np.where(puncta == l)] = 0           
            puncta_stats = puncta_stats[puncta_stats['volume_um'] <= maxvol].reset_index(drop=True)
            
            if output_dir_puncta_segm is not None:
                output_stack[len(channel_names) + pc_ind] = puncta
            
            
            # compute cell stats
            for i in range(len(cell_stats)):
                current_cell = puncta_stats[puncta_stats['cell_label'] == cell_stats['cell label'].iloc[i]]
                cell_stats.at[i, rf'number of {puncta_channel} puncta'] = len(current_cell)
                if len(current_cell) > 0:
                    cell_stats.at[i, rf'total {puncta_channel} puncta volume per nucleus um'] = np.sum(current_cell['volume_um'])
                    cell_stats.at[i, rf'average {puncta_channel} puncta volume per nucleus um'] = np.mean(current_cell['volume_um'])
                    cell_stats.at[i, rf'total {puncta_channel} puncta volume per nucleus pix'] = np.sum(current_cell['volume_pix'])
                    cell_stats.at[i, rf'average {puncta_channel} puncta volume per nucleus pix'] = np.mean(current_cell['volume_pix'])
                    cell_stats.at[i, rf'average {puncta_channel} puncta distance to nucleus border um'] = np.mean(current_cell['distance to nucleus border um'])
                else:
                    cell_stats.at[i, rf'total {puncta_channel} puncta volume per nucleus um'] = 0
                    cell_stats.at[i, rf'average {puncta_channel} puncta volume per nucleus um'] = 0
                    cell_stats.at[i, rf'total {puncta_channel} puncta volume per nucleus pix'] = 0
                    cell_stats.at[i, rf'average {puncta_channel} puncta volume per nucleus pix'] = 0
                    cell_stats.at[i, rf'average {puncta_channel} puncta distance to nucleus border um'] = 0
    
    
            for channel in channels:
                channel_data = np.array(dataset.loc[dict(c=channel)]['image'].data)
                
                # intensity stats per puncta
                intensity_stats = regionprops_table(label_image=puncta,
                                                    intensity_image=channel_data,
                                                    properties=['label', 'mean_intensity'])
                
                puncta_stats[channel + ' mean intensity per puncta'] = intensity_stats['mean_intensity']
                puncta_stats[channel + 
                             ' integrated intensity per puncta'] = intensity_stats['mean_intensity'] * puncta_stats['volume_pix']
                
                # intensity stats per cells inside/outside puncta            
                for label_img, location in zip([cells*(puncta > 0), cells*(puncta == 0)], 
                                               [rf'inside {puncta_channel} puncta', rf'outside {puncta_channel} puncta']):
                
                    intensity_stats = regionprops_table(label_image=label_img,
                                                        intensity_image=channel_data,
                                                        properties=['label', 'area', 'mean_intensity'])
                    ind = cell_stats[cell_stats['cell label'].isin(intensity_stats['label'])].index

                    cell_stats.at[ind, channel + ' mean intensity ' + location] = intensity_stats['mean_intensity']
                    cell_stats.at[ind, channel + ' integrated intensity ' + 
                                  location] = np.int_(intensity_stats['mean_intensity'] * intensity_stats['area'])
         
        
            puncta_stats['channel'] = puncta_channel
            puncta_stats_all = pd.concat([puncta_stats_all, puncta_stats], ignore_index=True)
     
        puncta_stats_all['condition'] = sample.split('/')[-2]
        puncta_stats_all['sample'] = sample.split('/')[-1]
        
        for key in puncta_kwargs.keys():
            puncta_stats_all[key] = puncta_kwargs[key]
        puncta_stats_all['threshold_detection'] = threshold_detection[0]
        puncta_stats_all['threshold_segmentation'] = threshold_segmentation[0]
        
        if output_dir_puncta is not None:
            os.makedirs(os.path.dirname(output_dir_puncta + sample_name + 'csv'), exist_ok=True)
            puncta_stats_all.to_csv(output_dir_puncta + sample_name + 'csv', index=False)
    
    for key in puncta_kwargs.keys():
        cell_stats[key] = puncta_kwargs[key]
    cell_stats['threshold_detection'] = threshold_detection[0]
    cell_stats['threshold_segmentation'] = threshold_segmentation[0]
        
    # save the cell stats
    os.makedirs(os.path.dirname(output_dir + sample_name + 'csv'), exist_ok=True)
    cell_stats.to_csv(output_dir + sample_name + 'csv', index=False)
    
    if output_dir_puncta_segm is not None:
        
        output_stack = xr.Dataset(data_vars=dict(image=(dataset['image'].dims, output_stack.astype(np.uint16))),
                                  coords=dict(c=channel_names + 
                                              [cn + ' segmentation' for cn in puncta_channels] + 
                                              ['Nuclei segmentation'], 
                                              x=dataset.coords['x'], y=dataset.coords['y'], z=dataset.coords['z']),
                                  attrs=dataset.attrs)
        output_stack['image'].attrs = dataset['image'].attrs
        
        os.makedirs(os.path.dirname(output_dir_puncta_segm + sample_name + 'tif'), exist_ok=True)
        intake_io.imsave(output_stack, output_dir_puncta_segm + sample_name + 'tif')
        
    

In [None]:
%%time
#################################
#  Don't modify the code below  #
#################################

# specify the analysis arguments
kwargs = dict()
kwargs['items'] = samples
kwargs['input_dir'] = input_dir
kwargs['segm_dir'] = segm_dir
kwargs['output_dir'] = os.path.join(output_dir, cell_stats_dir)
kwargs['output_dir_puncta'] = os.path.join(output_dir, puncta_stats_dir)
kwargs['output_dir_puncta_segm'] = os.path.join(output_dir, puncta_segm_dir)

kwargs['channel_names'] = channel_names
kwargs['puncta_channels'] = puncta_channels
kwargs['max_threads'] = max_threads
kwargs['minsize_um'] = minsize_um
kwargs['maxsize_um'] = maxsize_um
kwargs['num_sigma'] = num_sigma
kwargs['overlap'] = overlap
kwargs['threshold_detection'] = threshold_detection
kwargs['threshold_segmentation'] = threshold_segmentation
kwargs['threshold_background'] = threshold_background
kwargs['segmentation_mode'] = segmentation_mode

# run the analysis in parallel
run_parallel(process=quantify, **kwargs)

In [None]:
# combine stats
for stats_dir in [cell_stats_dir, puncta_stats_dir]:
    if os.path.exists(os.path.join(output_dir, stats_dir)):
        stat = pd.DataFrame()
        for fn in walk_dir(os.path.join(output_dir, stats_dir)):
            stat = pd.concat([stat, pd.read_csv(fn)], ignore_index=False)
        stat.to_csv(os.path.join(output_dir, stats_dir + '.csv'), index=False)
stat

### The following code plots cell stats over conditions:

In [None]:
#################################
#  Don't modify the code below  #
#################################

os.makedirs(output_dir + '/plots', exist_ok=True)
stats = pd.read_csv(os.path.join(output_dir, cell_stats_dir + '.csv'))
for col in stats.columns:
    if not col in ['cell label', 'condition', 'sample']:
        plt.figure(figsize=(15, 6))
        ax = sns.boxplot(x = 'condition', y=col, data=stats) 
        plt.savefig(output_dir + '/plots/' + col.replace(' ', '_') + '.png')   