In [None]:
import numpy as np
import pandas as pd
from scipy.cluster.hierarchy import dendrogram, linkage
import matplotlib.pyplot as plt

## VISIUM

In [None]:
visium_smoothers_df = pd.read_csv('/lustre/scratch126/cellgen/team292/vl6/VISIUM/femalereproductiveaxis_mese_cytassist_downsampled_fitted_values_tradeseq.csv', 
                              index_col = 0)
print(visium_smoothers_df.shape)
visium_smoothers_df.head()

## scRNA-seq

In [None]:
scrnaseq_smoothers_df = pd.read_csv('/lustre/scratch126/cellgen/team292/vl6/VISIUM/mese_femalereproductiveaxis_scrnaseq_downsampled_fitted_values_tradeseq.csv', 
                               index_col = 0)
print(scrnaseq_smoothers_df.shape)
scrnaseq_smoothers_df.head()

## Comparison of smoothers

In [None]:
import scipy.stats

### How many genes are in common between the two? 

In [None]:
scrnaseq_genes = scrnaseq_smoothers_df.index.tolist()
visium_genes = visium_smoothers_df.index.tolist()
common_genes = list(set(scrnaseq_genes) & set(visium_genes))
scrnaseq_unique = [i for i in scrnaseq_genes if i not in visium_genes]
visium_unique = [i for i in visium_genes if i not in scrnaseq_genes]

In [None]:
len(common_genes), len(scrnaseq_unique), len(visium_unique)

In [None]:
'PNOC' in visium_unique

In [None]:
'HOXA13' in common_genes

In [None]:
import matplotlib_venn

In [None]:
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42

In [None]:
import matplotlib.pyplot as plt
from matplotlib_venn import venn2

# Define the sets
set_visium = 74  # Unique elements in visium
set_scrnaseq = 449  # Unique elements in scrnaseq
intersection = 157

# Create the Venn diagram
venn = venn2(subsets=(set_visium, set_scrnaseq, intersection), set_labels=('Visium', 'scRNA-seq'))

# Customize the colors
venn.get_patch_by_id('10').set_color('blue') # Set A color
venn.get_patch_by_id('01').set_color('orange') # Set B color
venn.get_patch_by_id('11').set_color('yellowgreen') # Intersection color

# Adjust the transparency (alpha)
venn.get_patch_by_id('10').set_alpha(0.5)
venn.get_patch_by_id('01').set_alpha(0.5)
venn.get_patch_by_id('11').set_alpha(0.7)

# Save the plot as a PDF
pdf_filename = 'venn_diagram.pdf'
plt.savefig(pdf_filename, format='pdf')


# Display the plot
plt.show()

In [None]:
scrnaseq_smoothers_df_common = scrnaseq_smoothers_df.loc[common_genes]

In [None]:
visium_smoothers_df_common = visium_smoothers_df.loc[common_genes]

In [None]:
scrnaseq_smoothers_mtx_common = scrnaseq_smoothers_df_common.to_numpy()
visium_smoothers_mtx_common = visium_smoothers_df_common.to_numpy()

### 1. Non-parametric correlation between common genes (Spearman's rank correlation test)

In [None]:
spearman_correlations = []
for i in range(scrnaseq_smoothers_mtx_common.shape[0]):
    corr, _ = scipy.stats.spearmanr(scrnaseq_smoothers_mtx_common[i, :], visium_smoothers_mtx_common[i, :])
    spearman_correlations.append(corr)


In [None]:
len(spearman_correlations)

In [None]:
import matplotlib.pyplot as plt

plt.hist(spearman_correlations, bins=30, color='skyblue', edgecolor='black')
plt.title('Distribution of Spearman Correlation Coefficients')
plt.xlabel('Correlation Coefficient')
plt.ylabel('Frequency')
plt.show()


In [None]:
len(np.asarray(np.array(spearman_correlations) > 0.7).nonzero()[0].tolist())

### 2. Cosine similarity 

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
# Assuming matrix1 and matrix2 are your two matrices
# They should have the same shape: (number_of_genes, number_of_pseudotime_points)

num_genes = scrnaseq_smoothers_mtx_common.shape[0]
gene_cosine_similarities = np.zeros(num_genes)

for i in range(num_genes):
    # Reshape the rows to be 2D arrays as required by cosine_similarity
    gene1 = scrnaseq_smoothers_mtx_common[i, :].reshape(1, -1)
    gene2 = visium_smoothers_mtx_common[i, :].reshape(1, -1)

    # Compute cosine similarity and store it
    gene_cosine_similarities[i] = cosine_similarity(gene1, gene2)[0, 0]

