In [None]:
import numpy as np
import pandas as pd
from scipy.cluster.hierarchy import dendrogram, linkage
import matplotlib.pyplot as plt

## VISIUM smoothers

In [None]:
visium_smoothers_df = pd.read_csv('/lustre/scratch126/cellgen/team292/vl6/VISIUM/malereproductiveaxis_mese_cytassist_downsampled_fitted_values_tradeseq.csv', 
                              index_col = 0)
print(visium_smoothers_df.shape)
visium_smoothers_df.head()

## scRNA-seq anndata object for bona fide secretory mesenchymal genes

In [None]:
import scanpy as sc
import anndata

In [None]:
scrnaseq = sc.read('/nfs/team292/vl6/FetalReproductiveTract/males_post10pcw.20240326.h5ad')
scrnaseq.obs['celltype'].value_counts()

In [None]:
# Select cell types that might be contaminating the Visium signal 
scrnaseq = scrnaseq[[i in ['Corpus/CaputEpididymis_Mesenchyme', 'CaudaEpididymis_Mesenchyme', 'CaputEpididymis_Epithelium', 
                          'Endothelial_Lymphatic', 'Erythroid', 'Neural', 'Immune', 'Coelomic_Epithelium', 
                          'Epididymis_Ligament', 'Pre-Perivascular', 'SchwannCell', 
                          'Corpus/CaudaEpididymis_Epithelium', 'Endothelial', 'Perivascular', 
                          'CaputEpididymis_Ciliated_Epithelium', 'SkeletalMuscle', 
                          'VasDeferens_Ligament', 'VasDeferens_Mesenchyme', ] for i in scrnaseq.obs['celltype']]]
scrnaseq.shape

## Compare expression of genes across cell types

In [None]:
genes = visium_smoothers_df.index.to_list()
len(genes)

In [None]:
genes = [i for i in genes if i in scrnaseq.var_names.to_list()]

In [None]:
len(genes)

In [None]:
cell_types_of_interest = ['VasDeferens_Mesenchyme', 'Corpus/CaputEpididymis_Mesenchyme']

In [None]:
# Calculate average expression per cell type
average_expression = scrnaseq.to_df().groupby(scrnaseq.obs['celltype']).mean()

# Filter the average expression table to include only the genes of interest
average_expression = average_expression.loc[:, genes]

# Create a table to summarize the results
summary_table = average_expression.reset_index()

In [None]:
summary_table

In [None]:
# Filtering step 1
filtered_genes = []
for gene in genes:
    # Sort the average expression of the gene across all cell types in descending order
    sorted_expression = average_expression[gene].sort_values(ascending=False)
    
    # Check if either cell type of interest is within the top 3 cell types
    if any(cell_type in sorted_expression.index[:3] for cell_type in cell_types_of_interest):
        filtered_genes.append(gene)

# Filter the summary table to retain only the filtered genes
filtered_summary_table = summary_table[['celltype'] + filtered_genes]

In [None]:
filtered_summary_table

### Save remaining genes and plot their trends in TradeSeq 

In [None]:
print(filtered_summary_table.columns.to_list())

## Intersect prioritised genes with human TFs

In [None]:
tfs = pd.read_csv('/nfs/team292/vl6/FetalReproductiveTract/humanTFs/DatabaseExtract_v_1.01.csv')

In [None]:
tfs = tfs[tfs['Is TF?'] == 'Yes']
tfs = tfs[tfs['TF assessment'] == 'Known motif']
tfs = tfs['HGNC symbol'].tolist()
len(tfs)

In [None]:
tfs_prioritised = [i for i in filtered_summary_table.columns.to_list() if i in tfs]

In [None]:
print(len(tfs_prioritised))
print(tfs_prioritised)

In [None]:
tfs_prioritised = [i for i in tfs_prioritised if not i.startswith("HOX")]

In [None]:
len(tfs_prioritised)

In [None]:
print(tfs_prioritised)

## Intersect prioritised genes with ligands and receptors from CellPhoneDB v5

In [None]:
df_cellphone = pd.read_csv('./COMMOT_database/minimal_cellphonedb_commot.csv', index_col = 0)
print(df_cellphone.shape)
df_cellphone.head()

