### Preamble

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import warnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    import scanpy as sc


In [None]:
## local paths etc. You'll want to change these
DATASET_DIR = "/scratch1/rsingh/work/schema/data/tasic-nature"
import sys; sys.path.extend(['/scratch1/rsingh/tools','/afs/csail.mit.edu/u/r/rsingh/work/schema/'])


#### Import Schema and tSNE
We use fast-tsne here, but use whatever you like

In [None]:
from fast_tsne import fast_tsne
from schema import SchemaQP

### Get example data 
  * This data is from Tasic et al. (Nature 2018, DOI: 10.1038/s41586-018-0654-5 )
  * Shell commands to get our copy of the data:
    * wget http://schema.csail.mit.edu/datasets/Schema_demo_Tasic2018.h5ad.gz
    * gunzip Schema_demo_Tasic2018.h5ad.gz
  * The processing of raw data here broadly followed the steps in Kobak & Berens, https://www.biorxiv.org/content/10.1101/453449v1
  * The gene expression data has been count-normalized and log-transformed. 


In [None]:
adata = sc.read(DATASET_DIR + "/" + "Schema_demo_Tasic2018.h5ad")

### Schema examples
  * In all of what follows, the primary dataset is gene expression. The secondary datasets are 1) cluster IDs; and 2) cell-type "class" variables which correspond to superclusters (i.e. higher-level clusters) in the Tasic et al. paper.
#### Recommendations for parameter settings
  * min_desired_corr and w_max_to_avg are the names for the hyperparameters $s_1$ and $\bar{w}$ from our paper
  * *min_desired_corr*: at first, you should try a range of values for min_desired_corr (e.g., 0.99, 0.90, 0.50). This will give you a sense of what might work well for your data; after this, you can progressively narrow down your range. In typical use-cases, high min_desired_corr values (> 0.80) work best.
  * *w_max_to_avg*: start by keeping this constraint very loose. This ensures that min_desired_corr remains the binding constraint. Later, as you get a better sense for min_desired_corr values, you can experiment with this too. A value of 100 is pretty high and should work well in the beginning.


#### With PCA as change-of-basis, min_desired_corr=0.75, positive correlation with secondary datasets

In [None]:
afx = SchemaQP(0.75) # min_desired_corr is the only required argument.

dx_pca = afx.fit_transform(adata.X, # primary dataset
                           [adata.obs["class"].values], # one secondary dataset
                           ['categorical'] #it has labels, i.e., is a categorical datatype
                          )

#### Similar to above, with NMF as change-of-basis and a different min_desired_corr

In [None]:
afx = SchemaQP(0.6, params= {"decomposition_model": "nmf", "num_top_components": 50})

dx_nmf = afx.fit_transform(adata.X,
                           [adata.obs["class"].values, adata.obs.cluster_id.values], # two secondary datasets 
                           ['categorical', 'categorical'], # both are labels
                           [10, 1] # relative wts
                     )

#### Now let's do something unusual. Perturb the data so it *disagrees* with cluster ids

In [None]:
afx = SchemaQP(0.97, # Notice that we bumped up the min_desired_corr so the perturbation is limited 
               params = {"decomposition_model": "nmf", "num_top_components": 50})

dx_perturb = afx.fit_transform(adata.X,
                           [adata.obs.cluster_id.values], # could have used both secondary datasets, but one's fine here
                           ['categorical'],
                           [-1] # This is key: we are putting a negative wt on the correlation
                          )

### tSNE plots of the baseline and Schema transforms 

In [None]:
fig = plt.figure(constrained_layout=True, figsize=(8,2), dpi=300)
tmps = {}
for i,p in enumerate([("Original", adata.X), 
                      ("PCA1 (pos corr)", dx_pca), 
                      ("NMF (pos corr)", dx_nmf), 
                      ("Perturb (neg corr)", dx_perturb)
                     ]):
    titlestr, dx1 = p 
    ax = fig.add_subplot(1,4,i+1, frameon=False)
    tmps[titlestr] = dy = fast_tsne(dx1, seed=42)
    ax = plt.gca()
    ax.set_aspect('equal', adjustable='datalim')
    ax.scatter(dy[:,0], dy[:,1], s=1, color=adata.obs['cluster_color'])
    ax.set_title(titlestr)
    ax.axis("off")