# 'gene_cosine_similarities' now contains the cosine similarity for each gene pair



In [None]:
len(gene_cosine_similarities)

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(5, 4))
plt.hist(gene_cosine_similarities, bins=20, color='gainsboro', edgecolor='black')
plt.title('Distribution of Cosine Similarities')
plt.xlabel('Cosine similarity')
plt.ylabel('Frequency')

# Save the plot as a PDF
histogram_pdf = 'cosine_similarities_histogram.pdf'
plt.savefig(histogram_pdf, format='pdf')

plt.show()


In [None]:
len(np.asarray(np.array(gene_cosine_similarities) > 0.9).nonzero()[0].tolist())

In [None]:
common_pattern_genes = scrnaseq_smoothers_df.iloc[np.asarray(np.array(gene_cosine_similarities) > 0.9).nonzero()[0].tolist()].index.tolist()

In [None]:
104/111

## Select common genes + scRNA-seq specific genes 

In [None]:
tot_genes = common_genes.copy()
tot_genes.extend(scrnaseq_unique)

In [None]:
len(tot_genes)

## Intersect prioritised genes with human TFs

In [None]:
tfs = pd.read_csv('/nfs/team292/vl6/FetalReproductiveTract/humanTFs/DatabaseExtract_v_1.01.csv')

In [None]:
tfs['Is TF?'].value_counts()

In [None]:
tfs['TF assessment'].value_counts()

In [None]:
tfs = tfs[tfs['Is TF?'] == 'Yes']
tfs = tfs[tfs['TF assessment'] == 'Known motif']

In [None]:
tfs = tfs['HGNC symbol'].tolist()

In [None]:
len(tfs)

In [None]:
'CD36' in tfs

In [None]:
tfs_prioritised = [i for i in tot_genes if i in tfs]

In [None]:
len(tfs_prioritised)

In [None]:
tfs_prioritised = [i for i in tfs_prioritised if not i.startswith("HOX")]

In [None]:
len(tfs_prioritised)

In [None]:
print(tfs_prioritised)

## Remove ubiquitously expressed TFs

In [None]:
tfs_prioritised = [i for i in tfs_prioritised if i not in ['ZBTB20', 'NR2F2' , 'JUN', 'NR4A1',
               'NR2F1', 'ATF3', 'SMAD5', 'OSR2', 'RARG', 'TBX3','EGR1', 'KLF2',
               'KLF4', 'SOX4','PBX1', 'MEIS1', 'TBX2', 
                ]]

In [None]:
len(tfs_prioritised)

In [None]:
scrnaseq_smoothers_df_tfs = scrnaseq_smoothers_df.loc[tfs_prioritised]

In [None]:
scrnaseq_smoothers_mtx_tfs = scrnaseq_smoothers_df_tfs.to_numpy()

## Cluster TFs by spatial expression pattern

In [None]:
from sklearn.preprocessing import StandardScaler

In [None]:
scaler = StandardScaler()
scrnaseq_smoothers_mtx_tfs_scaled = scaler.fit_transform(scrnaseq_smoothers_mtx_tfs)

In [None]:
scrnaseq_smoothers_mtx_tfs_scaled.shape

In [None]:
# Perform hierarchical clustering
scrnaseq_smoothers_mtx_tfs_scaled_Z = linkage(scrnaseq_smoothers_mtx_tfs_scaled, method='ward', 
                                             optimal_ordering = True)

In [None]:
common_tfs = [i for i in common_genes if i in tfs_prioritised]

In [None]:
print(tfs_prioritised)

In [None]:
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42

In [None]:
plt.figure(figsize=(9, 3))
plt.title("Hierarchical clustering dendrogram of spatially-variable TFs")

# Create the dendrogram
dendro = dendrogram(scrnaseq_smoothers_mtx_tfs_scaled_Z, labels=scrnaseq_smoothers_df_tfs.index.to_list(),
           leaf_rotation=45, leaf_font_size=10)

# Highlight the common TFs
ax = plt.gca()
x_labels = ax.get_xmajorticklabels()
for label in x_labels:
    if label.get_text() in common_tfs:
        label.set_fontweight('bold')

plt.xlabel("TFs")
plt.ylabel("Distance")

# Save the plot as a PDF
dendrogram_pdf_path = 'hierarchical_clustering_dendrogram.pdf'
plt.savefig(dendrogram_pdf_path, format='pdf', bbox_inches='tight')  # bbox_inches='tight' ensures that labels are not cut off


plt.show()

In [None]:
from scipy.cluster.hierarchy import fcluster

