In [2]:
suppressPackageStartupMessages({
    library(splatter)
    
    library(scater)
    library(scran)
    library(bluster)
    
    library(aricode)
    library(kBET)
    
    library(reticulate)
    scanorama <- import('scanorama')
})

seed = 888
set.seed(seed)

Create the gold standard simulation. A guid to parameters in the simulation can be found [here](https://bioconductor.org/packages/devel/bioc/vignettes/splatter/inst/doc/splat_params.html)

In [3]:
qc.data<-function(sce){
    # taken from PMID: 34949812
    
    # QC of cells
    sce <- scater::addPerCellQC(sce) # typical QC as in batch correction paper
    discard <- lapply(unique(colData(sce)$Batch), function(batch) {
        in_batch <- colData(sce)$Batch == batch
        scater::quickPerCellQC(colData(sce)[in_batch, ], nmads = 2)$discard
    })
    discard <- unlist(discard)
    colData(sce)$Discard <- discard
    sce <- sce[, !discard]

    # QC of genes
    sce <- scater::addPerFeatureQC(sce)
    is_exprs <- rowData(sce)$detected >= 0.01
    sce <- sce[is_exprs, ]
    
    return(sce)
}

cluster.cells<-function(sce, assay.type){
    sce <- scran::fixedPCA(sce, assay.type = assay.type, subset.row=NULL) # default 50 PCs
    # clustering http://bioconductor.org/books/3.13/OSCA.basic/clustering.html
    colData(sce)[['Cluster']]<-scran::clusterCells(sce, use.dimred="PCA", 
                                                   BLUSPARAM=NNGraphParam(shared = T, cluster.fun="louvain")) 
    return(sce)
}

scanorama.batch.correct<-function(sce.batch){

    # prep scanorama inputs - split expression objects by batch
    batches<-list()
    genes<-list()
    for (batch in unique(colData(sce.batch)$Batch)){
        cells.in.batch<-rownames(colData(sce.batch)[colData(sce.batch)$Batch == batch,])
        sce.singlebatch<-sce.batch[, cells.in.batch]

        batches[[batch]]<-t(assays(sce.singlebatch)$logcounts) # cells x genes
        genes[[batch]]<-as.list(rownames(sce.batch))
    }

    # do the batch correction 
    corrected.data <- scanorama$correct(unname(batches), unname(genes), return_dense=TRUE)

    # format into the sce
    genes<-corrected.data[[2]]
    corrected.data<-corrected.data[[1]]

    for (i in 1:length(corrected.data)){
        expr<-corrected.data[[i]]
        colnames(expr) <- genes
        rownames(expr) <- rownames(batches[[i]])
        corrected.data[[i]] <- t(expr)
    }
    corrected.data <- do.call(cbind, corrected.data)

    if (!identical(dim(corrected.data), dim(assays(sce.batch)$logcounts))){ # sanity check
        stop('Unexpected filters in scanorama batch correction')
    }
    corrected.data<-corrected.data[rownames(assays(sce.batch)$logcounts), colnames(assays(sce.batch)$logcounts)]
    assays(sce.batch)[['scanorama.counts']] <- corrected.data
    
    return(sce.batch)
}

quantify.batch.effect<-function(sce, assay.type){
    nmi<-aricode::NMI(colData(sce)$Group, colData(sce)$Cluster) # clusterability
    kbet<-kBET(df = t(assays(sce)[[assay.type]]), batch = colData(sce)$Batch, # mixability
                        plot = F)$summary['mean', 'kBET.observed']
    # sil.cluster <- as.data.frame(silhouette(as.numeric(colData(sce)$Cluster), 
    #                               dist(reducedDims(sce)$PCA)))
    # sil.cluster<-mean(aggregate(sil.cluster[, 'sil_width'], list(sil.cluster$cluster), mean)$x)
    
    return(list(clusterability = 1 - nmi, mixability = kbet))
}

In [4]:
base_params <- newSplatParams()
n.cells<-1e3#60e3
n.cell.types<-3#10
n.genes<-3e3
# n.hvgs<-2e3

n.batches<-seq(2, 5, 1)
batch.locations<-seq(0.1, 0.5, 0.1)
batch.scales<-seq(0.1, 0.5, 0.1)
de.locations<-seq(0.1, 0.5, 0.1)
de.scales<-seq(0.1, 0.5, 0.1)

# for loop here
n.batches<-2
batch.scale<-0.4
batch.location<-0.4
de.prob<-0.1
de.scale<-0.15
de.location<-0.2

sim_params <- setParams(
    base_params,
    seed = seed,
    nGenes = n.genes,
    
    # batches
    batchCells = rep(n.cells/n.batches, n.batches),
    batch.facLoc = batch.location, # higher values increase batch severity
    batch.facScale = batch.scale, # higher values increase batch severity
    batch.rmEffect = FALSE, # create the gold standard when set to True
    
    # cell types
    group.prob = rep(1/n.cell.types, n.cell.types), 
    de.facLoc = de.location, # increase separation between cell types
    de.facScale = de.scale # increase separation between cell types
)

In [5]:
run.all<-function(sim_params){
    # generate the simulated dataset and its respective gold standard
    sim<-splatSimulateGroups(sim_params, verbose = F)
    gold.standard<-splatSimulateGroups(sim_params, verbose = F, batch.rmEffect = T)
    expr.datasets<-list(sim.log = sim, gold.standard = gold.standard)

    # process
    expr.datasets<-lapply(expr.datasets, FUN = function(sce) qc.data(sce)) # do qc

    genes.intersect<-intersect(rownames(expr.datasets$sim.log), rownames(expr.datasets$gold.standard))
    cells.intersect<-intersect(colnames(expr.datasets$sim.log), colnames(expr.datasets$gold.standard))
    expr.datasets<-lapply(expr.datasets, FUN = function(sce) {
        sce <- sce[genes.intersect, cells.intersect] # retain only genes/cells in common b/w the 2 datasets
        sce <- scater::logNormCounts(sce)
        sce <- cluster.cells(sce, assay.type = 'logcounts') # run PCA and SNN clustering 
    })
    
# #     # filter for HVGs after log-normalization
# #     mgv<-scran::modelGeneVar(expr.datasets$gold.standard, assay.type = 'logcounts')
# #     hvgs<-scran::getTopHVGs(mgv, n = n.hvgs)
#     expr.datasets<-lapply(expr.datasets, FUN = function(sce) {
# #         sce <- sce[hvg, ] # filter for HVGs from gold-standard dataset
#         sce <- cluster.cells(sce, assay.type = 'logcounts') # run PCA and SNN clustering 
#     })

    # batch correct 
    sce.batches<-do.batch.correction(expr.datasets$sim) # on log-normalized data with batch effects
    sce.scanorama<-sce.batches$scanorama
    sce.scGen<-sce.batches$scGen                      
                          
    # quantify batch severity
    # maybe map expr.datasets name to assay.type and loop                                         
    gs.batcheffect<-quantify.batch.effect(expr.datasets$gold.standard, assay.type = 'logcounts')   
    sim.log.batcheffect<-quantify.batch.effect(expr.datasets$sim.log, assay.type = 'logcounts')
    sim.scanorama.batcheffect<- # batch corrected data
                          
    # score the communication in each context (here, context == batch)
                          
    # run the tensor decomposition
                          
    # calculate the CorrIndex with the gold-standard
                          
    # store all information for each of: gold-standard, not batch corrected, batch corrected as separate matrices
                          
    return(NULL)#expr.datasets)
}

In [6]:
sim<-splatSimulateGroups(sim_params, verbose = F)
gold.standard<-splatSimulateGroups(sim_params, verbose = F, batch.rmEffect = T)
expr.datasets<-list(sim.log = sim, gold.standard = gold.standard)

# process
expr.datasets<-lapply(expr.datasets, FUN = function(sce) qc.data(sce)) # do qc

genes.intersect<-intersect(rownames(expr.datasets$sim.log), rownames(expr.datasets$gold.standard))
cells.intersect<-intersect(colnames(expr.datasets$sim.log), colnames(expr.datasets$gold.standard))
expr.datasets<-lapply(expr.datasets, FUN = function(sce) {
    sce <- sce[genes.intersect, cells.intersect] # retain only genes/cells in common b/w the 2 datasets
    sce <- scater::logNormCounts(sce)
    sce <- cluster.cells(sce, assay.type = 'logcounts') # run PCA and SNN clustering 
})

In [7]:
# batch correct
sce.batch<-expr.datasets$sim

#do.batch.correction<-function(sce.batch)
# manually overwrite features that will be re-calculated as a sanity check
reducedDims(sce.batch)<-NULL
colData(sce.batch)<-colData(sce.batch[colnames(colData(sce.batch)) != 'Cluster',])

# # do the batch correction 
sce.scanorama<-scanorama.batch.correct(sce.batch)
# sce.scGen<-scanorama.batch.correct(sce.batch)
# sce.batches<-list(scanorama = sce.scanorama, scGen = sce.scGen)

# # run PCA and SNN clustering lapply
# cluster.cells(sce, assay.type = ??)

# # sanity checks - may be unnecessary
# if (!identical(assays(sce.batch)$logcounts, assays(sce.scanorama)$logcounts) |
#     !identical(assays(sce.batch)$logcounts, assays(sce.scGen)$logcounts)){
#     stop('Unexpected loss of logcounts assay')
# }
# if (identical(assays(sce.batch)$logcounts, assays(sce.scanorama)$batch.corrected) | 
#     identical(assays(sce.batch)$logcounts, assays(sce.scGen)$batch.corrected) | 
#     identical(assays(sce.scanorama)$batch.corrected, assays(sce.scGen)$batch.corrected)){
#     stop('Unexpected lack of change in batch correction')
# }


# #return(sce.batches)