In [None]:
df_cellphone['directionality'].value_counts(dropna = False)

In [None]:
df_cellphone = df_cellphone.rename({'gene_name_a' : 'ligand', 'gene_name_b' : 'receptor', 
                                   'classification' : 'classification', 'directionality' : 'directionality'}, axis = 1)

In [None]:
df_cellphone.head()

In [None]:
df_cellphone['ligand_spatially_variable'] = 0
df_cellphone['receptor_spatially_variable'] = 0
df_cellphone.head()

In [None]:
def is_partner_variable(partner, spatially_variable_genes):
    print(partner)
    res = 0
    if '_' in partner:
        print('partner is heteromeric')
        partners = partner.split('_')
        partners_variable = []
        for p in partner: 
            if p in spatially_variable_genes:
                partners_variable.append(p)
        print(partners_variable)
        if len(partners_variable) > 0:
            print('at least one subunit of heteromeric partner are spatially variable')
            res = 1
        else:
            res = 0
    else: 
        print('partner is monomeric')
        if partner in spatially_variable_genes:
            print('partner is spatially variable')
            res = 1
        else:
            res = 0
    return res

In [None]:
tot_genes = filtered_summary_table.columns.to_list()[1:]

In [None]:
len(tot_genes)

In [None]:
spatially_variable_ligands_indices = []
spatially_variable_receptors_indices = []
for index, row in df_cellphone.iterrows():
    ligand = row['ligand']
    receptor = row['receptor']
    ligand_spatially_variable = is_partner_variable(ligand, tot_genes)
    receptor_spatially_variable = is_partner_variable(receptor, tot_genes)
    if ligand_spatially_variable == 1:
        spatially_variable_ligands_indices.append(index)
    if receptor_spatially_variable == 1:
        spatially_variable_receptors_indices.append(index)

In [None]:
print(len(spatially_variable_ligands_indices), len(spatially_variable_receptors_indices))

In [None]:
# Convert lists to sets and use the union() method
union_set = set(spatially_variable_ligands_indices).union(set(spatially_variable_receptors_indices))

# If you need the result as a list
union_list = list(union_set)


In [None]:
len(union_list)

In [None]:
df_cellphone = df_cellphone.iloc[union_list, :]

In [None]:
df_cellphone.loc[spatially_variable_ligands_indices, 'ligand_spatially_variable'] = 1
df_cellphone.loc[spatially_variable_receptors_indices, 'receptor_spatially_variable'] = 1

In [None]:
df_cellphone.head()

In [None]:
def bin_axis(adata, n_bins = 10, axis_name = 'FemaleReproductiveAxis'):
    max_val = np.nanquantile(adata.obs[axis_name],0.9999)
    min_val = np.nanquantile(adata.obs[axis_name],0.0001)

    incr_uva = (max_val-min_val)/n_bins
    uva_bins = []
    for i in range(0,n_bins+1):
        uva_bins.append(min_val+(i*incr_uva))

    adata.obs['binned_' + axis_name + '_' + str(n_bins) + '_bins'] = pd.cut(adata.obs[axis_name], bins = uva_bins, labels=uva_bins[:-1])
    return adata

In [None]:
def grouped_obs_percent(adata, group_key, gene):
    getX = lambda x: x.X
    adata = adata[:, gene]
    grouped = adata.obs.groupby(group_key)
    out = pd.DataFrame(
        np.zeros((adata.shape[1], len(grouped)), dtype=np.float64),
        columns=list(grouped.groups.keys()),
        index=adata.var_names
    )

    for group, idx in grouped.indices.items():
        X = getX(adata[idx])
        X.data = X.data > 0
        perc = np.asarray(np.sum(X,axis=0)/X.shape[0]).reshape(-1)
        out[group] = [round(i, 2) for i in perc ]
    return out