In [None]:
# Choosing a distance cutoff (or setting a specific number of clusters)
distance_cutoff = 10  # example value, adjust based on your dendrogram
clusters = fcluster(scrnaseq_smoothers_mtx_tfs_scaled_Z, distance_cutoff, criterion='distance')

# clusters now contains the cluster ID for each gene

In [None]:
len(np.unique(clusters))

In [None]:
cluster_number = 6

In [None]:
gene_indices_in_cluster = np.asarray(clusters == cluster_number).nonzero()[0].tolist()

In [None]:
scrnaseq_smoothers_df_tfs.iloc[gene_indices_in_cluster]

In [None]:
cluster_genes = scrnaseq_smoothers_df_tfs.iloc[gene_indices_in_cluster].index.to_list()

In [None]:
fitted_values_cluster = scrnaseq_smoothers_mtx_tfs[gene_indices_in_cluster, :]

In [None]:
fitted_values_cluster.shape

In [None]:
cluster_genes[0]

In [None]:
plt.figure(figsize=(8
            , 7))
pseudospace = np.linspace(-2, 4, 100)
i = 0
for gene_fitted_values in fitted_values_cluster:
    plt.plot(pseudospace, gene_fitted_values, 
             alpha = 0.5, label=cluster_genes[i])  # Plot each gene's spline
    i = i+1

plt.title(f"Splines for Genes in Cluster {cluster_number}")
plt.xlabel("Müllerian longitudinal axis")
plt.ylabel("Fitted Values")
# Display the legend
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
# Save the plot as a PDF
cluster3_tfs = 'cluster11_tfs.pdf'
plt.savefig(cluster3_tfs, format='pdf', bbox_inches='tight')  # bbox_inches='tight' ensures that labels are not cut off

plt.show()


## Intersect prioritised genes with ligands and receptors from CellPhoneDB v5

In [None]:
df_cellphone = pd.read_csv('./COMMOT_database/minimal_cellphonedb_commot.csv', index_col = 0)
print(df_cellphone.shape)
df_cellphone.head()

In [None]:
df_cellphone['directionality'].value_counts(dropna = False)

In [None]:
df_cellphone = df_cellphone.rename({'gene_name_a' : 'ligand', 'gene_name_b' : 'receptor', 
                                   'classification' : 'classification', 'directionality' : 'directionality'}, axis = 1)

In [None]:
df_cellphone.head()

In [None]:
df_cellphone['ligand_spatially_variable'] = 0

In [None]:
df_cellphone['receptor_spatially_variable'] = 0

In [None]:
df_cellphone.head()

In [None]:
def is_partner_variable(partner, spatially_variable_genes):
    print(partner)
    res = 0
    if '_' in partner:
        print('partner is heteromeric')
        partners = partner.split('_')
        partners_variable = []
        for p in partner: 
            if p in spatially_variable_genes:
                partners_variable.append(p)
        print(partners_variable)
        if len(partners_variable) > 0:
            print('at least one subunit of heteromeric partner are spatially variable')
            res = 1
        else:
            res = 0
    else: 
        print('partner is monomeric')
        if partner in spatially_variable_genes:
            print('partner is spatially variable')
            res = 1
        else:
            res = 0
    return res

In [None]:
len(tot_genes)

In [None]:
spatially_variable_ligands_indices = []
spatially_variable_receptors_indices = []
for index, row in df_cellphone.iterrows():
    ligand = row['ligand']
    receptor = row['receptor']
    ligand_spatially_variable = is_partner_variable(ligand, tot_genes)
    receptor_spatially_variable = is_partner_variable(receptor, tot_genes)
    if ligand_spatially_variable == 1:
        spatially_variable_ligands_indices.append(index)
    if receptor_spatially_variable == 1:
        spatially_variable_receptors_indices.append(index)
    

In [None]:
print(len(spatially_variable_ligands_indices), len(spatially_variable_receptors_indices))

In [None]:
# Convert lists to sets and use the union() method
union_set = set(spatially_variable_ligands_indices).union(set(spatially_variable_receptors_indices))

# If you need the result as a list
union_list = list(union_set)


In [None]:
len(union_list)

In [None]:
df_cellphone = df_cellphone.iloc[union_list, :]

In [None]:
df_cellphone.loc[spatially_variable_ligands_indices, 'ligand_spatially_variable'] = 1
df_cellphone.loc[spatially_variable_receptors_indices, 'receptor_spatially_variable'] = 1

In [None]:
df_cellphone.head()

