# SMorph: Group Analysis

---
## Setup
Please execute the cell(s) below to initialize the notebook environment.

In [None]:
# @title Install dependencies
!pip install poetry

In [None]:
# @title Install SMorph Python module
!pip install https://github.com/swanandlab/SMorph/releases/download/0.1.0/SMorph-0.1.0.tar.gz

In [None]:
import warnings
warnings.filterwarnings('ignore')  # Suppress annoying warnings

import smorph as sm
import pandas as pd
import ipywidgets as widgets
pd.set_option('display.max_rows', None)  # remove upper limit on display of rows

---
## Step 1: Define groups and image preprocessing parameters

- Set `LABELS` as list of strings containing name labels of each group
- Set `GROUP_FOLDERS` as list of strings containing paths to each group folder
- Set `IMG_TYPE` as string to select the image acquisition method
- Set `CROP_TECH` as string to select the cropping method used to crop out cells from tissue image

In [None]:
# group labels
LABELS = ['SAL_28_MLUB_3D', 'DMI_28_MLUB_3D', 'FLX_28_MLUB_3D']  #@param

# input the path to individual group folders (place them in SMorph)
GROUP_FOLDERS = ['Autocropped/SAL_28_MLUB_3D', 'Autocropped/DMI_28_MLUB_3D', 'Autocropped/FLX_28_MLUB_3D']  #@param

SCALE = (1.0785801681301463, 0.6918881978764917, 0.6918881978764917)

In [None]:
# @title Interactive image parameter selection
IMG_TYPE = 'confocal'
CROP_TECH = 'auto'

def select_image_params (img_type=IMG_TYPE, crop_tech=CROP_TECH):
  global IMG_TYPE, CROP_TECH
  IMG_TYPE = img_type
  CROP_TECH = crop_tech

_ = widgets.interact(select_image_params, img_type=['confocal', 'DAB'], crop_tech=['manual', 'auto'])

Cell image preprocessing parameters:
* `min_ptile` and `max_ptile`: minimum and maximum contrast percentiles to stretch the image to
* `threshold_method`: method for single intensity auto-thresholding the cell image

Overlayed contour represents the thresholding results.

In [None]:
# @title Interactive preprocessing parameter selection

CONTRAST_PTILES = (0, 100)
THRESHOLD_METHOD = None #sm.util.THRESHOLD_METHODS[4]

import matplotlib.pyplot as plt
import skimage.io as io
from random import choice
from os import listdir
from skimage.measure import find_contours

rand_group_path = choice(GROUP_FOLDERS)
rand_img = choice(listdir(rand_group_path))
cell_image = io.imread(rand_group_path + '/' + rand_img)
if cell_image.ndim == 3:
  cell_image = cell_image.max(2)

def plot_ptiles (
  min_ptile=CONTRAST_PTILES[0],
  max_ptile=CONTRAST_PTILES[1],
  threshold_method=THRESHOLD_METHOD
):
  global CONTRAST_PTILES, THRESHOLD_METHOD
  CONTRAST_PTILES = (min_ptile, max_ptile)
  THRESHOLD_METHOD = threshold_method
  plt.imshow(sm.util._image._contrast_stretching(cell_image,
                                                 (min_ptile, max_ptile)),
             cmap='gray')
  mask = sm.util.preprocess_image(
    cell_image,
    IMG_TYPE,
    None,
    CROP_TECH,
    CONTRAST_PTILES,
    THRESHOLD_METHOD
  )[1]
  contours = find_contours(mask, .9)
  for contour in contours:
    plt.plot(contour[:, 1], contour[:, 0], linewidth=2)
  # plt.imshow(mask, alpha=.5, cmap='jet', interpolation='none')

_ = widgets.interact(plot_ptiles, min_ptile=(0, 100, 1), max_ptile=(0, 100, 1), threshold_method=[*sm.util.THRESHOLD_METHODS, None])

---
## Step 2: Start group analysis

Sholl analysis parameters:
- Set `SHOLL_STEP_SIZE` as int to difference (in pixels) between concentric Sholl circles
- Set `POLYNOMIAL_DEGREE` as int to degree of polynomial for fitting regression model on Sholl values

In [None]:
SHOLL_STEP_SIZE = 3  #@param
POLYNOMIAL_DEGREE = 3  #@param

In [None]:
groups = sm.Groups(GROUP_FOLDERS, image_type=IMG_TYPE, scale=SCALE,
                   groups_crop_tech=CROP_TECH, labels=LABELS,
                   contrast_ptiles=CONTRAST_PTILES,
                   threshold_method=THRESHOLD_METHOD,
                   sholl_step_size=SHOLL_STEP_SIZE,
                   polynomial_degree=POLYNOMIAL_DEGREE,
                   save_results=True, show_logs=False)

In [None]:
import winsound
groups.plot_avg_sholl_plot(False)
winsound.MessageBeep(0)

In [None]:
groups.group_counts

In [None]:
# All 23 Morphological features' names which will be extracted from the images
sm.ALL_FEATURE_NAMES

In [None]:
groups.plot_feature_histograms()

  Select out of all Morphological features on which you want to perform Principal Component Analysis
- Set list of names of `pruned_features` through corresponding checkboxes of each morphological feature

In [None]:
# @title Prune the Morphological features, if needed.
pruned_features = list(sm.ALL_FEATURE_NAMES)

def prune_features(**args):
  global pruned_features
  pruned_features = [feat for feat, val in args.items() if val]

options = dict(zip(list(sm.ALL_FEATURE_NAMES), [True] * len(sm.ALL_FEATURE_NAMES)))
_ = widgets.interact(prune_features, **options)

In [None]:
import seaborn as sns
from statannotations.Annotator import Annotator

axes = plt.subplots((len(pruned_features)+1)//4, 4, figsize=(18, 18))[1]
data = groups.features[pruned_features]

data['label'] = [groups.labels[i] for i in range(len(groups.group_counts))
                    for j in range(groups.group_counts[i])]
ax = axes.ravel()  # flat axes with numpy ravel
x = 'label'

for i in range(len(pruned_features)):
    sns.violinplot(y=pruned_features[i], x=x, data=data, ax=ax[i],
                    order=groups.labels)
    sns.barplot(y=pruned_features[i], x=x, data=data, ax=ax[i],
                order=groups.labels, alpha=.3)
    sns.pointplot(x=x, y=pruned_features[i], data=data, ax=ax[i],
                    color="black", linestyles='--', ci=None)
        
    annotator = Annotator(ax[i], [groups.labels], data=data, x=x,
                            y=pruned_features[i])
    annotator.configure(test='t-test_ind', text_format='star',
                        loc='outside')
    annotator.apply_and_annotate()
    ax[i].set(xlabel=None)
plt.tight_layout()

sm.util._io.savefig(plt, '/Results/feature_bar_swarm.png')

plt.show()

In [None]:
groups.plot_feature_bar_swarm(pruned_features)

In [None]:
groups.plot_feature_scatter_matrix(pruned_features)

In [None]:
feature_significance, covar_matix, var_PCs = groups.pca(n_PC=6, save_results=True,
                                                        on_features=pruned_features)

In [None]:
groups.plot_feature_significance_heatmap()

In [None]:
groups.plot_feature_significance_vectors()

In [None]:
%matplotlib inline
cluster_centers, clustered_data, dist = groups.get_clusters(k=3, use_features=False,
                                                            n_PC=2, plot='scatter')
print('Distribution in clusters (rows represent clusters):')
dist

In [None]:
feature_significance, cov_mat, var_ratios = groups.lda(3, clustered_data['cluster_label'],
                                                       on_features=None)