def grouped_obs_mean(adata, group_key, gene):
    getX = lambda x: x.X
    adata = adata[:, gene]
    grouped = adata.obs.groupby(group_key)
    out = pd.DataFrame(
        np.zeros((adata.shape[1], len(grouped)), dtype=np.float64),
        columns=list(grouped.groups.keys()),
        index=adata.var_names
    )

    for group, idx in grouped.indices.items():
        X = getX(adata[idx])
        out[group] = np.ravel(X.mean(axis=0, dtype=np.float64))
    out_min_max = (out - out.loc[gene].min()) / (out.loc[gene].max() - out.loc[gene].min())
    return out_min_max

In [None]:
def have_common_elements(list1, list2, n_elements):
    # Convert the lists to sets to find the common elements
    set1 = set(list1)
    set2 = set(list2)
    
    # Find the common elements
    common_elements = set1.intersection(set2)
    
    # Check if there are at least n common elements
    return len(common_elements) >= n_elements

In [None]:
def has_consecutive_numbers(ordered_list):
    
    # Iterate through the sorted list
    for i in range(len(ordered_list) - 1):
        # Check if the current and next elements are consecutive
        if ordered_list[i] + 1 == ordered_list[i + 1]:
            return True  # Found consecutive numbers
    
    return False  # No consecutive numbers found


In [None]:
import scanpy as sc
import anndata as ad

In [None]:
wolffian_mese = sc.read('/lustre/scratch126/cellgen/team292/vl6/VISIUM/malereproductiveaxis_mese_downsampled_cytassist.h5ad')
wolffian_mese

In [None]:
wolffian_mese.X[20:30, 20:30].toarray()

In [None]:
wolffian_epi = sc.read('/lustre/scratch126/cellgen/team292/vl6/VISIUM/malereproductiveaxis_epi_downsampled_cytassist.h5ad')
wolffian_epi

In [None]:
import seaborn as sns

In [None]:
def evaluate_spatially_variable_ligand_interaction(adata_mese, adata_epi, axis_name, n_bins, cellphone_filtered, min_prop, 
                                                  spatially_variable_genes):

    
    # iterate over interactions and keep those that satisfy requirements 
    passed = []
    cellphone_filtered['starred_bins'] = 'none'
    for index, row in cellphone_filtered.iterrows():
        
        if row['ligand_spatially_variable'] == 1:
            print('Looking at interaction between {} and {}'.format(row['ligand'], row['receptor']))
            # mesenchymal ligand is spatially variable, so plot the ligand expression in the binned axis 
            lig = row['ligand']
            if lig in adata_mese.var_names.to_list():
                # Compute fraction of spots expressing the ligand 
                lig_frac = grouped_obs_percent(adata_mese, 'binned_' + axis_name + '_' + str(n_bins) + '_bins', lig)
                print(lig_frac)
                lig_frac = np.where((lig_frac > min_prop).all(axis=0))[0]
                
                if len(lig_frac) >= 2:
                    print('There are {} bins that express {} in more than 0.3 of cells'.format(len(lig_frac), 
                                                                                         lig))
                    # Compute averaged, 0-1 normalised expression of ligand in each bin 
                    out_lig = grouped_obs_mean(adata_mese, 'binned_' + axis_name + '_' + str(n_bins) + '_bins', lig)

                    # Get the bin values where 0-1 normalised expression is > 0.4 
                    lig_high_expr_bins = np.where((out_lig > 0.4).all(axis=0))[0]
                    lig_high_expr_bins = [i for i in lig_high_expr_bins if i in lig_frac]
                    print(lig_high_expr_bins)

                    rec = row['receptor']
                    if '_' in rec:
                        recs = rec.split('_')
                        if all(element in adata_epi.var_names.to_list() for element in recs):
                            print('All receptors for {} in var_names'.format(lig))
                            recs_list = [] 
                            for r in recs:
                                # Compute fraction of spots in bin that express each of the receptor subunits 
                                out_rec = grouped_obs_percent(adata_epi, 'binned_' + axis_name + '_' + str(n_bins) + '_bins', r)
                                recs_list.append(out_rec)
                            out_recs = pd.concat(recs_list, ignore_index=False)
                            print(out_recs)
                            rec_min_percent = np.where((out_recs > min_prop).all(axis=0))[0]
                            print(rec_min_percent)
                            consec = has_consecutive_numbers(rec_min_percent)
                            if consec and have_common_elements(lig_high_expr_bins, rec_min_percent, 2):
                                print('There is an interaction')
                                plt.figure(figsize=(3, 5))
                                sns.heatmap(out_lig.T, cmap='OrRd', annot=True, fmt=".2f", linewidths=.5, 
                                                  cbar = True)
                                plt.figure(figsize=(3, 5))
                                sns.heatmap(out_recs.T, cmap='OrRd', annot=True, fmt=".2f", linewidths=.5, 
                                                  cbar = True)
                                passed.append(index)
                                starred_bins = [str(i) for i in lig_high_expr_bins if i in rec_min_percent]
                                cellphone_filtered.loc[index, 'starred_bins'] = ','.join(starred_bins)
                            else:
                                print('There is NO interaction')
                        else:
                            print('At least one receptor for {} not in var_names'.format(lig))
                    else:
                        if rec in adata_epi.var_names.to_list():
                            print('Receptor for {} in var_names'.format(lig))

                            out_rec = grouped_obs_percent(adata_epi, 'binned_' + axis_name + '_' + str(n_bins) + '_bins', rec)
                            print(out_rec)
                            rec_min_percent = np.where((out_rec > min_prop).all(axis=0))[0]
                            print(rec_min_percent)
                            consec = has_consecutive_numbers(rec_min_percent)

                            if consec and have_common_elements(lig_high_expr_bins, rec_min_percent, 2):
                                print('There is an interaction')
                                plt.figure(figsize=(3, 5))
                                sns.heatmap(out_lig.T, cmap='OrRd', annot=True, fmt=".2f", linewidths=.5, 
                                                  cbar = True)
                                plt.figure(figsize=(3, 5))
                                sns.heatmap(out_rec.T, cmap='OrRd', annot=True, fmt=".2f", linewidths=.5, 
                                                  cbar = True)
                                passed.append(index)
                                starred_bins = [str(i) for i in lig_high_expr_bins if i in rec_min_percent]
                                cellphone_filtered.loc[index, 'starred_bins'] = ','.join(starred_bins)
                            else:
                                print('There is NO interaction')
                        else:
                            print('Receptor for {} not in var_names'.format(lig))
            
    print(passed)        
    return cellphone_filtered.loc[passed, :]

            
            
        