In [None]:
def bin_axis(adata, n_bins = 10, axis_name = 'FemaleReproductiveAxis'):
    max_val = np.nanquantile(adata.obs[axis_name],0.9999)
    min_val = np.nanquantile(adata.obs[axis_name],0.0001)

    incr_uva = (max_val-min_val)/n_bins
    uva_bins = []
    for i in range(0,n_bins+1):
        uva_bins.append(min_val+(i*incr_uva))

    adata.obs['binned_' + axis_name + '_' + str(n_bins) + '_bins'] = pd.cut(adata.obs[axis_name], bins = uva_bins, labels=uva_bins[:-1])
    return adata

In [None]:
len(tot_genes)

In [None]:
'LGR5' in df_cellphone['receptor'].values

In [None]:
def grouped_obs_percent(adata, group_key, gene):
    getX = lambda x: x.X
    adata = adata[:, gene]
    grouped = adata.obs.groupby(group_key)
    out = pd.DataFrame(
        np.zeros((adata.shape[1], len(grouped)), dtype=np.float64),
        columns=list(grouped.groups.keys()),
        index=adata.var_names
    )

    for group, idx in grouped.indices.items():
        X = getX(adata[idx])
        X.data = X.data > 0
        perc = np.asarray(np.sum(X,axis=0)/X.shape[0]).reshape(-1)
        out[group] = [round(i, 2) for i in perc ]
    return out


def grouped_obs_mean(adata, group_key, gene):
    getX = lambda x: x.X
    adata = adata[:, gene]
    grouped = adata.obs.groupby(group_key)
    out = pd.DataFrame(
        np.zeros((adata.shape[1], len(grouped)), dtype=np.float64),
        columns=list(grouped.groups.keys()),
        index=adata.var_names
    )

    for group, idx in grouped.indices.items():
        X = getX(adata[idx])
        out[group] = np.ravel(X.mean(axis=0, dtype=np.float64))
    out_min_max = (out - out.loc[gene].min()) / (out.loc[gene].max() - out.loc[gene].min())
    return out_min_max

In [None]:
def have_common_elements(list1, list2, n_elements):
    # Convert the lists to sets to find the common elements
    set1 = set(list1)
    set2 = set(list2)
    
    # Find the common elements
    common_elements = set1.intersection(set2)
    
    # Check if there are at least n common elements
    return len(common_elements) >= n_elements

In [None]:
def has_consecutive_numbers(ordered_list):
    
    # Iterate through the sorted list
    for i in range(len(ordered_list) - 1):
        # Check if the current and next elements are consecutive
        if ordered_list[i] + 1 == ordered_list[i + 1]:
            return True  # Found consecutive numbers
    
    return False  # No consecutive numbers found

In [None]:
import scanpy as sc
import anndata as ad

In [None]:
scRNAseq_mullerian_mese_epi = sc.read('/nfs/team292/vl6/FetalReproductiveTract/mese_epi_mullerian_withpseudospace.h5ad')
scRNAseq_mullerian_mese_epi

In [None]:
mese = sc.read('/nfs/team292/vl6/FetalReproductiveTract/mese_mullerian_withpseudospace.h5ad')

In [None]:
epi = sc.read('/nfs/team292/vl6/FetalReproductiveTract/epi_mullerian_withpseudospace.h5ad')

In [None]:
epi = bin_axis(epi, n_bins = 10, axis_name = 'FemaleReproductiveAxis')
epi.obs['binned_FemaleReproductiveAxis_10_bins'].hist()

In [None]:
mese.X[20:30, 20:30].toarray()

In [None]:
mese = bin_axis(mese, n_bins = 10, axis_name = 'FemaleReproductiveAxis')
mese.obs['binned_FemaleReproductiveAxis_10_bins'].hist()

In [None]:
sc.set_figure_params(scanpy=True, dpi=80, dpi_save=150, 
                         frameon=True, vector_friendly=True, fontsize=14, figsize=[7,7], color_map=None, 
                         format='pdf', facecolor=None, transparent=False)


In [None]:
len(tfs_prioritised)

In [None]:
tfs_to_plot = ['PROX1', 'GATA6', 'NFATC2','LEF1',
               'FOXL2',    'MEIS2','EMX2', 'FOXO1', 'ESR1', 'RORB','HMGA2','MSX1',  'AR',
               'TWIST1',
               'ESRRG', 'RUNX1','PRRX2','TWIST2',  'LBX2', 'PBX3',
              'AHR',  'EVX1', 'EVX2', 'IRF6','NR0B1', 'ISL1', 'HMBOX1',   
 'ASCL2', 'TBX18',
 
  ]

In [None]:
len(tfs_to_plot)

