In [1]:
suppressPackageStartupMessages({
    library(splatter)
    
    library(scater)
    library(scran)
    library(bluster)
    
    library(aricode)
    library(kBET)
})

seed = 888
set.seed(seed)

Create the gold standard simulation. A guid to parameters in the simulation can be found [here](https://bioconductor.org/packages/devel/bioc/vignettes/splatter/inst/doc/splat_params.html)

In [2]:
qc.data<-function(sce){
    # taken from PMID: 34949812
    
    # QC of cells
    sce <- scater::addPerCellQC(sce) # typical QC as in batch correction paper
    discard <- lapply(unique(colData(sce)$Batch), function(batch) {
        in_batch <- colData(sce)$Batch == batch
        scater::quickPerCellQC(colData(sce)[in_batch, ], nmads = 2)$discard
    })
    discard <- unlist(discard)
    colData(sce)$Discard <- discard
    sce <- sce[, !discard]

    # QC of genes
    sce <- scater::addPerFeatureQC(sce)
    is_exprs <- rowData(sce)$detected >= 0.01
    sce <- sce[is_exprs, ]
    
    return(sce)
}

cluster.cells<-function(sce, assay.type){
    sce <- scran::fixedPCA(sce, assay.type = assay.type, subset.row=NULL) # default 50 PCs
    # clustering http://bioconductor.org/books/3.13/OSCA.basic/clustering.html
    colData(sce)[['Cluster']]<-scran::clusterCells(sce, use.dimred="PCA", 
                                                   BLUSPARAM=NNGraphParam(shared = T, cluster.fun="louvain")) 
    return(sce)
}

quantify.batch.effect<-function(sce, assay.type){
    nmi<-aricode::NMI(colData(sce)$Group, colData(sce)$Cluster) # clusterability
    kbet<-kBET(df = t(assays(sce)[[assay.type]]), batch = colData(sce)$Batch, # mixability
                        plot = F)$summary['mean', 'kBET.observed']
    # sil.cluster <- as.data.frame(silhouette(as.numeric(colData(sce)$Cluster), 
    #                               dist(reducedDims(sce)$PCA)))
    # sil.cluster<-mean(aggregate(sil.cluster[, 'sil_width'], list(sil.cluster$cluster), mean)$x)
    
    return(list(clusterability = 1 - nmi, mixability = kbet))
}

In [3]:
base_params <- newSplatParams()
n.cells<-1e3#60e3
n.cell.types<-3#10

n.batches<-seq(2, 5, 1)
batch.locations<-seq(0.1, 0.5, 0.1)
batch.scales<-seq(0.1, 0.5, 0.1)
de.locations<-seq(0.1, 0.5, 0.1)
de.scales<-seq(0.1, 0.5, 0.1)

# for loop here
n.batches<-2
batch.scale<-0.4
batch.location<-0.4
de.prob<-0.1
de.scale<-0.15
de.location<-0.2

sim_params <- setParams(
    base_params,
    seed = seed,
    
    # batches
    batchCells = rep(n.cells/n.batches, n.batches),
    batch.facLoc = batch.location, # higher values increase batch severity
    batch.facScale = batch.scale, # higher values increase batch severity
    batch.rmEffect = FALSE, # create the gold standard when set to True
    
    # cell types
    group.prob = rep(1/n.cell.types, n.cell.types), 
    de.facLoc = de.location, # increase separation between cell types
    de.facScale = de.scale # increase separation between cell types
)

In [4]:
run.all<-function(sim_params){
    # generate the simulated dataset and its respective gold standard
    sim<-splatSimulateGroups(sim_params, verbose = F)
    gold.standard<-splatSimulateGroups(sim_params, verbose = F, batch.rmEffect = T)
    expr.datasets<-list(sim.log = sim, gold.standard = gold.standard)

    # process
    expr.datasets<-lapply(expr.datasets, FUN = function(sce) qc.data(sce)) # do qc

    genes.intersect<-intersect(rownames(expr.datasets$sim.log), rownames(expr.datasets$gold.standard))
    cells.intersect<-intersect(colnames(expr.datasets$sim.log), colnames(expr.datasets$gold.standard))
    expr.datasets<-lapply(expr.datasets, FUN = function(sce) {
        sce <- sce[genes.intersect, cells.intersect] # retain only genes/cells in common b/w the 2 datasets

        sce <- scater::logNormCounts(sce)
        sce <- cluster.cells(sce, assay.type = 'logcounts') # run PCA and SNN clustering    
    })

    # batch correct
    
    # quantify batch severity
    gs.batcheffect<-quantify.batch.effect(expr.datasets$gold.standard, assay.type = 'logcounts')   
    sim.log.batcheffect<-quantify.batch.effect(expr.datasets$sim.log, assay.type = 'logcounts')
    sim.corrected.batcheffect<- # batch corrected data
                          
    # score the communication in each context (here, context == batch)
                          
    # run the tensor decomposition
                          
    # calculate the CorrIndex with the gold-standard
                          
    return(expr.datasets)
}

In [5]:
# to do: 
# how to add multiple contexts to the simulation!
# commented portions above

In [80]:
# batch correct
sce<-expr.datasets$sim

# manually overwrite features that will be re-calculated as a sanity check
reducedDims(sce.batch)<-NULL
colData(sce.batch)<-colData(sce.batch[colnames(colData(sce.batch)) != 'Cluster',])

# do the batch correction 

# run PCA and SNN clustering
cluster.cells(sce, assay.type = ??)