In [None]:
filter_1 = evaluate_spatially_variable_ligand_interaction(wolffian_mese, wolffian_epi, 'MaleReproductiveAxis',
                                                         6, df_cellphone, 0.2, tot_genes)

In [None]:
filter_1

In [None]:
filter_1_partners = np.unique(filter_1['ligand'].to_list())

In [None]:
len(np.unique(filter_1['ligand'].to_list()))

In [None]:
filter_1_partners

In [None]:
def evaluate_spatially_variable_receptor_interaction(adata_mese, adata_epi, axis_name, n_bins, cellphone_filtered, min_prop, 
                                                  spatially_variable_genes):
    
    
    # iterate over interactions and keep those that satisfy requirements 
    passed = []
    cellphone_filtered['starred_bins'] = 'none'
    for index, row in cellphone_filtered.iterrows():
        
        
        if row['receptor_spatially_variable'] == 1:
            print('Looking at interaction between {} and {}'.format(row['ligand'], row['receptor']))
            print('Assuming this is not a heteromeric interaction!')
            # mesenchymal ligand is spatially variable, so plot the ligand expression in the binned axis 
            rec = row['receptor']
            if rec in adata_mese.var_names.to_list():
                # Compute fraction of spots expressing the ligand 
                rec_frac = grouped_obs_percent(adata_mese, 'binned_' + axis_name + '_' + str(n_bins) + '_bins', rec)
                rec_frac = np.where((rec_frac > min_prop).all(axis=0))[0]
                print('There are {} bins that express {} in more than 20% of cells'.format(len(rec_frac), 
                                                                                         rec))
                if len(rec_frac) >= 2:

                    # Compute averaged, 0-1 normalised expression of ligand in each bin 
                    out_rec = grouped_obs_mean(adata_mese, 'binned_' + axis_name + '_' + str(n_bins) + '_bins', rec)

                    # Get the bin values where 0-1 normalised expression is > 0.4 
                    rec_high_expr_bins = np.where((out_rec > 0.4).all(axis=0))[0]
                    rec_high_expr_bins = [i for i in rec_high_expr_bins if i in rec_frac]
                    print(rec_high_expr_bins)

                    lig = row['ligand']
                    
                    if '_' in lig:
                        ligs = lig.split('_')
                        if all(element in adata_epi.var_names.to_list() for element in ligs):
                            print('All ligands for {} in var_names'.format(rec))
                            ligs_list = [] 
                            for l in ligs:
                                # Compute fraction of spots in bin that express each of the receptor subunits 
                                out_lig = grouped_obs_percent(adata_epi, 'binned_' + axis_name + '_' + str(n_bins) + '_bins', l)
                                ligs_list.append(out_lig)
                            out_ligs = pd.concat(ligs_list, ignore_index=False)
                            print(out_ligs)
                            lig_min_percent = np.where((out_ligs > min_prop).all(axis=0))[0]
                            print(lig_min_percent)
                            consec = has_consecutive_numbers(lig_min_percent)
                            if consec and have_common_elements(rec_high_expr_bins, lig_min_percent, 2):
                                print('There is an interaction')
                                plt.figure(figsize=(3, 5))
                                sns.heatmap(out_rec.T, cmap='OrRd', annot=True, fmt=".2f", linewidths=.5, 
                                                  cbar = True)
                                plt.figure(figsize=(3, 5))
                                sns.heatmap(out_ligs.T, cmap='OrRd', annot=True, fmt=".2f", linewidths=.5, 
                                                  cbar = True)
                                passed.append(index)
                                starred_bins = [str(i) for i in rec_high_expr_bins if i in lig_min_percent]
                                cellphone_filtered.loc[index, 'starred_bins'] = ','.join(starred_bins)
                            else:
                                print('There is NO interaction')
                        else:
                            print('At least one ligand for {} not in var_names'.format(rec))
                    else:
                        if lig in adata_epi.var_names.to_list():
                            print('Ligand for {} in var_names'.format(rec))

                            out_lig = grouped_obs_percent(adata_epi, 'binned_' + axis_name + '_' + str(n_bins) + '_bins', lig)
                            print(out_lig)
                            lig_min_percent = np.where((out_lig > min_prop).all(axis=0))[0]
                            print(lig_min_percent)
                            consec = has_consecutive_numbers(lig_min_percent)

                            if consec and have_common_elements(rec_high_expr_bins, lig_min_percent, 2):
                                print('There is an interaction')
                                plt.figure(figsize=(3, 5))
                                sns.heatmap(out_rec.T, cmap='OrRd', annot=True, fmt=".2f", linewidths=.5, 
                                                  cbar = True)
                                plt.figure(figsize=(3, 5))
                                sns.heatmap(out_lig.T, cmap='OrRd', annot=True, fmt=".2f", linewidths=.5, 
                                                  cbar = True)
                                passed.append(index)
                                starred_bins = [str(i) for i in rec_high_expr_bins if i in lig_min_percent]
                                cellphone_filtered.loc[index, 'starred_bins'] = ','.join(starred_bins)
                            else:
                                print('There is NO interaction')
                        else:
                            print('Ligand for {} not in var_names'.format(rec))
            
    print(passed)        
    return cellphone_filtered.loc[passed, :]


In [None]:
filter_2 = evaluate_spatially_variable_receptor_interaction(wolffian_mese, wolffian_epi, 'MaleReproductiveAxis',
                                                         6, df_cellphone, 0.2, tot_genes)

In [None]:
len(np.unique(filter_2['receptor']))

In [None]:
filter_2_partners = np.unique(filter_2['receptor'].to_list())

In [None]:
print(filter_2_partners)

In [None]:
filter_2_selected = filter_2[filter_2['receptor'].isin(filter_2_partners)]

In [None]:
filter_2_selected

In [None]:
spatially_variable_partners = list(filter_1_partners)

In [None]:
spatially_variable_partners.extend(list(filter_2_partners))

In [None]:
spatially_variable_partners = list(np.unique(spatially_variable_partners))

In [None]:
len(spatially_variable_partners)

In [None]:
print(spatially_variable_partners)