# `spicemix` tutorial jupyter notebook

This jupyter notebook consists of two parts:
1. Theory behind `spicemix`, and 
2. Scripts for running `spicemix`. 

I am hoping that this notebook will give us some ideas on how to analyzing data, help us interpret the results, and understand the caveats that come with this type of analysis

## 1. Theory behind `spicemix` 

`spicemix` stands for spatial identification of cells using matrix factorization. It essentially uses **BOTH** the information about the gene expression in single cells and their locations to define cell types. Using matrix factorization to define cell types itself is not novel, but combining coordinates modality into the analysis makes this algorithm quite powerful, especially in the era of spatial transcriptomics. 

To understand how `spicemix` works, I think it is worth discussing 
1. What matrix factorization is, what type of matrix factorization `spicemix` uses, and what matrix factorization do?
2. How `spicemix` incorporates location information in this algorithm.

### 1A. (Non-negative) Matrix factorization

Matrix factorization is a linear algebra approach to decompose a big matrix into a product of two or more matrices. Matrix decomposition generally reduces the dimensionality of the data and enhances its interpretability. There are many types of matrix factorization because they are suitable applications. One of the factorization we may have used is called principal component analysis (PCA). In this type of matrix factorization, we get one matrix that tells us the new coordinate space that maximizes the variance of the original datapoints and the other matrix that tells us the locations of these datapoints in this new coordinate space. PCA is highly useful for projecting high-dimensional data onto 2-dimensional space since it attempts to maximize variance in the dataset, facilitating visualization and clustering of datapoints.

However, if we were to perform PCA onto a cell-by-gene gene expression matrix, we could get distribution of gene expression, **but we would not be able to interpret the results.** This is because PCA does not have any constraints on the resulting matrix factorization i.e., the values in the resulting matrix can be negative. But we know that RNA counts is always greater than or equal to 0. Can we do better? 

Non-negative matrix factorization (NMF) circumvents this problem by enforcing the results to be non-negative. Let $n$ to be number of genes in the sample and $m$ to be the number of cells in the matrix. In NMF, an $(n \times m)$ matrix $X$ will be approximated by the product of two smaller matrices, $W$ and $H$ with dimension $(n \times k)$ and $(k \times m)$, respectively. $k$ variable here is a hyperparameter that we have to decide and play with. It is usually way less than $n$ and $m$ such that it is interpretable. Too low of a value will make the results less accurate i.e., the error between the actual matrix $X$ and the approximated result from $WH$ is high. Too high of a value will be useless because you do not "summarize" any results here.

Mathematically,
$$
X \approx WH
$$

and NMF tries to find $W$ and $H$ that minimize the reconstruction error. In other words,
$$
\underset{W, H}{\arg \min} \, \lVert X - WH\rVert^2
$$

In the transcriptomics world, $X$ is a cell-by-gene expression matrix. $k$ then is the number of *metagenes* or groups of genes that define one cell type. We can then interpret matrix $W$ and $H$ as the following...
1. $W$ has the dimension of the number of genes $n$ times the number of metagenes $k$. Each column $i$ of $W$ then tell us the combinations of genes that give rise to that metagene $i$. In other words, what are the genes that are classified in each metagene?
2. $H$ has the dimension of the number of metagenes $k$ times the number of cells $m$. Each column $j$ then of $H$ tells us the combination of metagenes that give rise to that single cell transcriptomics.

### 1B. Probabilistic graphical model
It is certainly possible that cells that are distant will have similar gene expression profile, but this is quite rare. Most of the times, cells that have similar gene expression files locate next to one another in a group/tissue. To ensure this spatial consistency, `spicemix` uses probabilistic graphical model. For each field of view, the locations of cells and their relationships are modeled as a set of nodes $\mathcal{V}$ and a set of edges $\mathcal{E}$, respectively, forming a graph structure $\mathcal{G} = (\mathcal{V}, \mathcal{E}).$

Each node or cell $i$ contains parameters describing both 1. the metagene pattern of that cell through NMF I described in the previous section and 2. the similarities of metagene pattern between its neighboring cells. The likelihood function is then formulated by taking into account these parameters which can then be used to find the best parameters that best explain the metagenes in each single cell and maintain the similarity among metagenes in the neighboring cells. 

## 2. Scripts for test running `spicemix` on Server 2 

In [1]:
# Go to the right directory 
import os 
os.chdir(r"F:\Tee\spicemix\SpiceMix\SpiceMix")

In [None]:
# Import the necessary modules 
import time, os, sys, pickle, h5py, importlib, gc, copy, re, itertools, json, logging
os.environ['OMP_NUM_THREADS'] = '8'
os.environ['MKL_NUM_THREADS'] = '8'
os.environ['NUMEXPR_NUM_THREADS'] = '8'
from tqdm.auto import tqdm, trange
from pathlib import Path
from util import config_logger

import numpy as np, pandas as pd, scipy
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score, adjusted_rand_score

from umap import UMAP

import torch
torch.set_num_threads(16)

from matplotlib import pyplot as plt
import seaborn as sns
sns.set_style("white")
%matplotlib inline
%config InlineBackend.figure_format='retina'

logger = config_logger(logging.getLogger(__name__))

In [None]:
# This get around numpy deprecation problem because of the incompatibilities 
# between different numpy version 

np.float = float    
np.int = int   #module 'numpy' has no attribute 'int'
np.object = object    #module 'numpy' has no attribute 'object'
np.bool = bool    #module 'numpy' has no attribute 'bool'

In [None]:
# specify GPU device
context = dict(device='cuda:0', dtype=torch.float64)

