# Introduction
Pipeline from the paper [Persistence diagrams as morphological signatures of cells: A method to measure and compare cells within a population](https://files.yossi.eu/manuscripts/2310.20644.pdf) by Yossi Bokor Bleile, Pooja Yadav, Patrice Koehl, and Florian Rehfeldt.

The following packages need to be installed: 
- [Correa](correa.yossi.eu)
- plotly
- pandas
- sklearn
- numpy
- matplotlib
- tifffile


We begin the analysis by obtaining a persistence diagram for each cell in the population, as a summary of the morophology of the cell.

Below is the anlysis for `X1`, which can be repeated for `Y1` by replacing `X1` with `Y1` as appropriate and the running the relevant cells again.


## Import modules

In [None]:
import os
import correa
import pandas
import tifffile
import plotly.express as px
import plotly.figure_factory as ff		
import plotly.io as pio
import plotly.graph_objects as go
from sklearn import manifold, cluster, decomposition, metrics, preprocessing
import numpy
import scipy.cluster
from scipy.cluster.hierarchy import dendrogram
from scipy.cluster.hierarchy import ClusterWarning
from matplotlib import pyplot as plt
import matplotlib
from scipy.spatial.distance import squareform
from warnings import simplefilter
simplefilter("ignore", ClusterWarning)

pio.renderers.default = "browser" #set the renderer to browser

## Custom functions for the analysis

We next define some custom functions to make creating the dendrograms easier.


In [None]:
#set colours
def rgb_to_hex(r, g, b):
    return '#{:02x}{:02x}{:02x}'.format(r, g, b)
hex_list = []
for c in px.colors.qualitative.Set1:
	hex_list.append(rgb_to_hex(int(c.replace("rgb(","").replace(")","").split(",")[0]),int(c.replace("rgb(","").replace(")","").split(",")[1]), int(c.replace("rgb(","").replace(")","").split(",")[2])))
 
 
def plot_dendrogram(model, **kwargs):
	# Create linkage matrix and then plot the dendrogram
	# create the counts of samples under each node
	counts = numpy.zeros(model.children_.shape[0])
	n_samples = len(model.labels_)
	for i, merge in enumerate(model.children_):
		current_count = 0
		for child_idx in merge:
			if child_idx < n_samples:
				current_count += 1  # leaf node
			else:
				current_count += counts[child_idx - n_samples]
		counts[i] = current_count
	linkage_matrix = numpy.column_stack([model.children_, model.distances_, counts]).astype(float)
	# Plot the corresponding dendrogram
	plt.figure(figsize=(2000,2000))
	plt.tick_params(
    axis='x',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,         # ticks along the top edge are off
    labelbottom=False) # labels along the bottom edge are off
	plt.tick_params(
    axis='y',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,         # ticks along the top edge are off
    labelbottom=False) # labels along the bottom edge are off
	dendrogram(linkage_matrix, **kwargs)

def generate_rand_index(cluster_df : pandas.DataFrame, to_compare : list):
	r_ind = pandas.DataFrame(columns=to_compare, index=to_compare)
	for i in to_compare:
		for j in to_compare:
			r_ind[i].loc[j] = metrics.rand_score(cluster_df[i], cluster_df[j])
	return r_ind

def population_percentages(df : pandas.DataFrame, clustering : str):
	labels = [int(c) for c in df[clustering]]
	counts = [0 for i in range(max(labels)+1)]
	for l in labels:
		counts[l]+=1
	percentages = [c/len(labels) for c in counts]
	return percentages, counts

def analysis(dists : pandas.DataFrame, cluster_numbers : list,  name : str,  dir : str, group = False, exclude : list = [], show2d = False, showElbow=False, colour_list=hex_list):
	#colour order is red, purple, blue, green
	if group == False:
		inds = dists.index
		dists = dists
		if len(exclude) != 0:
			inds = []
			for f in dists.index:
				if f not in exclude:
					inds.append(f)
			dists = dists[inds].loc[inds]
	else:
		inds = []
		for f in dists.index:
			if group in f and f not in exclude:
				inds.append(f)
		dists = dists[inds].loc[inds]
	df = pandas.DataFrame(index=inds)
	embed = manifold.MDS(3, dissimilarity='precomputed', random_state=1, normalized_stress="auto").fit_transform(dists.to_numpy())
	df["x"] = embed[:,0]
	df["y"] = embed[:,1]
	df["z"] = embed[:,2]
	silhouette_samples = []
	silhouette_score = []
	linkage_matrices = []
	avg = cluster.AgglomerativeClustering(distance_threshold=None, n_clusters=4, linkage="average")
	avg = avg.fit(dists.to_numpy())
	dflt = "#000000"
	D_leaf_colors = {dists.index[i]: colour_list[avg.labels_[i]] for i in range(len(avg.labels_))}
	#Average
	for link in ["average", "complete", "single", "ward"]:
		linkage = cluster.AgglomerativeClustering(distance_threshold=0, n_clusters=None, linkage=link)
		linkage = linkage.fit(dists.to_numpy())
		counts = numpy.zeros(linkage.children_.shape[0])
		n_samples = len(linkage.labels_)
		for i, merge in enumerate(linkage.children_):
			current_count = 0
			for child_idx in merge:
				if child_idx < n_samples:
					current_count += 1  # leaf node
				else:	
					current_count += counts[child_idx - n_samples]
			counts[i] = current_count
		linkage_matrix = numpy.column_stack([linkage.children_, linkage.distances_, counts]).astype(float)
		linkage_matrices.append(linkage_matrix)
		# notes:
		# * rows in Z correspond to "inverted U" links that connect clusters
		# * rows are ordered by increasing distance
		# * if the colors of the connected clusters match, use that color for link
		link_cols = {}
		for i, i12 in enumerate(linkage_matrix[:,:2].astype(int)):
			c1, c2 = (link_cols[x] if x > len(linkage_matrix) else D_leaf_colors[dists.index[x]] for x in i12)
			if c1 == c2:
				link_cols[i+1+len(linkage_matrix)] = c1
			else:
				if i12[0] < n_samples:
					link_cols[i+1+len(linkage_matrix)] = D_leaf_colors[dists.index[i12[0]]]
				elif i12[1] < n_samples:
					link_cols[i+1+len(linkage_matrix)] = D_leaf_colors[dists.index[i12[1]]]
				else:
					link_cols[i+1+len(linkage_matrix)] = dflt
		# Dendrogram
		D = dendrogram(Z=linkage_matrix, labels=dists.index, color_threshold=None, no_labels=True, link_color_func=lambda x: link_cols[x])
		#plt.xlabel(name+" "+link)
		plt.yticks([])
		plt.savefig(dir+"/"+name.replace(" ","_")+"_"+link+"_dendrogram.png")
		plt.show()
	sse = []
	for link in ["average", "complete", "single", "ward"]:
		for k in cluster_numbers:
			linkage = cluster.AgglomerativeClustering(distance_threshold=None, n_clusters=k, linkage=link)
			linkage = linkage.fit(dists.to_numpy())
			labels = linkage.labels_
			labels = [str(c) for c in labels]
			df[link+str(k)] = labels
			sil_score = metrics.silhouette_score(dists, df[link+str(k)], metric="precomputed")
			silhouette_score.append(sil_score)
			sil_samps = metrics.silhouette_samples(dists, df[link+str(k)], metric="precomputed")
			silhouette_samples.append([sil_samps])
			fig = px.scatter(df, x='x', y='y',color=link+str(k), title=name+" ("+link+" "+str(k)+")", hover_data=[df.index], width=800, height=600, color_discrete_map={
                "0": hex_list[0], "1": hex_list[1],"2": hex_list[1],"2": hex_list[2], "3": hex_list[3]})
			fig.update_traces(marker={'size': 5})
			if show2d:
				fig.show()
			fig.write_image(dir+"/"+name.replace(" ","_")+"_"+link+str(k)+"_2D.png")
	if showElbow:
		plt.plot(cluster_numbers, sse)
		plt.title("Elbow Method")
		plt.xlabel("Number of Clusters")
		plt.xticks(cluster_numbers)
		plt.ylabel("SSE")
		plt.savefig(dir+"/"+name.replace(" ","_")+"_kmeans-elbow.png")
		fig.show()
	df.to_csv(dir+"/"+name.replace(" ","_")+"_df.csv")
	rand_ind_tables_latex = []
	percentages_dict = {}
	for k in cluster_numbers:
		rand_ind_tables_latex.append(generate_rand_index(df, ["average"+str(k), "complete"+str(k), "single"+str(k),"ward"+str(k)]).to_latex())
		print("cluster sizes and percentages are:")
		percentages_dict["average{}".format(k)], counts = population_percentages(df, "average"+str(k))
		print("average{}".format(k), counts, percentages_dict["average{}".format(k)])
		percentages_dict["complete{}".format(k)], counts = population_percentages(df, "complete"+str(k))
		print("complete{}".format(k), counts, percentages_dict["complete{}".format(k)])
		percentages_dict["single{}".format(k)], counts = population_percentages(df, "single"+str(k))
		print("single{}".format(k), counts, percentages_dict["single{}".format(k)])
		percentages_dict["ward{}".format(k)], counts = population_percentages(df, "ward"+str(k))
		print("ward{}".format(k), counts, percentages_dict["ward{}".format(k)])
		#percentages_dict["kmeans++{}".format(k)] = population_percentages(df, "kmeans++"+str(k))
		#print("kmeans++{}".format(k), percentages_dict["kmeans++{}".format(k)])
	return df, silhouette_score, silhouette_samples, rand_ind_tables_latex, percentages_dict, dists, linkage_matrices

def get_main_population(analysis, cluster : str, dists : pandas.DataFrame):
	main_id = analysis[4][cluster].index(max(analysis[4][cluster]))
	main_index= analysis[0].index[analysis[0][cluster] == str(main_id)]
	return dists[main_index].loc[main_index]

def purity(merges, clus : set, n_objs):
	dct = dict([(i, {i}) for i in range(n_objs)])
	for i, row in enumerate(merges, n_objs):
		dct[i] = dct[row[0]].union(dct[row[1]])
		del dct[row[0]]
		del dct[row[1]]
		for c in list(dct.values()):
			if clus.issubset(c):
				return c, (n_objs-len(c))/(n_objs-len(clus))

# Example Dataset

In [None]:
dataset = "X1" #select the dataset to analyse and the set the correct width and height of the images
if dataset == "X1":
	width=1500 #we need to set the largest width of an image in the dataset, as FilamentSensor2 pads all the iamges to be the same size
	height=1692 #we need to set the largest height of an image in the dataset, as FilamentSensor2 pads all the iamges to be the same size
elif dataset == "X2":
	width=1455
	height=1584
elif dataset == "X3":
	width=1919
	height=1467
elif dataset == "Y1":
	width=1226 #we need to set the largest width of an image in the dataset, as FilamentSensor2 pads all the iamges to be the same size
	height=1088 #we need to set the largest height of an image in the dataset, as FilamentSensor2 pads all the iamges to be the same size
else:
	raise ValueError("Dataset not supported")

files = os.listdir(dataset+"/cell/raw_images") #get all of the files in the cell/raw_images directory
cell_names = [int(file[0:3]) for file in files if file.endswith('.tif')] #get the cell numbers from the file names
n_cells = len(cell_names)


First, we will load the file containing the center of the nucleus for each cell.

Next, for each cell will caculate the persistence diagram using a radial function based at the center of the nucles. We do need to translate the center of the nucleus into the same frame as the contour is in. This is due to the way FilamentSensor extracts the contour of each cell in a directory. The appropriate `height` and `width` values can be found in the [FilamentSensor2](https://filament-sensor.de/) log file.

In [None]:
centers = pandas.read_csv(dataset+"/Nuc_Cm_"+dataset+".csv", index_col="filename") #Load the center of the nucleus, labled by the nucleus file name

In [None]:
contours = [] #list of contours
for i in cell_names[0:5]:
	actin = tifffile.imread(dataset+"/cell/raw_images/"+str("%03d" % i)+".tif") #we need to know the size of the original image to shift the center of the nucleus to the correct position
	height_diff = height - actin.shape[0]
	width_diff = width - actin.shape[1]
	print("before load")
	c_i = correa.create_polygon_focal_point(dataset+"/cell/contours/"+str("%03d" % i)+"_contour.csv", [centers.loc[i,"X_m"]+width_diff/2,centers.loc[i,"Y_m"]+height_diff/2])
	print("after load")
	c_i.persistence_diagram()
	print("after persistence diagram")
	contours.append(c_i)
 

We can also plot the contours and mark the center of the nucleus

In [None]:
for i in cell_names[0:5]: #cell number we want to look at
	contour = pandas.read_csv(dataset+"/cell/contours/"+str("%03d" % i)+"_contour.csv",header=None)
	actin = tifffile.imread(dataset+"/cell/raw_images/"+str("%03d" % i)+".tif")
	nucleus = tifffile.imread(dataset+"/nucleus/raw_images/"+str("%03d" % i)+".tif")
	height_diff = height - actin.shape[0]
	width_diff = width - actin.shape[1]
	center = pandas.DataFrame([[centers.loc[i,"X_m"]+width_diff/2,centers.loc[i,"Y_m"]+height_diff/2]], columns=["x", "y"])
	fig_data = px.scatter(contour, x=0, y=1, width=800, height=600).data
	fig_data = fig_data + px.scatter(center, x="x", y="y").update_traces(marker={'size': 5, 'color': 'Red'}).data
	fig = go.Figure(fig_data)
	fig.show()


Once we have a persistence diagram for each cell summarising its morphology, we compute the Wasserstein distance between each pair of persistence diagrams as a (dis)similarity score for each pair of cells.

In [None]:
w_distances = numpy.zeros((n_cells,n_cells))
for i in range(n_cells):
	for j in range(i,n_cells):
		dist_ij = correa.wasserstein_distance(contours[i], contours[j], q=2)
		dist_ji = correa.wasserstein_distance(contours[j], contours[i], q=2)
		dist = (dist_ij+dist_ji)/2
		print("for {} and {} dist_ij is {} and dist_ji is {} so we have dist {}".format(i,j,dist_ij,dist_ji,dist))
		w_distances[i,j] = dist
		w_distances[j,i] = dist

Next we display a heatmap of the Wasserstein distances.

In [None]:
px.imshow(w_distances,width=500, height=500)

In [None]:
dists = pandas.DataFrame(w_distances, columns=cell_names, index=cell_names)

In [None]:
A = analysis(X1_dists, [3,4,5], dataset, dataset)

In `X1`, the heatmap and all 4 dendrograms indicate there is an outlier, so lets identify which cell this is. Using `average4` the outlier has cluster number 3.

In [None]:
for i in A[0].index:
	if int(A[0].loc[i]["average4"]) == 3:
		print(i)

If desired we can use the `analysis` command with the `exclude` parameter to exclude cell a from our analysis.

In [None]:
if dataset == "X1":
	A_main = analysis(dists, [3,4,5], dataset, dataset, exclude=["015"])

	n_c = A_main[0].shape[0]
	print(n_c)
	for k in range(4):
		clus = set(())
		for i in range(len(X1_main[0]["average4"])):
			if X1_main[0]["average4"].iloc[i] == str(k):
				clus.add(i)
		print("cluster "+str(k))
		print("average: "+str(purity(X1_main[6][0], clus, n_c)[1]))
		print("complete: "+str(purity(X1_main[6][1], clus, n_c)[1]))
		print("single: "+str(purity(X1_main[6][2], clus, n_c)[1]))
		print("ward: "+str(purity(X1_main[6][3], clus, n_c)[1]))

else:
	raise ValueError("Dataset not supported")