In [None]:
!pip install scanpy==1.9.1
!pip install tensorflow
!pip install matplotlib==3.6
!pip install scib
!pip install scanorama==1.7.0

!pip install --quiet scvi-colab
!pip install --quiet scib-metrics
from scvi_colab import install
install()

# ---------------------------------------------------------------------
# WHEN USING COLAB, PLEASE RESTART RUNTIME AFTER RUNNING THIS CELL
# ---------------------------------------------------------------------

# (some of the packages being used in this notebook are automatically loaded 
# before installation of a different version and they need to be reloaded)


In [None]:
# Most of the code was addapted from:
# https://github.com/theislab/scib

import scvi
import os
from pathlib import Path
import scipy
import scanpy as sc
import scib
import numpy as np
from scib import utils
import time
import collections 
import sys
import pandas as pd
if sys.version_info.major == 3 and sys.version_info.minor >= 10:
    from collections.abc import MutableSet
    collections.MutableSet = collections.abc.MutableSet
else: 
    from collections import MutableSet

# When using colab, set the path to the modules directory to use saved modules
sys.path.append('/content/drive/MyDrive/modules/')
from datasets_dict import datasets

# The main integration function (calls the appropriate integration method)
def integrate(inPath, method, batchLabel, celltypeLabel=None):

    # read the original dataset
    adata = sc.read(inPath)
    hvg = None

    # integrate the dataset
    if celltypeLabel is not None:
        start_time = time.time()
        integrated = method(adata, batchLabel, hvg, celltypeLabel)
        end_time = time.time()

    else:
        start_time = time.time()
        integrated = method(adata, batchLabel, hvg)
        end_time = time.time()

    elapsed_time = end_time - start_time

    return integrated, elapsed_time


def trvae(adata, batch_category):

    conditions = adata.obs[batch_category].unique().tolist()
    network = trvae.archs.trVAEMulti(adata.shape[1], conditions,
                                 z_dimension=10,
                                 gene_names=adata.var_names.tolist(),
                                 architecture=[256, 64],
                                 model_path='./models/trVAE/haber/',
                                 alpha=0.0001,
                                 beta=50,
                                 eta=100,
                                 loss_fn='sse',
                                 output_activation='linear')

    network.train(adata,
                  batch_category,
                  train_size=0.8,
                  n_epochs=50,
                  batch_size=512,
                  early_stop_limit=10,
                  lr_reducer=20,
                  verbose=5,
                  save=True,
                  )
    target_condition = adata.obs[batch_category].value_counts().index[0]
    corrected = network.predict(adata, batch_category, target_condition=target_condition)

    return corrected


# Wrapper function for scib combat
def combat(adata, batch_category, hvg=None):
  return scib.integration.combat(adata, batch_category)


# integration procedure from the tutorial "Integration with scVI":
# https://docs.scvi-tools.org/en/stable/tutorials/notebooks/harmonization.html
def scviIntegration(adata, batch_category, hvg):

    scvi.model.SCVI.setup_anndata(adata, layer="counts", batch_key=batch_category)
    vae = scvi.model.SCVI(adata, n_layers=2, n_latent=30, gene_likelihood="nb")
    vae.train()
    corrected = adata.copy()
    corrected.X = vae.get_normalized_expression()
    return corrected


def scanvi(adata, batch_category, hvg, label_key):

    scvi.model.SCVI.setup_anndata(adata, layer="counts", batch_key=batch_category)
    vae = scvi.model.SCVI(adata, n_layers=2, n_latent=30, gene_likelihood="nb")
    vae.train()

    lvae = scvi.model.SCANVI.from_scvi_model(
      vae,
      adata=adata,
      labels_key=label_key,
      unlabeled_category="Unknown",
    )

    lvae.train(max_epochs=20, n_samples_per_label=100)
    corrected = adata.copy()
    corrected.X = lvae.get_normalized_expression()
    return corrected



# path to the original dataset (after subset to 3000 highly variable genes)
base_path = '/content/drive/MyDrive/Colab Notebooks/integrationDatasets/'
execution_times = {}

methods = {    
        'scvi': scviIntegration,
        'scanvi': scanvi,
        'combat': combat,
        'scanorama': scib.integration.scanorama,
    }


for method in methods.keys():
  print("Integrating using method: ", method)

  # for dataset_name in datasets.keys():
  for dataset_name in ['small_atac_windows']:

    # get dataset parameters
    label_key = datasets[dataset_name]['label_key']
    batch_key = datasets[dataset_name]['batch_key']

    # set paths
    inPath = os.path.join(base_path, f"{dataset_name}_hvg.h5ad")
    outPath = os.path.join(base_path, 'integratedDatasets', method)
    
    # create directory if does not exists
    Path(outPath).mkdir(parents=True, exist_ok=True)
    
    if method != 'scanvi':
      label_key = None


    # integrate the data
    integrated, elapsed_time = integrate(inPath, methods[method], 
                                         batch_key, label_key)

    # save integration duration time
    minutes, seconds = divmod(elapsed_time, 60)
    execution_times[dataset_name] = elapsed_time

    print("Integrated: ", dataset_name)
    print(f"Duration: {minutes} minutes and {seconds} seconds")
    
    # write integrated data
    sc.write(os.path.join(outPath, f"{dataset_name}_integrated.h5ad"), integrated)
    print("Integrated data saved")

  # write execution times data
  df = pd.DataFrame(list(execution_times.items()), 
                    columns=['Dataset', 'Execution Time'])
  df.to_csv(os.path.join(outPath, 'execution_times.csv'), index=False)