# specify dataset
path2dataset = Path(r"F:\Tee\spicemix\SpiceMix\data\synthetic_cortex")
repli_list = [str(i) for i in range(8)]

For more information about hyperparameters for `spicemix` please check this [link](https://github.com/ma-compbio/SpiceMix?tab=readme-ov-file#step-3-inferring-latent-states-metagenes-and-pairwise-affinity-matrix).

In [None]:
from model import SpiceMix
from helper import evaluate_embedding_maynard2021 # This function is for the optional on-the-fly evaluation. This is not required for SpiceMix.
fn_eval = evaluate_embedding_maynard2021

np.random.seed(0)

K, num_pcs, n_neighbors, res_lo, res_hi = 10, 50, 20, .5, 2.

path2result = path2dataset / 'results' / 'SpiceMix_tutorial_3.h5'
os.makedirs(path2result.parent, exist_ok=True)
if os.path.exists(path2result):
    os.remove(path2result)

# This function asks for hyperparameter inputs - it is good to understand what each argument does. 
obj = SpiceMix(
    K=10,
    lambda_Sigma_x_inv=1e-6, power_Sigma_x_inv=2,
    repli_list=repli_list,
    context=context,
    context_Y=context,
    path2result=path2result,
)
obj.load_dataset(path2dataset, expression_suffix='', neighbor_suffix='')
obj.meta['cell type'] = pd.Categorical(obj.meta['cell type'])
# --
obj.initialize(
    method='louvain', kwargs=dict(num_pcs=num_pcs, n_neighbors=n_neighbors, resolution_boundaries=(res_lo, res_hi), num_rs=10),
)
for iiter in range(10):
    obj.estimate_weights(iiter=iiter, use_spatial=[False]*obj.num_repli)
    obj.estimate_parameters(iiter=iiter, use_spatial=[False]*obj.num_repli)
obj.initialize_Sigma_x_inv()
for iiter in range(1, 201):
    logger.info(f'Iteration {iiter}')
    obj.estimate_parameters(iiter=iiter, use_spatial=[True]*obj.num_repli)
    obj.estimate_weights(iiter=iiter, use_spatial=[True]*obj.num_repli)
    if iiter % 50 == 0: # Optional
        # We evaluate the learned latent embeddings every 50 iterations
        evaluate_embedding_maynard2021(obj)

The output file is in HDF5 format, which essentially behaves like Pythonic dictionaries. For more details on the output, please check out this [GitHub repository](https://github.com/ma-compbio/SpiceMix?tab=readme-ov-file#step-4-locating-results).

In [None]:
# load the result file 
result_file = Path(r"F:\Tee\spicemix\SpiceMix\data\synthetic_cortex\results\SpiceMix_tutorial_3.h5")

In [None]:
# Try plotting data 

# 1. We need to read the coordinate and cell type files 
coordinate_file = path2dataset / 'files' / 'coordinates_0.txt' 
cell_type_file = path2dataset / 'files' / 'celltypes_0.txt'

# 2. load these files in numpy 
coordinate_data = np.genfromtxt(coordinate_file, delimiter=' ', dtype='float')
cell_type_data = np.genfromtxt(cell_type_file, dtype='str')

In [None]:
fig, ax = plt.subplots()

unique_labels = np.unique(cell_type_data)
colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels)))
colormap = dict(zip(unique_labels, colors))

for label in unique_labels:
    indices = cell_type_data == label
    ax.scatter(coordinate_data[indices, 0], coordinate_data[indices, 1],
               label=label, color=colormap[label])

plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.title('Scatter Plot with Categorical Labels')
plt.legend(title='Labels', bbox_to_anchor=(1.05, 1))
plt.show()

latent state matrix is equivalent to matrix $H$ I have described above. It essentially tells us the "expression level" of these metagenes in all single cells.

In [None]:
# latent states is equivalent to matrix W 

latent_state = f['latent_states/XT/0/50'][:]
mins = np.min(latent_state, axis=0)
maxs = np.max(latent_state, axis=0)

# Normalize the values in each column (single cell) such that the values are between 0 and 1. 
latent_state_normalized = (latent_state - mins)/(maxs - mins)

In [17]:
# Plot the location of these metagenes 
from matplotlib.colors import Normalize

for metagene_id in range(10):
    fig, ax = plt.subplots()
    
    color_variable = latent_state_normalized[:, metagene_id]
    norm = Normalize(vmin=0, vmax=1)
    
    scatter = plt.scatter(coordinate_data[:, 0], coordinate_data[:, 1],
                c=color_variable, cmap='Reds', norm=norm)
    
    cbar = plt.colorbar(scatter)
    cbar.set_label('Normalized Level')

    plt.title('Metagene {}'.format(metagene_id))
    plt.xlabel('X coordinate')
    plt.ylabel('Y coordinate')
    plt.show()

parameter matrix form `spicemix` is equivalent to matrix $W$ I described above: it contains the information about the gene components in metagenes. 

In [None]:
parameters = f['parameters/M/50'][:]
mins = np.min(parameters, axis=0)
maxs = np.max(parameters, axis=0)

# Normalize each column such that it is between 0 and 1
parameters_normalized = (parameters - mins)/(maxs - mins)

In [None]:
# For fun: perform hierarchical clustering to see which the expression level of genes in each metagene

sns.clustermap(parameters_normalized, cmap="Blue")

In [None]:
# For fun #2: we can look at the z-score and look at the differential gene expression across different metagenes 

sns.clustermap(parameters_normalized, cmap="vlag", center=0, col_cluster=True, z_score=0)