In [None]:
# Create matrix plot using scanpy
sc.pl.matrixplot(mese,var_names=tfs_to_plot,groupby='binned_FemaleReproductiveAxis_10_bins',
                 num_categories = 10,
                 cmap='OrRd',
                 standard_scale='var',dendrogram=False,
                 save='scrnaseq_matrixplot_mese_tfs_spatiallyvariable.pdf'
                )

In [None]:
def correlate_with_axis_and_order(adata, genes, axis_name, n_bins):
    df_list = []
    gene_list = []
    for gene in genes:
        out = grouped_obs_mean(adata, 'binned_' + axis_name + '_' + str(n_bins) + '_bins', gene)
        df_list.append(out)
        gene_list.append(gene)
    df = pd.concat(df_list)
    df_new_index = pd.DataFrame(df, index=gene_list)
    
    # Calculate correlation between smoothed scores and spatial axis values
    df_new_index = df_new_index.T
    correlations = np.corrcoef(df_new_index.T, df_new_index.index.to_list())
    interaction_correlations = correlations[:-1, -1]  # Exclude correlation with itself
    
    # Order interactions based on enrichment measure
    ordered_indices = np.argsort(interaction_correlations)[::-1]
    df_new_index = df_new_index.iloc[:, ordered_indices[::-1]]
    return df_new_index

In [None]:
mese_downsampled = sc.read('/nfs/team292/vl6/FetalReproductiveTract/mese_mullerian_withpseudospace_downsampled.h5ad')
mese_downsampled

In [None]:
epi_downsampled = sc.read('/nfs/team292/vl6/FetalReproductiveTract/epi_mullerian_withpseudospace_downsampled.h5ad')
epi_downsampled 

In [None]:
import seaborn as sns

In [None]:
def evaluate_spatially_variable_ligand_interaction(adata_mese, adata_epi, axis_name, n_bins, cellphone_filtered, min_prop, 
                                                  spatially_variable_genes):

    
    # iterate over interactions and keep those that satisfy requirements 
    passed = []
    cellphone_filtered['starred_bins'] = 'none'
    for index, row in cellphone_filtered.iterrows():
        
        if row['ligand_spatially_variable'] == 1:
            print('Looking at interaction between {} and {}'.format(row['ligand'], row['receptor']))
            # mesenchymal ligand is spatially variable, so plot the ligand expression in the binned axis 
            lig = row['ligand']
            if lig in adata_mese.var_names.to_list():
                # Compute fraction of spots expressing the ligand 
                lig_frac = grouped_obs_percent(adata_mese, 'binned_' + axis_name + '_' + str(n_bins) + '_bins', lig)
                print(lig_frac)
                lig_frac = np.where((lig_frac > min_prop).all(axis=0))[0]
                
                if len(lig_frac) >= 2:
                    print('There are {} bins that express {} in more than 0.3 of cells'.format(len(lig_frac), 
                                                                                         lig))
                    # Compute averaged, 0-1 normalised expression of ligand in each bin 
                    out_lig = grouped_obs_mean(adata_mese, 'binned_' + axis_name + '_' + str(n_bins) + '_bins', lig)

                    # Get the bin values where 0-1 normalised expression is > 0.4 
                    lig_high_expr_bins = np.where((out_lig > 0.4).all(axis=0))[0]
                    lig_high_expr_bins = [i for i in lig_high_expr_bins if i in lig_frac]
                    print(lig_high_expr_bins)

                    rec = row['receptor']
                    if '_' in rec:
                        recs = rec.split('_')
                        if all(element in adata_epi.var_names.to_list() for element in recs):
                            print('All receptors for {} in var_names'.format(lig))
                            recs_list = [] 
                            for r in recs:
                                # Compute fraction of spots in bin that express each of the receptor subunits 
                                out_rec = grouped_obs_percent(adata_epi, 'binned_' + axis_name + '_' + str(n_bins) + '_bins', r)
                                recs_list.append(out_rec)
                            out_recs = pd.concat(recs_list, ignore_index=False)
                            print(out_recs)
                            rec_min_percent = np.where((out_recs > min_prop).all(axis=0))[0]
                            print(rec_min_percent)
                            consec = has_consecutive_numbers(rec_min_percent)
                            if consec and have_common_elements(lig_high_expr_bins, rec_min_percent, 2):
                                print('There is an interaction')
                                plt.figure(figsize=(3, 5))
                                sns.heatmap(out_lig.T, cmap='OrRd', annot=True, fmt=".2f", linewidths=.5, 
                                                  cbar = True)
                                plt.figure(figsize=(3, 5))
                                sns.heatmap(out_recs.T, cmap='OrRd', annot=True, fmt=".2f", linewidths=.5, 
                                                  cbar = True)
                                passed.append(index)
                                starred_bins = [str(i) for i in lig_high_expr_bins if i in rec_min_percent]
                                cellphone_filtered.loc[index, 'starred_bins'] = ','.join(starred_bins)
                            else:
                                print('There is NO interaction')
                        else:
                            print('At least one receptor for {} not in var_names'.format(lig))
                    else:
                        if rec in adata_epi.var_names.to_list():
                            print('Receptor for {} in var_names'.format(lig))

                            out_rec = grouped_obs_percent(adata_epi, 'binned_' + axis_name + '_' + str(n_bins) + '_bins', rec)
                            print(out_rec)
                            rec_min_percent = np.where((out_rec > min_prop).all(axis=0))[0]
                            print(rec_min_percent)
                            consec = has_consecutive_numbers(rec_min_percent)

                            if consec and have_common_elements(lig_high_expr_bins, rec_min_percent, 2):
                                print('There is an interaction')
                                plt.figure(figsize=(3, 5))
                                sns.heatmap(out_lig.T, cmap='OrRd', annot=True, fmt=".2f", linewidths=.5, 
                                                  cbar = True)
                                plt.figure(figsize=(3, 5))
                                sns.heatmap(out_rec.T, cmap='OrRd', annot=True, fmt=".2f", linewidths=.5, 
                                                  cbar = True)
                                passed.append(index)
                                starred_bins = [str(i) for i in lig_high_expr_bins if i in rec_min_percent]
                                cellphone_filtered.loc[index, 'starred_bins'] = ','.join(starred_bins)
                            else:
                                print('There is NO interaction')
                        else:
                            print('Receptor for {} not in var_names'.format(lig))
            
    print(passed)        
    return cellphone_filtered.loc[passed, :]

            
            
        

