In [None]:
!pip install scanpy==1.9.1
!pip install matplotlib==3.6

In [None]:
import os
from pathlib import Path
import sys
import scanpy as sc

# When using colab, set the path to the modules directory to use saved modules
sys.path.insert(0,'/content/drive/MyDrive/modules/')
import plottingData
from datasets_dict import datasets


methods = [
            'ABC',
            'scgen',
            'AutoClass',
            'CLEAR_embed',
            'scanvi',
            'scvi',
            'scanorama',
            'harmony_embed',
            'combat',
            'scDREAMER_embed',
            'Seurat',
            ]

embedding_methods = ['ABC', 'harmony_embed', 'scDREAMER_embed', 'CLEAR_embed']
lung_biopsies_batchs = ['A1', 'A2','A3', 'A4', 'A5', 'A6']

path = '/content/drive/MyDrive/Colab Notebooks/integrationDatasets/integratedDatasets/'
save_path = '/content/drive/MyDrive/Colab Notebooks/integrationDatasets/final_plots/'
dpi = 300

for method in methods:

  for dataset_name in datasets.keys():

    # get dataset parameters from the dictionary
    label_key = datasets[dataset_name]['label_key']
    batch_key = datasets[dataset_name]['batch_key']
    categories = [batch_key, label_key]

    # define paths
    inPath = os.path.join(path, method)
    outPath = os.path.join(save_path, method)

    # create directory if does not exists
    Path(outPath).mkdir(parents=True, exist_ok=True)

    # read integrated dataset
    integrated_data = sc.read(os.path.join(inPath, f"{dataset_name}_integrated.h5ad"))

    if dataset_name == 'Immune_ALL_hum_mou':
      categories.append('species')

    if dataset_name == 'Lung_atlas_public':
      integrated_data.obs['location'] = integrated_data.obs[batch_key].map(
          lambda x: 'Airways' if x in lung_biopsies_batchs else 'Lung Parenchyma'
      )
      categories.append('location')

    # method specific data preparations
    if method == 'ABC':
      integrated_data.obsm['X_emb'] = integrated_data.X.copy()

    # plot integrated dataset
    print("Plotting integrated dataset-")
    plot_name = f"{method}_{dataset_name}_integrated_NEW"

    if method in embedding_methods:
      plottingData.plot_datasets(integrated_data, categories, outPath, dpi,
                                 plot_name, use_emb=True)

    else:
      plottingData.plot_datasets(integrated_data, categories, outPath, dpi,
                                 plot_name, use_emb=False)

