### Discussion

Here, we discuss the intricacies of batch correction with regards to CCC scoring and Tensor-cell2cell.

For use with Tensor-cell2cell, we want a dataset that represents >2 contexts. When considering 2 or more contexts, these are typically derived from different samples and/or measurments, and thus introduce batch effects. 

#### Replicates
Our dataset should contain [replicates](https://www.nature.com/articles/nmeth.3091). Replicates will allow us to ensure that the output factors are not simply due to technical effects (i.e., a factor with high loadings for just one replicate in the context dimension). We will use a [BALF COVID dataset](https://doi.org/10.1038/s41591-020-0901-9), which contains 12 samples associated with "Healthy Control", "Moderate", or "Severe" COVID contexts. This dataset does not contain technical replicates since each sample was taken from a different patient, but each sample associated with a context is a biological replicate. 


#### Batch correction
[Batch correction](https://www.nature.com/articles/s41592-018-0254-1) removes technical variation while preserving biological variation between samples. We can reasonably assume that the biological variation in samples between contexts will be greater than that of those within contexts after using appropriate batch correction to remove technical variation. Thus, we expect Tensor-cell2cell to capture overall communication trends differing between contexts and can then assess that output factors aren't simply due to technical effects by checking that the output factors have similar loadings for biological replicates and do not have  high loadings for just one sample in the context dimension. 

#### Benchmarking Batch Effects

Using simulated data and metrics of batch severity (kBET and NMI), we saw that Tensor-cell2cell is robust to batch effects approaching **XXX** severity (see benchmarking/batch_correction for details). Applying these metrics to your own dataset should help you determine whether batch correction is necessary prior to running Tensor-cell2cell.

***add png here**

#### Introduction of Negative Counts

Since CCC uses gene expression values to infer communication, a pre-requesite on selection of batch effect correction method is that it returns a corrected counts matrix rather than a latent/reduced space representation (see [Table 1](https://doi.org/10.1093/nargab/lqac022) for examples). 

Secondly, most batch correction methods that return a counts matrix introduce negative counts. Below, we show a simple example with a batch correction method that 1) returns a corrected counts matrix, and 2) returns non-negative counts. Here, we discuss further the problems with negative counts in CCC. 

* Problem 1: Negative expression values can distort scoring functions that include multiplicative functions of ligands and receptors. Take the scenario in which a ligand has a negative count and a receptor has a negative count, this will yield a positive communication score, assumed to be strong. 
* Problem 2: Negative expression values can yield negative communication scores, which the non-negative tensor decomposition algorithm used by Tensor-cell2cell disregards in its optimization. 

Regardless, we show that Tensor-cell2cell can robustly identify communication patterns even in the presence of negative counts introduced during batch correction (see benchmarking/batch_correction for details). This is likely due to the key fact: negative counts and communication scores represent lower strength interactions that do not have a strong influence on the overall communication. If using a preferred batch correction method that introduces negative counts, to address the above problems, follow these recommendations: 
* Recommendation 1: Try using methods that have additive rather than multiplicative functions for scoring of ligand-receptor pairs. 
* Recommendation 2a: If the scoring method cannot handle negative values, replace these with NaN. These are genes that are more lowly expressed anyways, so disregarding their communication score is ok. 
* Recommendation 2b: If the scoring method can handle negative values, the final tensor will contain negative values. Use a mask to have Tensor-cell2cell disregard these values when running the decomposition. Assuming an additive scoring function was used, these are communication scores that are lower strength anyways, so disregarding them is ok. 



### Application 

In [1]:
import os
import pickle

import scanpy as sc

seed = 888

data_path = '/data3/hratch/ccc_protocols/'

First, let's load our normalized expression data from Tutorial 1:

In [2]:
with open(os.path.join(data_path, 'interim/covid_balf_norm.pickle'), 'rb') as handle:
    balf_samples = pickle.load(handle)

## Simple Example: Batch corrected counts with only non-negative values

[scVI](https://doi.org/10.1038/s41592-018-0229-2) implements a batch correction method that can return non-negative corrected counts, and it also [benchmarked](https://doi.org/10.1038/s41592-021-01336-8) to work well. 

We format our dictionary by [concatenating](https://anndata.readthedocs.io/en/latest/concatenation.html) it into a single Seurat Object:

In [3]:
import anndata as ad
import pandas as pd
import numpy as np

import warnings
warnings.simplefilter('ignore')
import scvi

Global seed set to 0
  new_rank_zero_deprecation(
  return new_rank_zero_deprecation(*args, **kwargs)


In [4]:
balf_combined = [bs.raw.to_adata() for bs in list(balf_samples.values())] # raw counts used in scVI
balf_combined = balf_combined[0].concatenate(balf_combined[1:], join = 'outer')

Since CCC inference tools only consider a subset of the genes (those present in ligand-receptor databases), we do not filter for highly variable genes as this would exclude too many LRs and decrease the power of communication inference. 

However, if runtime with scVI is a concern, we can conduct the following optional step prior to batch correction: filtering for only genes present in the LR database that you will use for communication scoring. Here, we use the [CellChat](https://doi.org/10.1038/s41467-021-21246-9) database as an example. 

**ToDO**: can we change this to be from LIANA directly?

In [5]:
# optional step: 

# get the CellChatDB
hl = 'https://raw.githubusercontent.com/LewisLabUCSD/Ligand-Receptor-Pairs/master/Human/Human-2020-Jin-LR-pairs.csv'
lr_pairs = pd.read_csv(hl)

# separate complexes and join LRs
receptors = lr_pairs.receptor_symbol.apply(lambda x: x.split('&')).tolist() 
receptors = [item for sublist in receptors for item in sublist]
ligands = lr_pairs.ligand_symbol.apply(lambda x: x.split('&')).tolist() 
ligands = [item for sublist in ligands for item in sublist]
lrs = set(ligands + receptors)

# subset to present lrs                            
balf_combined = balf_combined[:, sorted(lrs.intersection(balf_combined.var.index.tolist()))]

In [6]:
# # name barcodes same as R to have cell x gene matrix be in the same order

# barcode_id = pd.Series(balf_combined.obs.index).apply(lambda x: '-'.join(x.split('-')[:-1]))
# barcode_id = balf_combined.obs.Sample_ID.reset_index(drop = True).str.cat(barcode_id, sep = '-')
# md = balf_combined.obs.copy()
# md['R_barcode_id'] = barcode_id.tolist()
# # barcode_order = md.sort_values(by = 'R_barcode_id').index.tolist()
# # balf_combined = balf_combined[barcode_order, :]

In [7]:
balf_combined.to_df().T.head()

Unnamed: 0,AAACCCACAGCTACAT-1-0,AAACCCATCCACGGGT-1-0,AAACCCATCCCATTCG-1-0,AAACGAACAAACAGGC-1-0,AAACGAAGTCGCACAC-1-0,AAACGAAGTCTATGAC-1-0,AAACGAAGTGTAGTGG-1-0,AAACGCTGTCACGTGC-1-0,AAACGCTGTTGGAGGT-1-0,AAAGAACTCTAGAACC-1-0,...,TTTGTCAGTGTCAATC-1-11,TTTGTCAGTGTGAAAT-1-11,TTTGTCATCAGTTAGC-1-11,TTTGTCATCCAGTATG-1-11,TTTGTCATCCCTAATT-1-11,TTTGTCATCGATAGAA-1-11,TTTGTCATCGGAAATA-1-11,TTTGTCATCGGTCCGA-1-11,TTTGTCATCTCACATT-1-11,TTTGTCATCTCCAACC-1-11
ACKR2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
ACKR3,0.0,0.0,0.0,5.0,0.0,0.0,0.0,0.0,0.0,6.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
ACKR4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
ACVR1,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,2.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
ACVR1B,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


Next, we can run scVI according to the [tutorial](https://docs.scvi-tools.org/en/stable/tutorials/notebooks/harmonization.html):

In [8]:
balf_combined = balf_combined.copy()
scvi.model.SCVI.setup_anndata(balf_combined, batch_key="Sample_ID")
model = scvi.model.SCVI(balf_combined, n_layers=2, n_latent=30, gene_likelihood="nb")
model.train()

with open(os.path.join(data_path, 'interim/balf_scvi_model.pickle'), 'wb') as handle:
    pickle.dump(model, handle, protocol=pickle.HIGHEST_PROTOCOL)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Epoch 127/127: 100%|██████████████████████████████| 127/127 [9:01:32<00:00, 264.96s/it, loss=169, v_num=1]

`Trainer.fit` stopped: `max_epochs=127` reached.


Epoch 127/127: 100%|██████████████████████████████| 127/127 [9:01:32<00:00, 255.85s/it, loss=169, v_num=1]


scVI's batch correction has the added benefit of being formatted like a depth-normalized matrix. Transforming this with log1p will put it in a similar format as log(1+CPM).

In [9]:
# library size and log1p make it similar to log(1+CPM) normalization, but with batch correction
# batch corrected counts: https://discourse.scverse.org/t/how-to-extract-batch-corrected-expression-matrix-from-trained-scvi-vae-model/151

corrected_data = model.get_normalized_expression(transform_batch = sorted(balf_combined.obs.Sample_ID.unique()), 
                               library_size = 1e6) # depth normalization
corrected_data = np.log1p(corrected_data) # log1p transformation
corrected_data.to_csv(os.path.join(data_path, 'interim/py_scvi_corrected_counts.csv'))
corrected_data.T.head()

Unnamed: 0,AAACCCACAGCTACAT-1-0,AAACCCATCCACGGGT-1-0,AAACCCATCCCATTCG-1-0,AAACGAACAAACAGGC-1-0,AAACGAAGTCGCACAC-1-0,AAACGAAGTCTATGAC-1-0,AAACGAAGTGTAGTGG-1-0,AAACGCTGTCACGTGC-1-0,AAACGCTGTTGGAGGT-1-0,AAAGAACTCTAGAACC-1-0,...,TTTGTCAGTGTCAATC-1-11,TTTGTCAGTGTGAAAT-1-11,TTTGTCATCAGTTAGC-1-11,TTTGTCATCCAGTATG-1-11,TTTGTCATCCCTAATT-1-11,TTTGTCATCGATAGAA-1-11,TTTGTCATCGGAAATA-1-11,TTTGTCATCGGTCCGA-1-11,TTTGTCATCTCACATT-1-11,TTTGTCATCTCCAACC-1-11
ACKR2,1.053409,1.330357,1.654605,0.352915,0.332205,2.329128,0.492785,1.31685,0.809256,0.309724,...,1.0179,0.716844,1.552116,1.333449,0.981065,1.650308,1.829056,0.72436,1.179444,0.337692
ACKR3,6.21444,6.006417,3.825843,7.156668,5.369172,2.602306,4.721803,3.581478,6.35328,6.967628,...,4.615073,4.014686,4.81514,4.407703,4.71646,4.216978,3.977207,4.010178,5.503017,7.041326
ACKR4,1.551312,2.505382,1.524521,0.733223,1.685193,3.388128,1.803644,2.298422,1.831141,1.997048,...,1.840316,1.697146,3.236091,1.23253,2.152939,3.605989,1.743316,1.285501,1.243938,1.047431
ACVR1,5.998411,6.390562,6.126544,5.326526,5.551836,5.246313,5.709505,5.446943,5.670623,5.923387,...,5.286893,5.073652,6.160941,5.395304,5.538884,5.374123,5.565694,5.365861,5.745325,5.255951
ACVR1B,4.188776,6.716805,5.052667,5.592592,6.306942,4.974235,6.457505,4.683265,6.646193,5.999748,...,5.508474,5.077518,6.080957,6.293207,5.738174,5.569398,5.847088,5.473363,5.286615,5.654563


This corrected data matrix can replace the log(1+CPM) matrix used in tutorials 02 onwards for downstream analyses, if desired. Note, outputs won't be identical to companion R tutorial in this case due to stochastic steps in scVI. 

**To do: should I show how to replace this in the actual anndata object?**

## Complex Example: Batch corrected counts containing negative values