In [None]:
!pip install git+https://github.com/reutd/ABC.git
!pip install scanpy==1.9.1
!pip install matplotlib==3.6
!pip install scib
!pip install louvain

In [None]:
# LISI scores use the knn_graph.o file created by the cpp code: knn_graph.cpp
# In order to use these metrics we need to recompile the code in the current 
# environment, and replace the existing file with the compiled new file, 
# using the following code:

!wget https://raw.githubusercontent.com/theislab/scib/main/scib/knn_graph/knn_graph.cpp
!g++ -O3 -o knn_graph.o knn_graph.cpp

import shutil
import pathlib
import scib
import os

root = pathlib.Path(scib.__file__).parent
print(root)

cpp_file_path = (root / "knn_graph/")

os.remove(str(root / "knn_graph/knn_graph.o"))
shutil.move("knn_graph.o", str(cpp_file_path))

In [None]:
import os
from pathlib import Path
import scanpy as sc
import itertools as it
import time
import traceback
import sys
import warnings
warnings.filterwarnings('ignore')

from ABC import ABC

sys.path.append('/content/drive/MyDrive/modules/')
from evaluate_integration import evaluate_integration



# integrates the data and evaluates the integration. Returns the final score,
# the metrics dataframe (row per metric) and the trained model.
def integrate_and_score(orig_dataset, eval_params, dataset_params, 
                        training_hParams, verbos=True):

  label_key = dataset_params['label_key']
  batch_key = dataset_params['batch_key']
  organism = dataset_params['organism']
  
  # --- Data Integration ---
  lr = training_hParams['LR']
  rec_l = training_hParams['Rec_Loss_W']
  l_dim = training_hParams['L_Dim']

  # create the model
  start_time = time.time()
  model = ABC(orig_dataset, batch_key, label_key, latent_dim=l_dim)

  # train the model and integrate the dataset
  integrated = model.batch_correction(data_name=dataset_name,                                                                             
                                      base_LR=lr,
                                      recon_loss_w=rec_l,
                                      )
  end_time = time.time()
  elapsed_time = end_time - start_time
  minutes, seconds = divmod(elapsed_time, 60)

  print("Integrated: ", dataset_name)
  print(f"Duration: {minutes} minutes and {seconds} seconds")


  # --- Integration Evaluation ---
  if dataset_params['ATAC']:
    data_type = 'ATAC'
  else:
    data_type = 'RNA'


  metrics = evaluate_integration(orig_dataset, integrated, eval_params,
                                 batch_key=batch_key,
                                  label_key=label_key,
                                  data_type=data_type,
                                  organism=organism,
                                 )

  
  # store execution time
  metrics.loc['Time'] = elapsed_time

  
  if verbos:
    print("integration scores:")
    print(metrics)

  final_score = metrics.loc['Final_Score'][0]

  return final_score, metrics, model, integrated


# datasets to be used for hyperparameter optimization
datasets = {
    'small_atac_gene_activity':
        {
            'label_key': 'final_cell_label',
            'batch_key': 'batchname',
            'organism': 'mouse',
            'subsample': 1,
            'ATAC': True,
         },

    'human_pancreas_norm_complexBatch':
        {
            'label_key': 'celltype',
            'batch_key': 'tech',
            'organism': 'human',
            'ATAC': False,
            'subsample': 1,
        },
}

# evaluation metrics to calculate
eval_params = {
    'silhouette_': True,
    'nmi_': True,
    'ari_': True,
    'cell_cycle_': True,    # turns to false for ATAC
    'isolated_labels_': True,
    'hvg_score_': True,
    'graph_conn_': True,
    'lisi_graph_': True,
    'trajectory_': True     # turns to false if pseudotime info not present
}


# Hyperparameters to test
param_grid = {
    'LR': [None, 0.001, 0.0001, 0.0002],
    'L_Dim': [32, 64, 128, 256, 512, 1024],
    'Rec_Loss_W': [0.2, 0.4, 0.6, 0.8]
}

# define paths
base_path = '/content/drive/MyDrive/Colab Notebooks/integrationDatasets/'
metrics_path = os.path.join(base_path, 'final_metrics')
weights_path = os.path.join(base_path, 'trainedModels')


# create directories if needed
Path(metrics_path).mkdir(parents=True, exist_ok=True)
Path(weights_path).mkdir(parents=True, exist_ok=True)

# save lists of failed param combinations
failed_params={
    'small_atac_gene_activity': [],
    'human_pancreas_norm_complexBatch': [],
}

subsample = False
best_score = 0

# Run hyperparameter grid search for each dataset
for dataset_name in datasets.keys():

  # get dataset parameters
  dataset_params = datasets[dataset_name]

  # read data
  inPath = os.path.join(base_path, f"{dataset_name}_hvg.h5ad")
  adata = sc.read(inPath)

  if subsample:
    adata = sc.pp.subsample(adata, n_obs=1001, random_state=1, copy=True)

    
  # Generate all combinations
  all_names = param_grid.keys()
  combinations = it.product(*(param_grid[name] for name in all_names))
  
  # For each combination, train and evaluate the model
  for combination in combinations:

    params = dict(zip(all_names, combination))

    # display current combination and dataset name
    print('-' * 100)
    print(f'Using params: {params} with dataset: {dataset_name}')


    # run integration and score results
    try:
      score, metrics, model, integrated = integrate_and_score(adata, eval_params = eval_params, 
                                              dataset_params = dataset_params, 
                                              training_hParams = params)
    except Exception as e:
      print('-' * 100)
      print(f'Failed to integrate or evaluate params: {params} with dataset: {dataset_name}')
      # print(e)
      traceback.print_exc(limit=None)
      failed_params[dataset_name].append(params)
      print("moving on...")
      print('-' * 100)
      
      continue


    # display parameters and final score
    print(f'Params: {params} Final Score: {score}')
    
    # process and save metrics
    metrics_file = os.path.join(metrics_path, f'{dataset_name}_hyperParamOpt.csv')
    metrics = metrics.T
    metrics = metrics.dropna(axis=1, how='all')

    # Rename the index of the transposed DataFrame
    metrics.index = [str(params)]

    if os.path.exists(metrics_file):
      metrics.to_csv(metrics_file, mode='a', index=True, header=False)
    else:
      metrics.to_csv(metrics_file, index=True)


    # save trained model weights and update best score
    if score > best_score:
      best_score = score
      print(f'new best score: {score}')
      print(f'for combination: {params}')
      
      # save model weights       
      weights_filepath = os.path.join(weights_path, dataset_name, 
                                      f'{dataset_name}_bestModel')
      model.save_weights(weights_filepath, overwrite=True)


print("Failed runs:")
print(failed_params)