In [None]:
filter_1 = evaluate_spatially_variable_ligand_interaction(mese_downsampled, epi_downsampled, 'FemaleReproductiveAxis',
                                                         10, df_cellphone, 0.2, tot_genes)


In [None]:
filter_1_partners = np.unique(filter_1['ligand'].to_list())

In [None]:
len(np.unique(filter_1['ligand'].to_list()))

In [None]:
filter_1_partners

In [None]:
def evaluate_spatially_variable_receptor_interaction(adata_mese, adata_epi, axis_name, n_bins, cellphone_filtered, min_prop, 
                                                  spatially_variable_genes):
    
    
    # iterate over interactions and keep those that satisfy requirements 
    passed = []
    cellphone_filtered['starred_bins'] = 'none'
    for index, row in cellphone_filtered.iterrows():
        
        
        if row['receptor_spatially_variable'] == 1:
            print('Looking at interaction between {} and {}'.format(row['ligand'], row['receptor']))
            print('Assuming this is not a heteromeric interaction!')
            # mesenchymal ligand is spatially variable, so plot the ligand expression in the binned axis 
            rec = row['receptor']
            if rec in adata_mese.var_names.to_list():
                # Compute fraction of spots expressing the ligand 
                rec_frac = grouped_obs_percent(adata_mese, 'binned_' + axis_name + '_' + str(n_bins) + '_bins', rec)
                rec_frac = np.where((rec_frac > min_prop).all(axis=0))[0]
                print('There are {} bins that express {} in more than 20% of cells'.format(len(rec_frac), 
                                                                                         rec))
                if len(rec_frac) >= 2:

                    # Compute averaged, 0-1 normalised expression of ligand in each bin 
                    out_rec = grouped_obs_mean(mese, 'binned_' + axis_name + '_' + str(n_bins) + '_bins', rec)

                    # Get the bin values where 0-1 normalised expression is > 0.4 
                    rec_high_expr_bins = np.where((out_rec > 0.4).all(axis=0))[0]
                    rec_high_expr_bins = [i for i in rec_high_expr_bins if i in rec_frac]
                    print(rec_high_expr_bins)

                    lig = row['ligand']
                    
                    if '_' in lig:
                        ligs = lig.split('_')
                        if all(element in adata_epi.var_names.to_list() for element in ligs):
                            print('All ligands for {} in var_names'.format(rec))
                            ligs_list = [] 
                            for l in ligs:
                                # Compute fraction of spots in bin that express each of the receptor subunits 
                                out_lig = grouped_obs_percent(adata_epi, 'binned_' + axis_name + '_' + str(n_bins) + '_bins', l)
                                ligs_list.append(out_lig)
                            out_ligs = pd.concat(ligs_list, ignore_index=False)
                            print(out_ligs)
                            lig_min_percent = np.where((out_ligs > min_prop).all(axis=0))[0]
                            print(lig_min_percent)
                            consec = has_consecutive_numbers(lig_min_percent)
                            if consec and have_common_elements(rec_high_expr_bins, lig_min_percent, 2):
                                print('There is an interaction')
                                plt.figure(figsize=(3, 5))
                                sns.heatmap(out_rec.T, cmap='OrRd', annot=True, fmt=".2f", linewidths=.5, 
                                                  cbar = True)
                                plt.figure(figsize=(3, 5))
                                sns.heatmap(out_ligs.T, cmap='OrRd', annot=True, fmt=".2f", linewidths=.5, 
                                                  cbar = True)
                                passed.append(index)
                                starred_bins = [str(i) for i in rec_high_expr_bins if i in lig_min_percent]
                                cellphone_filtered.loc[index, 'starred_bins'] = ','.join(starred_bins)
                            else:
                                print('There is NO interaction')
                        else:
                            print('At least one ligand for {} not in var_names'.format(rec))
                    else:
                        if lig in adata_epi.var_names.to_list():
                            print('Ligand for {} in var_names'.format(rec))

                            out_lig = grouped_obs_percent(adata_epi, 'binned_' + axis_name + '_' + str(n_bins) + '_bins', lig)
                            print(out_lig)
                            lig_min_percent = np.where((out_lig > min_prop).all(axis=0))[0]
                            print(lig_min_percent)
                            consec = has_consecutive_numbers(lig_min_percent)

                            if consec and have_common_elements(rec_high_expr_bins, lig_min_percent, 2):
                                print('There is an interaction')
                                plt.figure(figsize=(3, 5))
                                sns.heatmap(out_rec.T, cmap='OrRd', annot=True, fmt=".2f", linewidths=.5, 
                                                  cbar = True)
                                plt.figure(figsize=(3, 5))
                                sns.heatmap(out_lig.T, cmap='OrRd', annot=True, fmt=".2f", linewidths=.5, 
                                                  cbar = True)
                                passed.append(index)
                                starred_bins = [str(i) for i in rec_high_expr_bins if i in lig_min_percent]
                                cellphone_filtered.loc[index, 'starred_bins'] = ','.join(starred_bins)
                            else:
                                print('There is NO interaction')
                        else:
                            print('Ligand for {} not in var_names'.format(rec))
            
    print(passed)        
    return cellphone_filtered.loc[passed, :]


