In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
# import squidpy as sq

# import muon as mu
import mudata as mu

In [None]:
import liana as li

In [None]:
## load slide
slide = "B1"
exp = "V11L12-109" # V11T17-102 V11L12-109 V11T16-085
slide_path = f"data/VicariEtAl/sma/{exp}/{exp}_{slide}/output_data"

NOTE: This requires coordinates to be in the same system (e.g. pixels on the image)

In [None]:
mdata = mu.read_h5mu(os.path.join(slide_path, "sma.h5mu"))
rna = mdata.mod["rna"]
msi = mdata.mod["msi"]

In [None]:
rna.var_names_make_unique()
msi.var_names_make_unique()

In [None]:
mdata

## Process RNA

In [None]:
sc.pp.filter_cells(rna, min_genes=200)
sc.pp.filter_cells(rna, min_counts=100)
sc.pp.filter_genes(rna, min_cells=3)

In [None]:
sc.pp.normalize_total(rna, target_sum=1e4)
sc.pp.log1p(rna)

In [None]:

sc.pp.highly_variable_genes(rna, flavor='cell_ranger', n_top_genes=200)

In [None]:
rna = rna[:, rna.var['highly_variable']]

In [None]:
rna

## Process metabolites

In [None]:
# change type to float
msi.var['mz'] = msi.var.index.astype(float)
# round to 2 decimals
msi.var['mz'] = [round(x, 2) for x in msi.var['mz']]

In [None]:
msi.var['max_intensity'] = msi.X.sum(axis=0)

In [None]:
molecule_weights = {
    'GABA': 371.18,
    'GABA-H2O': 353.16,
    'Taurine': 393.13,
    'Serotonin': 444.21,
    'Histidine': 423.18,
    '3-MT': 435.21,
    'Dopamine (single)': 421.19,
    'Dopamine (double)': 674.28,
    'DOPAC': 689.24,
    'NE (Norepinephrine)': 690.28,
    'Tocopherol': 698.49
}

In [None]:
# show m/z with matching molecule weight
msk = msi.var['mz'].isin(molecule_weights.values())

In [None]:
msi = msi[:, msk]

In [None]:
msi.var['name'] = msi.var['mz'].map({v: k for k, v in molecule_weights.items()})
msi.var['name'] = msi.var['name'].astype('category')
msi.var.index = msi.var['name']

In [None]:
from matplotlib import pyplot as plt

In [None]:
msi.X.sum(axis=1)

In [None]:
x = msi.X.sum(axis=1)

# histogram of total ion count
plt.hist(np.array(x), bins=25)

Test

In [None]:
# get reference coordinates
reference = mdata.mod["rna"].obsm["spatial"]

# distances of metabolties to RNA
li.ut.spatial_neighbors(msi, bandwidth=500, cutoff=0.1, spatial_key="spatial", reference=reference, set_diag=False, standardize=False)

In [None]:
# get reference coordinates
reference = mdata.mod["msi"].obsm["spatial"]

# distances of metabolties to RNA
li.ut.spatial_neighbors(rna, bandwidth=500, cutoff=0.1, spatial_key="spatial", reference=reference, set_diag=False, standardize=False)

In [None]:
import squidpy as sq

In [None]:
sq.pl.spatial_scatter(msi, color="Dopamine (double)")

In [None]:
# spatial weight all as 1
# w = np.zeros((msi.obsm['spatial_connectivities'].T.shape))

In [None]:
w = msi.obsm['spatial_connectivities'].T

In [None]:
# Apply Spatial Smoothing with RNA as reference
msi_smooth = w @ msi.X

In [None]:
msi_smooth = sc.AnnData(X=msi_smooth, obs=rna.obs, obsm=rna.obsm, uns=rna.uns, var=msi.var)

In [None]:
sq.pl.spatial_scatter(msi_smooth, color="Dopamine (double)")

In [None]:
w2 = rna.obsm['spatial_connectivities'].T

In [None]:
x_mat, y_mat = msi.X.T, rna.X[:,:8].T


In [None]:
weight = w
weight = w2

In [None]:
x_mat = x_mat @ weight.T
y_mat = y_mat @ weight2.T

Mitsy

In [None]:
mdata.update_obs()

In [None]:
misty = li.mt.MistyData({"intra": msi, "inter":rna}, enforce_obs=False, obs=mdata.obs)


In [None]:
misty(model="linear", verbose=True, bypass_intra=True)

In [None]:
# NOTE: why is stat being passed here and why does it work?....
li.pl.contributions(misty, return_fig=True)

In [None]:
misty.uns['target_metrics']

In [None]:
li.pl.target_metrics(misty, stat='multi_R2', return_fig=True)


In [None]:
misty.uns['interactions'].sort_values("importances")

In [None]:
from sklearn.neighbors import BallTree
import plotnine as p9

In [None]:
def query_bandwidth(coordinates, start=0, end=500, interval_n=50, reference=None, inplace=False):
    tree = BallTree(coordinates, metric='euclidean')

    # initialize df
    df = pd.DataFrame()

    # Specify a range of max distances
    interval = np.linspace(start, end, interval_n)
    
    if reference is None:
        _reference = coordinates
    else:
        _reference = reference

    # Calculate average number of nearest neighbors for each max distance
    for n in range(interval_n):
        # Query the neighbors within the specified distance using ball_point
        max_distance = interval[n]
        df.loc[n, 'bandwith'] = max_distance
        num_neighbors = tree.query_radius(_reference, r=max_distance, count_only=True)
        
        # Calculate the average number of neighbors
        avg_nn = np.mean(num_neighbors)
        df.loc[n, 'neighbours'] = avg_nn
    
    p = (p9.ggplot(df, p9.aes(x='bandwith', y='neighbours')) +
         p9.geom_line() + 
         p9.geom_point() +
         p9.theme_bw(base_size=16) +
         p9.xlab("Bandwidth") +
         p9.ylab("Average Number of Neighbors")
         )
    
    return p, df
    

In [None]:
coordinates = msi.obsm['spatial']

In [None]:
p, _ = query_bandwidth(coordinates, start=0, end=5000, interval_n=50, reference=None)

In [None]:
p

In [None]:
reference = mdata.mod["msi"].obsm["spatial"]

In [None]:
coordinates = mdata.mod["rna"].obsm["spatial"]

In [None]:
p, df = query_bandwidth(coordinates, start=300, end=5000, interval_n=50, reference=reference)