# SMorph: Group Analysis

---
## Setup
Please execute the cell(s) below to initialize the notebook environment.

In [None]:
# @title Install dependencies
# !pip install poetry

In [None]:
# @title Install SMorph Python module
# !pip install https://github.com/swanandlab/SMorph/releases/download/0.1.0/SMorph-0.1.0.tar.gz

In [None]:
import warnings
warnings.filterwarnings('ignore')  # Suppress annoying warnings

import smorph as sm
import pandas as pd
import ipywidgets as widgets
import winsound
# pd.set_option('display.max_rows', None)  # remove upper limit on display of rows

---
## Step 1: Define groups and image preprocessing parameters

- Set `LABELS` as list of strings containing name labels of each group
- Set `GROUP_FOLDERS` as list of strings containing paths to each group folder
- Set `IMG_TYPE` as string to select the image acquisition method
- Set `CROP_TECH` as string to select the cropping method used to crop out cells from tissue image

In [None]:
# input the path to individual group folders (place them in SMorph)
GROUP_FOLDERS = ['D:/Kushaan/SMorph/ARNAB/CTRL_CA1_Desmin', 'D:/Kushaan/SMorph/ARNAB/MS_CA1_Desmin']  #@param
SCALE = (0.6918881978764917, 0.6918881978764917)
# group labels
LABELS = None  #@param

In [None]:
# @title Interactive image parameter selection
IMG_TYPE = 'confocal'
SEGMENTED = False

def select_image_params (img_type, segmented=SEGMENTED):
  global IMG_TYPE, SEGMENTED
  IMG_TYPE = img_type
  SEGMENTED = segmented

_ = widgets.interact(select_image_params, img_type=['confocal', 'DAB'], segmented=[True, False])

Cell image preprocessing parameters:
* `min_ptile` and `max_ptile`: minimum and maximum contrast percentiles to stretch the image to
* `threshold_method`: method for single intensity auto-thresholding the cell image

Overlayed contour represents the thresholding results.

In [None]:
# @title Interactive preprocessing parameter selection

CONTRAST_PTILES = (0, 100)
THRESHOLD_METHOD = sm.util.THRESHOLD_METHODS[6]

import matplotlib.pyplot as plt
import skimage.io as io
from random import choice
from os import listdir
from skimage.measure import find_contours
from copy import copy

my_cmap = copy(plt.cm.get_cmap('gray')) # get a copy of the gray color map
my_cmap.set_bad(alpha=0) # set how the colormap handles 'bad' values

rand_group_path = choice(GROUP_FOLDERS)
rand_img = choice(listdir(rand_group_path))
rand_img = f'{rand_group_path}/{rand_img}'
cell_image = io.imread(rand_img)
if cell_image.ndim == 3:
  cell_image = cell_image.max(2)
print(rand_img)

def plot_ptiles (
  min_ptile=CONTRAST_PTILES[0],
  max_ptile=CONTRAST_PTILES[1],
  threshold_method=THRESHOLD_METHOD,
  threshold_value=0
):
  global CONTRAST_PTILES, THRESHOLD_METHOD
  CONTRAST_PTILES = (min_ptile, max_ptile)
  THRESHOLD_METHOD = threshold_method
  if threshold_method is None:
    THRESHOLD_METHOD = threshold_value
  plt.imshow(sm.util._image._contrast_stretching(cell_image,
                                                 (min_ptile, max_ptile)),
             cmap='gray')
  mask = sm.util.preprocess_image(
    cell_image,
    IMG_TYPE,
    None,
    SEGMENTED,
    CONTRAST_PTILES,
    THRESHOLD_METHOD
  )[1]
  # contours = find_contours(mask, .9)
  # for contour in contours:
  #   plt.plot(contour[:, 1], contour[:, 0], linewidth=2)
  mask[~mask] = None # insert 'bad' values into your lattice
  plt.imshow(mask, alpha=.5, cmap=my_cmap)

_ = widgets.interact(plot_ptiles, min_ptile=(0, 100, 1),
  max_ptile=(0, 100, 1),
  threshold_method=[*sm.util.THRESHOLD_METHODS, None],
  threshold_value=(0., .99, .01))

---
## Step 2: Start group analysis

Sholl analysis parameters:
- Set `SHOLL_STEP_SIZE` as int to difference (in pixels) between concentric Sholl circles
- Set `POLYNOMIAL_DEGREE` as int to degree of polynomial for fitting regression model on Sholl values

In [None]:
SHOLL_STEP_SIZE = 1 #@param
POLYNOMIAL_DEGREE = 3  #@param