In [None]:
filter_2 = evaluate_spatially_variable_receptor_interaction(mese_downsampled, epi_downsampled, 'FemaleReproductiveAxis',
                                                         10, df_cellphone, 0.2, tot_genes)


In [None]:
len(np.unique(filter_2['receptor']))

In [None]:
filter_2_partners = np.unique(filter_2['receptor'].to_list())

In [None]:
print(filter_2_partners)

In [None]:
filter_2_selected = filter_2[filter_2['receptor'].isin(filter_2_partners)]

In [None]:
filter_2_selected

## Hierarchical clustering of spatially variable interacting partners 

In [None]:
spatially_variable_partners = list(filter_1_partners)

In [None]:
spatially_variable_partners.extend(list(filter_2_partners))

In [None]:
spatially_variable_partners = list(np.unique(spatially_variable_partners))

In [None]:
len(spatially_variable_partners)

In [None]:
print(spatially_variable_partners)

### Filter out genes that are expressed everywhere in the manifold 

In [None]:
spatially_variable_partners = [i for i in spatially_variable_partners if i not in ['APP', 'AR', 'COL1A2', 'COL21A1', 
                                                                                  'COL5A2', 'EDNRA', 'EFNB2', 'FN1', 
                                                                                  'IGFBP3', 'LRPAP1', 
                                                                                  'NRXN1', 'NRXN3', 
                                                                                  'PDGFRA', 'ROBO2', 'SFRP1',
                                                                                  'SFRP4', 'TGFBR3', 'WNT11', 
                                                                                  'SCARA5', 'TNFRSF21', 'IGF2']]

In [None]:
scrnaseq_smoothers_df.shape

In [None]:
scrnaseq_smoothers_df_ligands = scrnaseq_smoothers_df.loc[spatially_variable_partners]

In [None]:
scrnaseq_smoothers_df_ligands.shape

In [None]:
scrnaseq_smoothers_mtx_ligands = scrnaseq_smoothers_df_ligands.to_numpy()

In [None]:
scaler = StandardScaler()
scrnaseq_smoothers_mtx_ligands_scaled = scaler.fit_transform(scrnaseq_smoothers_mtx_ligands)

In [None]:
# Perform hierarchical clustering
scrnaseq_smoothers_mtx_ligands_scaled_Z = linkage(scrnaseq_smoothers_mtx_ligands_scaled, method='ward', 
                                             optimal_ordering = True)

