For use with Tensor-cell2cell, we want a dataset that represents >2 contexts. We also want a dataset that contains [replicates](https://www.nature.com/articles/nmeth.3091). Replicates will allow us to ensure that the output factors are not simply due to technical effects (i.e., a factor with high loadings for just one replicate in the context dimension). We will use a [BALF COVID dataset](https://doi.org/10.1038/s41591-020-0901-9), which contains 12 samples associated with "Healthy Control", "Moderate", or "Severe" COVID contexts. This dataset does not contain technical replicates since each sample was taken from a different patient, but each sample associated with a context is a biological replicate. 

[Batch correction](https://www.nature.com/articles/s41592-018-0254-1) removes technical variation while preserving biological variation between samples. We can reasonably assume that the biological variation in samples between contexts will be greater than that of those within contexts after using appropriate batch correction to remove technical variation. Thus, we expect Tensor-cell2cell to capture overall communication trends differing between contexts and can then assess that output factors aren't simply due to technical effects by checking that the output factors have similar loadings for biological replicates and do not have  high loadings for just one sample in the context dimension. 

Finally, we apply a batch correction. The goal here is to account for sample-to-sample technical variability. In this case, we show Combat since it is built in with scanpy. 

Note, the final input matrices to Tensor-cell2cell must be non-negative. We will demonstrate workarounds to negative counts in the tensor building tutorial. 

See 10.1186/s13619-020-00041-9 for a benchmarking of Scanpy's batch correction methods

In [None]:
batch_var = 'Sample_ID' # the batch variable in the metadata

Batch correction using combat:

In [None]:
# merge the balf_samples
balf_corrected = sc.concat(balf_samples.values())
balf_corrected.obs_names_make_unique()

# store log(1+CPM) values in "raw" attribute
balf_corrected.raw = balf_corrected 

# do the batch correction
sc.pp.combat(balf_corrected, key = batch_var) 

At some point in the pipeline, we must account for batch. Batch-correction is important since Tensor-cell2cell considers multiple balf_samples to extract context-dependent patterns, and we want to make sure we are capturing true biological signals rather than sample-specific differences due to technical variability. 

Ideally, we can use single-cell RNAseq batch correction methods. There are a few potential problems with this approach:

1) Batch correction methods often return a matrix in a reduced space and thus does not have the original gene features included, which is needed for LR scoring (see [Table 1](https://academic.oup.com/nargab/article/4/1/lqac022/6548822)).

2) Some cell-cell communication tools expect data in other formats, such as log(1+CPM)

3) Batch correction methods that do return gene counts often return negative counts which can result in negative LR scores. Negative values in the tensor can bias non-negative TCD, the main algorithm used in Tensor-cell2cell.  

In this tutorial, and its companion 01B for R users, we will show pre-processing from raw counts to batch corrected counts. Problem 1 can simply be dealt with by only using batch correction methods that return the original gene features. Problem 2-3 will be discussed further in Tutorials XXX. Essentially, Problems 2-3 can both be dealth with by instead directly introducing a technical covariate to account for batch directly to the decomposition. Problem 3 can also be dealt with either by masking negative values or using a TCD approach that does not have a non-negative constraint. 

The next two cells, unused, show examples of other methods for batch correction . See https://nbisweden.github.io/workshop-scRNAseq/labs/compiled/scanpy/scanpy_03_integration.html for more tutorials on batch correction

Batch correction with scanorama:

In [None]:
# import scanorama

# # merge all the balf_samples into a single object
# balf_log = sc.concat(balf_samples.values())
# balf_log.obs_names_make_unique()

# # correct with scanorama
# balf_corrected = scanorama.correct_scanpy(adatas=list(balf_samples.values()), return_dimred=False)

# # aggregate into one object
# balf_corrected = sc.concat(balf_corrected) 
# balf_corrected.obs_names_make_unique()

# # store log(1+CPM) values in "raw" attribute
# balf_corrected.raw = balf_log

Batch correction using a simple linear regression:

In [None]:
# # merge the balf_samples
# balf_corrected = sc.concat(balf_samples.values())
# balf_corrected.obs_names_make_unique()

# # store log(1+CPM) values in "raw" attribute
# balf_corrected.raw = balf_corrected

# # do the batch correction
# sc.pp.regress_out(balf_corrected, keys = batch_var)

Calculate a PCA manifold on the batch-corrected counts

In [None]:
# get the top 2000 highly variable genes
sc.pp.highly_variable_genes(balf_corrected, n_top_genes = 2000)

# get PCA to 100 PCs
sc.tl.pca(balf_corrected, use_highly_variable = True, svd_solver='arpack', random_state = seed, 
         n_comps = 100)

The final "balf_corrected" AnnData object has the following attributes:
1) X: batch-correct counts matrix (preferably non-negative) <br>
2) obs: cell metadata that includes the cell group (cluster or type), Sample ID, and Context <br>
3) raw: log(1+CPM) normalized AnnData object <br>
4) obsm['X_pca']: the cell manifold 

Regardless of the preprocessing pipeline used, these four pieces of information will be necessary for some parts of the Tensor-cell2cell analyses. 

In [None]:
# from typing import Dict
# def split_adata(adata, sample_col = 'Sample_ID'):
#     """Split an AnnData object with corrected counts into its respective balf_samples.

#     Parameters
#     ----------
#     adata : AnnData
#         merged AnnData object across balf_samples (see sc.concat)
#     sample_col : str, optional
#         the metadata (adata.obs) column specifying the balf_samples, by default 'Sample_ID'

#     Returns
#     -------
#     balf_samples : Dict[str, AnnData]
#         the set of AnnData objects corresponding to each sample
#     """
    
#     balf_samples = {sample: adata[adata.obs[adata.obs[sample_col] == sample].index] for sample in adata.obs[sample_col].unique()}
#     return balf_samples


# balf_corrected_split = split_adata(adata=balf_corrected)