# Persistent Homology Analysis of Cell Morphology

Pipeline from the paper [Persistence diagrams as morphological signatures of cells: A method to measure and compare cells within a population](https://files.yossi.eu/manuscripts/2310.20644.pdf) by Yossi Bokor Bleile, Pooja Yadav, Patrice Koehl, and Florian Rehfeldt.

The following packages need to be installed: 
- [Correa](https://correa.yossi.eu)
- plotly
- pandas
- sklearn
- numpy
- matplotlib
- tifffile

## Overview

We begin the analysis by obtaining a persistence diagram for each cell in the population, as a summary of the morphology of the cell. This notebook contains three main analysis sections:

1. **X1 Dataset**: Analysis of hMSC cells
2. **Y1 Dataset**: Analysis of HeLa cells  
3. **X1Y1 Combined Dataset**: Comparative analysis of both cell types together


## Import modules

In [None]:
import os
import correa
import pandas
import tifffile
import plotly.express as px
import plotly.figure_factory as ff		
import plotly.io as pio
import plotly.graph_objects as go
from sklearn import manifold, cluster, decomposition, metrics, preprocessing
import numpy
import scipy.cluster
from scipy.cluster.hierarchy import dendrogram
from scipy.cluster.hierarchy import ClusterWarning
from matplotlib import pyplot as plt
import matplotlib
from scipy.spatial.distance import squareform
from warnings import simplefilter
simplefilter("ignore", ClusterWarning)

pio.renderers.default = "browser" #set the renderer to browser

## Custom functions for the analysis

We next define some custom functions to make creating the dendrograms easier.


In [None]:
#set colours
def rgb_to_hex(r, g, b):
    return '#{:02x}{:02x}{:02x}'.format(r, g, b)
hex_list = []
for c in px.colors.qualitative.Set1:
	hex_list.append(rgb_to_hex(int(c.replace("rgb(","").replace(")","").split(",")[0]),int(c.replace("rgb(","").replace(")","").split(",")[1]), int(c.replace("rgb(","").replace(")","").split(",")[2])))
 
 
# def plot_dendrogram(model, **kwargs):
# 	# Create linkage matrix and then plot the dendrogram
# 	# create the counts of samples under each node
# 	counts = numpy.zeros(model.children_.shape[0])
# 	n_samples = len(model.labels_)
# 	for i, merge in enumerate(model.children_):
# 		current_count = 0
# 		for child_idx in merge:
# 			if child_idx < n_samples:
# 				current_count += 1  # leaf node
# 			else:
# 				current_count += counts[child_idx - n_samples]
# 		counts[i] = current_count
# 	linkage_matrix = numpy.column_stack([model.children_, model.distances_, counts]).astype(float)
# 	# Plot the corresponding dendrogram
# 	plt.figure(figsize=(2000,2000))
# 	plt.tick_params(
#     axis='x',          # changes apply to the x-axis
#     which='both',      # both major and minor ticks are affected
#     bottom=False,      # ticks along the bottom edge are off
#     top=False,         # ticks along the top edge are off
#     labelbottom=False) # labels along the bottom edge are off
# 	plt.tick_params(
#     axis='y',          # changes apply to the x-axis
#     which='both',      # both major and minor ticks are affected
#     bottom=False,      # ticks along the bottom edge are off
#     top=False,         # ticks along the top edge are off
#     labelbottom=False) # labels along the bottom edge are off
# 	dendrogram(linkage_matrix, **kwargs)

def generate_rand_index(cluster_df : pandas.DataFrame, to_compare : list):
	r_ind = pandas.DataFrame(columns=to_compare, index=to_compare)
	for i in to_compare:
		for j in to_compare:
			r_ind[i].loc[j] = metrics.rand_score(cluster_df[i], cluster_df[j])
	return r_ind

def population_percentages(df : pandas.DataFrame, clustering : str):
	labels = [int(c) for c in df[clustering]]
	counts = [0 for i in range(max(labels)+1)]
	for l in labels:
		counts[l]+=1
	percentages = [c/len(labels) for c in counts]
	return percentages, counts

def analysis(dists : pandas.DataFrame, cluster_numbers : list,  name : str,  dir : str, group = False, exclude : list = [], show2d = False, showElbow=False, colour_list=hex_list):
	#colour order is red, purple, blue, green
	if group == False:
		inds = dists.index
		dists = dists
		if len(exclude) != 0:
			inds = []
			for f in dists.index:
				if f not in exclude:
					inds.append(f)
			dists = dists[inds].loc[inds]
	else:
		inds = []
		for f in dists.index:
			if group in f and f not in exclude:
				inds.append(f)
		dists = dists[inds].loc[inds]
	df = pandas.DataFrame(index=inds)
	embed = manifold.MDS(3, dissimilarity='precomputed', random_state=1, normalized_stress="auto").fit_transform(dists.to_numpy())
	df["x"] = embed[:,0]
	df["y"] = embed[:,1]
	df["z"] = embed[:,2]
	silhouette_samples = []
	silhouette_score = []
	linkage_matrices = []
	avg = cluster.AgglomerativeClustering(distance_threshold=None, n_clusters=4, linkage="average")
	avg = avg.fit(dists.to_numpy())
	dflt = "#000000"
	D_leaf_colors = {dists.index[i]: colour_list[avg.labels_[i]] for i in range(len(avg.labels_))}
	#Average
	for link in ["average", "complete", "single", "ward"]:
		linkage = cluster.AgglomerativeClustering(distance_threshold=0, n_clusters=None, linkage=link)
		linkage = linkage.fit(dists.to_numpy())
		counts = numpy.zeros(linkage.children_.shape[0])
		n_samples = len(linkage.labels_)
		for i, merge in enumerate(linkage.children_):
			current_count = 0
			for child_idx in merge:
				if child_idx < n_samples:
					current_count += 1  # leaf node
				else:	
					current_count += counts[child_idx - n_samples]
			counts[i] = current_count
		linkage_matrix = numpy.column_stack([linkage.children_, linkage.distances_, counts]).astype(float)
		linkage_matrices.append(linkage_matrix)
		# notes:
		# * rows in Z correspond to "inverted U" links that connect clusters
		# * rows are ordered by increasing distance
		# * if the colors of the connected clusters match, use that color for link
		link_cols = {}
		for i, i12 in enumerate(linkage_matrix[:,:2].astype(int)):
			c1, c2 = (link_cols[x] if x > len(linkage_matrix) else D_leaf_colors[dists.index[x]] for x in i12)
			if c1 == c2:
				link_cols[i+1+len(linkage_matrix)] = c1
			else:
				if i12[0] < n_samples:
					link_cols[i+1+len(linkage_matrix)] = D_leaf_colors[dists.index[i12[0]]]
				elif i12[1] < n_samples:
					link_cols[i+1+len(linkage_matrix)] = D_leaf_colors[dists.index[i12[1]]]
				else:
					link_cols[i+1+len(linkage_matrix)] = dflt
		# Dendrogram
		# D = dendrogram(Z=linkage_matrix, labels=dists.index, color_threshold=None, no_labels=True, link_color_func=lambda x: link_cols[x])
		#plt.xlabel(name+" "+link)
		# plt.yticks([])
		# plt.savefig(dir+"/"+name.replace(" ","_")+"_"+link+"_dendrogram.png")
		# plt.show()
	sse = []
	for link in ["average", "complete", "single", "ward"]:
		for k in cluster_numbers:
			linkage = cluster.AgglomerativeClustering(distance_threshold=None, n_clusters=k, linkage=link)
			linkage = linkage.fit(dists.to_numpy())
			labels = linkage.labels_
			labels = [str(c) for c in labels]
			df[link+str(k)] = labels
			sil_score = metrics.silhouette_score(dists, df[link+str(k)], metric="precomputed")
			silhouette_score.append(sil_score)
			sil_samps = metrics.silhouette_samples(dists, df[link+str(k)], metric="precomputed")
			silhouette_samples.append([sil_samps])
			# fig = px.scatter(df, x='x', y='y',color=link+str(k), title=name+" ("+link+" "+str(k)+")", hover_data=[df.index], width=800, height=600, color_discrete_map={
            #     "0": hex_list[0], "1": hex_list[1],"2": hex_list[1],"2": hex_list[2], "3": hex_list[3]})
			# fig.update_traces(marker={'size': 5})
			# if show2d:
			# 	fig.show()
			# fig.write_image(dir+"/"+name.replace(" ","_")+"_"+link+str(k)+"_2D.png")
	# if showElbow:
	# 	plt.plot(cluster_numbers, sse)
	# 	plt.title("Elbow Method")
	# 	plt.xlabel("Number of Clusters")
	# 	plt.xticks(cluster_numbers)
	# 	plt.ylabel("SSE")
	# 	plt.savefig(dir+"/"+name.replace(" ","_")+"_kmeans-elbow.png")
	# 	fig.show()
	df.to_csv(dir+"/"+name.replace(" ","_")+"_df.csv")
	rand_ind_tables_latex = []
	percentages_dict = {}
	for k in cluster_numbers:
		rand_ind_tables_latex.append(generate_rand_index(df, ["average"+str(k), "complete"+str(k), "single"+str(k),"ward"+str(k)]).to_latex())
		print("cluster sizes and percentages are:")
		percentages_dict["average{}".format(k)], counts = population_percentages(df, "average"+str(k))
		print("average{}".format(k), counts, percentages_dict["average{}".format(k)])
		percentages_dict["complete{}".format(k)], counts = population_percentages(df, "complete"+str(k))
		print("complete{}".format(k), counts, percentages_dict["complete{}".format(k)])
		percentages_dict["single{}".format(k)], counts = population_percentages(df, "single"+str(k))
		print("single{}".format(k), counts, percentages_dict["single{}".format(k)])
		percentages_dict["ward{}".format(k)], counts = population_percentages(df, "ward"+str(k))
		print("ward{}".format(k), counts, percentages_dict["ward{}".format(k)])
		#percentages_dict["kmeans++{}".format(k)] = population_percentages(df, "kmeans++"+str(k))
		#print("kmeans++{}".format(k), percentages_dict["kmeans++{}".format(k)])
	return df, silhouette_score, silhouette_samples, rand_ind_tables_latex, percentages_dict, dists, linkage_matrices

def get_main_population(analysis, cluster : str, dists : pandas.DataFrame):
	main_id = analysis[4][cluster].index(max(analysis[4][cluster]))
	main_index= analysis[0].index[analysis[0][cluster] == str(main_id)]
	return dists[main_index].loc[main_index]

def purity(merges, clus : set, n_objs):
	dct = dict([(i, {i}) for i in range(n_objs)])
	for i, row in enumerate(merges, n_objs):
		dct[i] = dct[row[0]].union(dct[row[1]])
		del dct[row[0]]
		del dct[row[1]]
		for c in list(dct.values()):
			if clus.issubset(c):
				return c, (n_objs-len(c))/(n_objs-len(clus))

# X1 Dataset Analysis

The X1 dataset contains 140 hMSC (human mesenchymal stem cells) cells. We will:
1. Load the cell contours and nucleus centers
2. Compute persistence diagrams for each cell
3. Calculate Wasserstein distances between all pairs
4. Perform hierarchical clustering analysis

In [None]:
## X1 Dataset Setup

# Set dataset parameters for X1
dataset_X1 = "X1"
width_X1 = 1500  # Largest width of an image in the X1 dataset (FilamentSensor2 pads all images to same size)
height_X1 = 1692  # Largest height of an image in the X1 dataset
conversion_factor_X1 = 0.3155
# Get all cell files
files_X1 = os.listdir("../" + dataset_X1 + "/cell/raw_images")
cell_names_X1 = sorted([int(file[0:3]) for file in files_X1 if file.endswith('.tif')])
n_cells_X1 = len(cell_names_X1)
print(f"X1 dataset contains {n_cells_X1} cells")


## X1: Load Nucleus Centers

First, we will load the file containing the center of the nucleus for each cell.

Next, for each cell we will calculate the persistence diagram using a radial function based at the center of the nucleus. We need to translate the center of the nucleus into the same frame as the contour is in. This is due to the way FilamentSensor extracts the contour of each cell in a directory. The appropriate `height` and `width` values can be found in the [FilamentSensor2](https://filament-sensor.de/) log file.

In [None]:
centers_X1 = pandas.read_csv("../" + dataset_X1 + "/Nuc_Cm_" + dataset_X1 + ".csv", index_col="filename")
print(f"Loaded nucleus centers for {len(centers_X1)} cells")

## X1: Compute Persistence Diagrams

Now we compute the persistence diagram for each cell using the radial function based at the nucleus center.


In [None]:
contours_X1 = []
for i in cell_names_X1:
	actin = tifffile.imread("../" + dataset_X1 + "/cell/raw_images/" + str("%03d" % i) + ".tif")
	height_diff = height_X1 - actin.shape[0]
	width_diff = width_X1 - actin.shape[1]
	
	# Create polygon with focal point at nucleus center
	c_i = correa.create_polygon_focal_point(
		"../" + dataset_X1 + "/cell/contours/" + str("%03d" % i) + "_contour.csv", 
		[centers_X1.loc[i, "X_m"] + width_diff/2, centers_X1.loc[i, "Y_m"] + height_diff/2], convert_to_microns_factor=conversion_factor_X1
	)
	c_i.persistence_diagram()
	contours_X1.append(c_i)
	
	if (i+1) % 20 == 0:
		print(f"Processed {i+1}/{n_cells_X1} cells")

print(f"Completed persistence diagrams for all {n_cells_X1} cells")
 

## X1: Visualize Sample Contours

We can visualize some sample contours and mark the center of the nucleus (shown in red).

In [None]:
# Visualize first 5 cells
for i in cell_names_X1[0:5]:
	contour = pandas.read_csv(dataset_X1 + "/cell/contours/" + str("%03d" % i) + "_contour.csv", header=None)
	actin = tifffile.imread(dataset_X1 + "/cell/raw_images/" + str("%03d" % i) + ".tif")
	height_diff = height_X1 - actin.shape[0]
	width_diff = width_X1 - actin.shape[1]
	center = pandas.DataFrame(
		[[centers_X1.loc[i, "X_m"] + width_diff/2, centers_X1.loc[i, "Y_m"] + height_diff/2]], 
		columns=["x", "y"]
	)
	
	fig_data = px.scatter(contour, x=0, y=1, width=800, height=600).data
	fig_data = fig_data + px.scatter(center, x="x", y="y").update_traces(marker={'size': 10, 'color': 'Red'}).data
	fig = go.Figure(fig_data)
	fig.update_layout(title=f"X1 Cell {i:03d}", xaxis_title="X", yaxis_title="Y")
	fig.show()


## X1: Compute Wasserstein Distances

Once we have a persistence diagram for each cell summarising its morphology, we compute the Wasserstein distance between each pair of persistence diagrams as a (dis)similarity score for each pair of cells.

In [None]:
w_distances_X1 = numpy.zeros((n_cells_X1, n_cells_X1))
for i in range(n_cells_X1):
	for j in range(i, n_cells_X1):
		dist_ij = correa.wasserstein_distance(contours_X1[i], contours_X1[j], q=2)
		dist_ji = correa.wasserstein_distance(contours_X1[j], contours_X1[i], q=2)
		dist = (dist_ij + dist_ji) / 2
		w_distances_X1[i,j] = dist
		w_distances_X1[j,i] = dist
	
	if (i+1) % 10 == 0:
		print(f"Computed distances for {i+1}/{n_cells_X1} cells")

print(f"Completed Wasserstein distance computation for all {n_cells_X1} cells")

## X1: Distance Heatmap

Next we display a heatmap of the Wasserstein distances.

In [None]:
fig = px.imshow(w_distances_X1, width=800, height=800)
fig.update_layout(title="X1 Wasserstein Distance Heatmap")
fig.show()

In [None]:
dists_X1 = pandas.DataFrame(w_distances_X1, columns=cell_names_X1, index=cell_names_X1)

## X1: Hierarchical Clustering Analysis

Perform hierarchical clustering with different linkage methods (average, complete, single, ward) and different numbers of clusters (3, 4, 5).


In [None]:
A_X1 = analysis(dists_X1, [3,4,5], "X1", "X1")

## X1: Compute purity scores based on ```average4```


In [None]:
for k in range(4):
	clus = set(())
	for i in range(len(A_X1[0]["average4"])):
		if A_X1[0]["average4"].iloc[i] == str(k):
			clus.add(i)
	print(f"\nCluster {k}:")
	print(f"  Average linkage purity: {purity(A_X1[6][0], clus, n_c)[1]:.4f}")
	print(f"  Complete linkage purity: {purity(A_X1[6][1], clus, n_c)[1]:.4f}")
	print(f"  Single linkage purity: {purity(A_X1[6][2], clus, n_c)[1]:.4f}")
	print(f"  Ward linkage purity: {purity(A_X1[6][3], clus, n_c)[1]:.4f}")

## X1: Identify Outliers

In the X1 dataset, the heatmap and dendrograms indicate there is an outlier. Let's identify which cell this is. Using `average4`, the outlier has cluster number 3.

In [None]:
for i in A_X1[0].index:
	if int(A_X1[0].loc[i]["average4"]) == 3:
		print(f"Outlier cell: {i}")

## X1: Analysis Excluding Outlier

If desired, we can use the `analysis` command with the `exclude` parameter to exclude the outlier cell (015) from our analysis and compute purity scores.

In [None]:
A_X1_main = analysis(dists_X1, [3,4,5], "X1", "X1", exclude=["X1-15"])

n_c = A_X1_main[0].shape[0]
print(f"Number of cells (excluding outlier): {n_c}")

for k in range(4):
	clus = set(())
	for i in range(len(A_X1_main[0]["average4"])):
		if A_X1_main[0]["average4"].iloc[i] == str(k):
			clus.add(i)
	print(f"\nCluster {k}:")
	print(f"  Average linkage purity: {purity(A_X1_main[6][0], clus, n_c)[1]:.4f}")
	print(f"  Complete linkage purity: {purity(A_X1_main[6][1], clus, n_c)[1]:.4f}")
	print(f"  Single linkage purity: {purity(A_X1_main[6][2], clus, n_c)[1]:.4f}")
	print(f"  Ward linkage purity: {purity(A_X1_main[6][3], clus, n_c)[1]:.4f}")

# Y1 Dataset Analysis

The Y1 dataset contains 100 HeLa cells. We will follow the same analysis pipeline as for X1:
1. Load the cell contours and nucleus centers
2. Compute persistence diagrams for each cell
3. Calculate Wasserstein distances between all pairs
4. Perform hierarchical clustering analysis


## Y1 Dataset Setup


In [None]:
# Set dataset parameters for Y1
dataset_Y1 = "Y1"
width_Y1 = 1226  # Largest width of an image in the Y1 dataset
height_Y1 = 1088  # Largest height of an image in the Y1 dataset
conversion_factor_Y1 = 0.1639
# Get all cell files
files_Y1 = os.listdir("../" + dataset_Y1 + "/cell/raw_images")
cell_names_Y1 = sorted([int(file[0:3]) for file in files_Y1 if file.endswith('.tif')])
n_cells_Y1 = len(cell_names_Y1)
print(f"Y1 dataset contains {n_cells_Y1} cells")


## Y1: Load Nucleus Centers


In [None]:
centers_Y1 = pandas.read_csv("../" + dataset_Y1 + "/Nuc_Cm_" + dataset_Y1 + ".csv", index_col="filename")
print(f"Loaded nucleus centers for {len(centers_Y1)} cells")


## Y1: Compute Persistence Diagrams

Compute the persistence diagram for each HeLa cell using the radial function based at the nucleus center.


In [None]:
contours_Y1 = []
for i in cell_names_Y1:
	actin = tifffile.imread(dataset_Y1 + "/cell/raw_images/" + str("%03d" % i) + ".tif")
	height_diff = height_Y1 - actin.shape[0]
	width_diff = width_Y1 - actin.shape[1]
	
	# Create polygon with focal point at nucleus center
	c_i = correa.create_polygon_focal_point(
		"../" + dataset_Y1 + "/cell/contours/" + str("%03d" % i) + "_contour.csv", 
		[centers_Y1.loc[i, "X_m"] + width_diff/2, centers_Y1.loc[i, "Y_m"] + height_diff/2], convert_to_microns_factor=conversion_factor_Y1
	)
	c_i.persistence_diagram()
	contours_Y1.append(c_i)
	
	if (i+1) % 20 == 0:
		print(f"Processed {i+1}/{n_cells_Y1} cells")

print(f"Completed persistence diagrams for all {n_cells_Y1} cells")


## Y1: Compute Wasserstein Distances

Compute the Wasserstein distance between each pair of HeLa cells.


In [None]:
w_distances_Y1 = numpy.zeros((n_cells_Y1, n_cells_Y1))
for i in range(n_cells_Y1):
	for j in range(i, n_cells_Y1):
		dist_ij = correa.wasserstein_distance(contours_Y1[i], contours_Y1[j], q=2)
		dist_ji = correa.wasserstein_distance(contours_Y1[j], contours_Y1[i], q=2)
		dist = (dist_ij + dist_ji) / 2
		w_distances_Y1[i,j] = dist
		w_distances_Y1[j,i] = dist
	
	if (i+1) % 10 == 0:
		print(f"Computed distances for {i+1}/{n_cells_Y1} cells")

print(f"Completed Wasserstein distance computation for all {n_cells_Y1} cells")


## Y1: Distance Heatmap

Display a heatmap of the Wasserstein distances for the Y1 dataset.


In [None]:
fig = px.imshow(w_distances_Y1, width=800, height=800)
fig.update_layout(title="Y1 Wasserstein Distance Heatmap")
fig.show()


In [None]:
dists_Y1 = pandas.DataFrame(w_distances_Y1, columns=cell_names_Y1, index=cell_names_Y1)


## Y1: Hierarchical Clustering Analysis

Perform hierarchical clustering for the Y1 dataset with different linkage methods and numbers of clusters.


In [None]:
A_Y1 = analysis(dists_Y1, [3,4,5], "Y1", "Y1")


## Y1: Compute purity scores based on ```average4```

In [None]:
for k in range(4):
	clus = set(())
	for i in range(len(A_Y1[0]["average4"])):
		if A_Y1[0]["average4"].iloc[i] == str(k):
			clus.add(i)
	print(f"\nCluster {k}:")
	print(f"  Average linkage purity: {purity(A_Y1[6][0], clus, n_c)[1]:.4f}")
	print(f"  Complete linkage purity: {purity(A_Y1[6][1], clus, n_c)[1]:.4f}")
	print(f"  Single linkage purity: {purity(A_Y1[6][2], clus, n_c)[1]:.4f}")
	print(f"  Ward linkage purity: {purity(A_Y1[6][3], clus, n_c)[1]:.4f}")

# X1Y1 Combined Dataset Analysis

The X1Y1 combined dataset contains 240 cells (140 hMSC from X1 and 100 HeLa from Y1). This analysis will:
1. Combine the contours and persistence diagrams from both datasets
2. Compute Wasserstein distances between all pairs of cells across both types
3. Perform hierarchical clustering to see if the two cell types separate
4. Compare results with ground truth labels


## X1Y1: Combine Datasets

Combine the contours from X1 and Y1 datasets and create a combined cell name list.


In [None]:
# Combine contours from both datasets
contours_X1Y1 = contours_X1 + contours_Y1

# Create combined cell names with dataset prefix
cell_names_X1Y1 = [f"X1_{i:03d}" for i in cell_names_X1] + [f"Y1_{i:03d}" for i in cell_names_Y1]
n_cells_X1Y1 = len(cell_names_X1Y1)

print(f"Combined X1Y1 dataset contains {n_cells_X1Y1} cells")
print(f"  - X1 (hMSC): {n_cells_X1} cells")
print(f"  - Y1 (HeLa): {n_cells_Y1} cells")

# Create ground truth labels (1 for X1, 2 for Y1)
ground_truth_2clusters = [1] * n_cells_X1 + [2] * n_cells_Y1
print(f"\nGround truth: {n_cells_X1} cells in cluster 1 (X1), {n_cells_Y1} cells in cluster 2 (Y1)")


## X1Y1: Compute Combined Distance Matrix

Compute the Wasserstein distance matrix for the combined dataset. We can reuse the within-dataset distances and only compute the cross-dataset distances.


In [None]:
# Initialize the combined distance matrix
w_distances_X1Y1 = numpy.zeros((n_cells_X1Y1, n_cells_X1Y1))

# Copy X1 within-dataset distances
w_distances_X1Y1[:n_cells_X1, :n_cells_X1] = w_distances_X1
print("Copied X1 within-dataset distances")

# Copy Y1 within-dataset distances
w_distances_X1Y1[n_cells_X1:, n_cells_X1:] = w_distances_Y1
print("Copied Y1 within-dataset distances")

# Compute cross-dataset distances (X1 vs Y1)
print("Computing cross-dataset distances (X1 vs Y1)...")
for i in range(n_cells_X1):
	for j in range(n_cells_Y1):
		dist_ij = correa.wasserstein_distance(contours_X1[i], contours_Y1[j], q=2)
		dist_ji = correa.wasserstein_distance(contours_Y1[j], contours_X1[i], q=2)
		dist = (dist_ij + dist_ji) / 2
		w_distances_X1Y1[i, n_cells_X1 + j] = dist
		w_distances_X1Y1[n_cells_X1 + j, i] = dist
	
	if (i+1) % 20 == 0:
		print(f"  Computed cross-distances for X1 cell {i+1}/{n_cells_X1}")

print(f"Completed combined distance matrix for all {n_cells_X1Y1} cells")


## X1Y1: Distance Heatmap

Display a heatmap of the combined distance matrix. The block structure should show lower distances within each cell type and higher distances between cell types.


In [None]:
fig = px.imshow(w_distances_X1Y1, width=900, height=900)
fig.update_layout(
	title="X1Y1 Combined Wasserstein Distance Heatmap",
	xaxis_title="Cell Index",
	yaxis_title="Cell Index"
)
# Add annotations to show dataset boundaries
fig.add_vline(x=n_cells_X1-0.5, line_width=2, line_dash="dash", line_color="red")
fig.add_hline(y=n_cells_X1-0.5, line_width=2, line_dash="dash", line_color="red")
fig.show()


In [None]:
dists_X1Y1 = pandas.DataFrame(w_distances_X1Y1, columns=cell_names_X1Y1, index=cell_names_X1Y1)


## X1Y1: Hierarchical Clustering Analysis

Perform hierarchical clustering on the combined dataset to see if the two cell types separate naturally.


In [None]:
A_X1Y1 = analysis(dists_X1Y1, [2,3,4,5], "X1Y1", "X1Y1")


## X1Y1: Evaluate Clustering Quality

Peform the analysis.


In [None]:
A_X1Y1 = analysis(dists_X1Y1, [2,3,4,5,6,7,8,9,10], "X1Y1", "X1Y1")