In [None]:
common_ligands = [i for i in common_genes if i in spatially_variable_partners]

In [None]:
common_ligands

In [None]:
print(spatially_variable_partners)

In [None]:
plt.figure(figsize=(7, 2.5))
plt.title("Hierarchical clustering dendrogram of spatially-variable interacting partners")

# Create the dendrogram
dendro = dendrogram(scrnaseq_smoothers_mtx_ligands_scaled_Z, labels=scrnaseq_smoothers_df_ligands.index.to_list(),
           leaf_rotation=45, leaf_font_size=10)

# Highlight the common TFs
ax = plt.gca()
x_labels = ax.get_xmajorticklabels()
for label in x_labels:
    if label.get_text() in common_ligands:
        label.set_fontweight('bold')

plt.xlabel("Interacting partners")
plt.ylabel("Distance")

# Save the plot as a PDF
dendrogram_pdf_path = 'hierarchical_clustering_dendrogram_ligands.pdf'
plt.savefig(dendrogram_pdf_path, format='pdf', bbox_inches='tight')  # bbox_inches='tight' ensures that labels are not cut off


plt.show()

## Epithelial partners

In [None]:
filter_2_epi = filter_2[filter_2['receptor'].isin(spatially_variable_partners)]

In [None]:
filter_2_epi

In [None]:
print(filter_2_epi['ligand'].to_list())

In [None]:
filter_1_epi = filter_1[filter_1['ligand'].isin(spatially_variable_partners)]

In [None]:
pd.set_option('display.max_rows', 100)

In [None]:
filter_1_epi

In [None]:
print(np.unique(filter_1_epi['receptor'].to_list()))

## Plot heatmap of mesenchymal and epithelial partners along the axis

In [None]:
mese_partners_reordered = correlate_with_axis_and_order(mese, spatially_variable_partners, 'FemaleReproductiveAxis',
                                                         10)

In [None]:

# Create matrix plot using scanpy
sc.pl.matrixplot(mese,var_names=mese_partners_reordered.columns.to_list(),groupby='binned_FemaleReproductiveAxis_10_bins',
                 num_categories = 10,
                 cmap='OrRd',
                 standard_scale='var',dendrogram=False,
                 save='scrnaseq_matrixplot_mese_partners_spatiallyvariable.pdf'
                )

In [None]:
mese_partners_reordered.columns

In [None]:
mese_partners_reordered_manual = ['LGR5',  'NTRK2','CD36', 'CD55', 'ALDH1A2', 'DLK1','NRG1', 
            'WNT4', 'WNT5A', 'FLRT2',  'GRIA4', 'FGF7','TGM2', 'ALDH1A1', 'LRRTM1', 'NRP1',
                                  'NRP2','RORB',
              'GDF7', 'TNC', 'WIF1', 'SFRP5', 
           'IGF1', 'BMP4', 'BMP7',   
              ]

In [None]:
overlap = [i for i in mese_partners_reordered_manual if i not in mese_partners_reordered.columns.to_list() ]

In [None]:
overlap

In [None]:

# Create matrix plot using scanpy
sc.pl.matrixplot(mese,var_names=mese_partners_reordered_manual,groupby='binned_FemaleReproductiveAxis_10_bins',
                 num_categories = 10,
                 cmap='OrRd',
                 standard_scale='var',dendrogram=False,
                 save='scrnaseq_matrixplot_mese_partners_spatiallyvariable.pdf'
                )

In [None]:
epi_partners = ['RSPO1', 'THBS1', 'NTF3', 'ADGRE5', 'RXRA', 'RARA', 'RXRB', 'RARB', 'CRABP2', 
                'NOTCH2', 'ERBB4', 'LRP5', 'LRP6',   'FZD2', 'FZD10', 
                'FZD3', 'ADGRL1', 'ADGRL3', 'SLC1A3', 'GLS', 'FGFR2', 'ADGRG1', 'NRXN3', 
                'SEMA3C', 'SEMA3F', 'ALDH1A1', 'ALDH1A3', 'BMPR1A', 'BMPR2', 
                'ITGAV', 'ITGB6',
                 'WNT7A', 'WNT11', 'IGF1R'
               ]

In [None]:
# Create matrix plot using scanpy
sc.pl.matrixplot(epi,var_names=epi_partners,groupby='binned_FemaleReproductiveAxis_10_bins',
                 num_categories = 10,
                 cmap='OrRd',
                 standard_scale='var',dendrogram=False,
                 save='scrnaseq_matrixplot_epi_partners_spatiallyvariable.pdf'
                )