In [None]:
groups = sm.Groups(GROUP_FOLDERS, image_type=IMG_TYPE, scale=SCALE,
                   segmented=SEGMENTED, labels=LABELS,
                   contrast_ptiles=CONTRAST_PTILES,
                   threshold_method=THRESHOLD_METHOD,
                   sholl_step_size=SHOLL_STEP_SIZE,
                   polynomial_degree=POLYNOMIAL_DEGREE,
                   save_results=True, show_logs=False, fig_format='svg',
                   args={"save_cell": False, "save_bin": False, "save_skel": False, "save_overlay": False, "save_branch_struct": False})

groups.plot_avg_sholl_plot(False)
winsound.MessageBeep(0)

In [None]:
groups.group_counts

In [None]:
import numpy as np
import tifffile, json, pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

metadatas = []

for name in groups.file_names:
    image = tifffile.TiffFile(name)
    metadata = image.pages[0].tags['ImageDescription'].value
    metadata = json.loads(metadata)
    metadatas.append(metadata)

In [None]:
parent_images = [m['parent_image'] for m in metadatas]
parent_images

In [None]:
parent_df = pd.Series(parent_images).str.split('\\', expand=True)[7]
parent_mice = parent_df.str.slice(0, 25)
parent_mice.unique()

In [None]:
[parent_mice == 'CONTROL_MSP2.1MB_4_LONG M']

In [None]:
parent_mice.where(parent_mice != 'CONTROL_MSP2.1MB_4_LONG M')

In [None]:
sholl_polynomial_plots = groups.sholl_polynomial_plots
polynomial_plots = list(map(lambda x: list(x),
                            sholl_polynomial_plots))
group_cnts = groups.group_counts
labels = groups.labels

len_polynomial_plots = max(map(len, polynomial_plots))

polynomial_plots = np.array([
    x+[0]*(len_polynomial_plots-len(x)) for x in polynomial_plots])

x = np.arange(SHOLL_STEP_SIZE,
    SHOLL_STEP_SIZE * (len_polynomial_plots + 1),
    SHOLL_STEP_SIZE)

# JASP-friendly data
jasp_friendly_cols = ['label', 'radius', 'nintersections']
jasp_friendly = []
csum_group_cnts = np.cumsum(group_cnts)
for itercell in range(len(sholl_polynomial_plots)):
    for iterradii, r in enumerate(x):
        nintersections = (0 if iterradii >= len(sholl_polynomial_plots[itercell])
            else sholl_polynomial_plots[itercell][iterradii])
        row = [
            parent_mice[itercell], # labels[np.digitize(itercell, csum_group_cnts)],
            r,
            nintersections
        ]
        jasp_friendly.append(row)

jasp_friendly = pd.DataFrame(jasp_friendly, columns=jasp_friendly_cols)

jasp_friendly

In [None]:
jasp_friendly.copy()

In [None]:
tmp = jasp_friendly.copy()
tmp['grp'] = jasp_friendly[jasp_friendly['label'] != 'CONTROL_MSP2.2M_1_SINGLE ']['label'].str.slice(0,3)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# fig, ax = plt.subplots(figsize=(12,8))
sns.lineplot(data=tmp, hue='grp', x='radius', y='nintersections', err_style='bars')

plt.savefig('removed.svg')

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(12,8))
sns.lineplot(data=jasp_friendly, hue='label', x='radius', y='nintersections',err_style='bars', ax=ax)

plt.savefig('decoupled.svg')

In [None]:
groups.features.describe()

In [None]:
import matplotlib.pyplot as plt
n = plt.pie(groups.group_counts, labels=LABELS, autopct='%1.1f%%')
# n[0][0].set_alpha(.5)
# n[0][1].set_alpha(.5)
#plt.show()
plt.savefig('Results/Pie.svg')

In [None]:
# All 23 Morphological features' names which will be extracted from the images
sm.ALL_FEATURE_NAMES

In [None]:
groups.plot_feature_histograms()

  Select out of all Morphological features on which you want to perform Principal Component Analysis
- Set list of names of `pruned_features` through corresponding checkboxes of each morphological feature

In [None]:
# @title Prune the Morphological features, if needed.
pruned_features = list(sm.ALL_FEATURE_NAMES)

def prune_features(**args):
  global pruned_features
  pruned_features = [feat for feat, val in args.items() if val]

options = dict(zip(list(sm.ALL_FEATURE_NAMES), [True] * len(sm.ALL_FEATURE_NAMES)))
_ = widgets.interact(prune_features, **options)

In [None]:
import seaborn as sns
from statannotations.Annotator import Annotator
import numpy as np


axes = plt.subplots((len(sm._ALL_FEATURE_NAMES)+1)//5, 6, figsize=(9, 12))[1]
data = groups.features[pruned_features]

data['label'] = [groups.labels[i] for i in range(len(groups.group_counts))
                    for j in range(groups.group_counts[i])]

ax = axes.ravel()  # flat axes with numpy ravel
x = 'label'
palette = ['#00CCFF','#AA0044'] # (0.7176470588235294, 1.0, 0.9803921568627451), (0.6509803921568628, 0.7843137254901961, 0.4588235294117647), (0.8392156862745098, 0.2823529411764706, 0.8431372549019608)]


for i in range(len(pruned_features)):
    # palette = [np.random.choice(list(sns.xkcd_rgb.keys())) for i in range(3)]
    # palette = sns.palettes.xkcd_palette(palette)
    # print(palette)
    sns.boxplot(y=pruned_features[i], x=x, data=data, ax=ax[i], boxprops=dict(alpha=1),  palette=palette,
                    order=groups.labels, whis=[0,100], width=.75, showmeans=True,
                    medianprops={'color': 'black'}, meanprops={"marker":"o",
                       "markerfacecolor":"black", 
                       "markeredgecolor":"white"})#, palette={'CTRL CA3 GFAP': '#67A9CF', 'MS CA3 GFAP': '#EF8A62'})

    # sns.barplot(y=pruned_features[i], x=x, data=data, ax=ax[i],
    #             order=groups.labels, alpha=.3)
    # sns.pointplot(x=x, y=pruned_features[i], data=data, ax=ax[i],
    #                 color="black", linestyles='--', ci=None)

    # if len(groups.group_counts) == 2:
    #     annotator = Annotator(ax[i], [groups.labels], data=data, x=x,
    #                         y=pruned_features[i], verbose=False)
    #     annotator.configure(test='t-test_ind', text_format='star',
    #                         loc='outside')
    #     annotator.apply_and_annotate()

    ax[i].set(xlabel=None)

plt.tight_layout()
plt.savefig('Results/Boxplot_Feature.svg')

In [None]:
groups.plot_feature_bar_swarm(pruned_features)

In [None]:
# groups.plot_feature_scatter_matrix(pruned_features)

In [None]:
feature_significance, covar_matix, var_PCs = groups.pca(n_PC=6, save_results=True,
                                                        on_features=pruned_features, only_ellipse=True)

In [None]:
groups.plot_feature_significance_heatmap()

In [None]:
groups.plot_feature_significance_vectors()

In [None]:
%matplotlib inline
cluster_centers, clustered_data, dist = groups.get_clusters(k=3, use_features=False,
                                                            n_PC=2, plot='scatter')
print('Distribution in clusters (rows represent clusters):')
dist

In [None]:
import numpy as np
fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(aspect="equal"))
size = 0.3
vals = [v for k in dist.keys() for c, v in dist[k].iteritems()]
ilabels = [f'{k}-Cluster_{c}' for k in dist.keys() for c, v in dist[k].iteritems()]
vals = dist.to_numpy().ravel('A')
ax.pie(groups.group_counts, radius=1-size, labeldistance=1-size*2,
       wedgeprops=dict(width=size, edgecolor='w', alpha=.9), labels=LABELS)
wedges, texts, _ = ax.pie(vals, radius=1, autopct='%1.1f%%', pctdistance=1-size/2,
                       wedgeprops=dict(width=size, edgecolor='w'))
bbox_props = dict(boxstyle="square,pad=0.3", fc="w", ec="k", lw=0.72)
kw = dict(arrowprops=dict(arrowstyle="-"),
          bbox=bbox_props, zorder=0, va="center")

for i, p in enumerate(wedges):
    ang = (p.theta2 - p.theta1)/2. + p.theta1
    y = np.sin(np.deg2rad(ang))
    x = np.cos(np.deg2rad(ang))
    horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))]
    connectionstyle = "angle,angleA=0,angleB={}".format(ang)
    kw["arrowprops"].update({"connectionstyle": connectionstyle})
    ax.annotate(ilabels[i], xy=(x, y), xytext=(1.1*np.sign(x), 1.1*y),
                horizontalalignment=horizontalalignment, **kw)

ax.set_title("Cluster distribution")

In [None]:
from ipyfilechooser import FileChooser

# Create and display a FileChooser widget
fc = FileChooser()
FEATURES_FILE = None
display(fc)
def choose_features_file():
  global FEATURES_FILE
  FEATURES_FILE = fc.selected
_ = widgets.interact_manual(choose_features_file)

In [None]:
import pandas as pd
data = pd.read_csv(FEATURES_FILE)
data.describe()

In [None]:
data[data['label'] == 'ADT_CONTROL_28D_HILUS'].describe()

In [None]:
data[data['label'] == 'ADT_DMI_28D_HILUS'].describe()

In [None]:
pd.plotting.boxplot_frame?

In [None]:
pd.plotting.boxplot_frame(data[data['label'] == 'MS CA1 Desmin'])

In [None]:
feature_significance, cov_mat, var_ratios = groups.lda(n_components=2,
                                                       cluster_labels=clustered_data['cluster_label'],
                                                       on_features=pruned_features)

---