# POC of hyperparameter tuning with Optuna on scPoli
- using `envs/scpoli.yaml` as the base (annoyingly with Python 3.9)
- adding `Optuna` and `scib-metrics` on top of that


highly specific installs
```
pip install scib-metrics
pip install optuna
pip install scvi-tools==1.1.6.post2 numpyro==0.15.3 scarches==0.6.1 

# to speed up benchmarker
pip install jax-cuda12-pjrt jax-cuda12-plugin 

# to visualise trials
pip install plotly scikit-learn nbformat 
```

In [2]:
from scarches.models.scpoli import scPoli

 captum (see https://github.com/pytorch/captum).


In [77]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scanpy as sc
import anndata as ad
import optuna
import pickle
from scib_metrics.benchmark import Benchmarker


Okay, finally worked. Going to save that env.

```
conda env export --name scpolituna --file ./scpolituna_conda.yaml --no-build
pip freeze > scpolituna_pip.txt
```

## Load data

In [56]:
data_path = "/home/icb/tim.treis/projects/broad_integrate/2023_Arevalo_BatchCorrection/outputs/scenario_7/mad_int_featselect.parquet"

data = pd.read_parquet(data_path)
metadata_cols = data.filter(regex="Metadata").columns

adata = ad.AnnData(X=data.drop(metadata_cols, axis=1).values, obs=data[metadata_cols])
adata



AnnData object with n_obs × n_vars = 25329 × 1018
    obs: 'Metadata_Source', 'Metadata_Plate', 'Metadata_Well', 'Metadata_JCP2022', 'Metadata_InChIKey', 'Metadata_InChI', 'Metadata_Batch', 'Metadata_PlateType', 'Metadata_PertType', 'Metadata_Row', 'Metadata_Column', 'Metadata_Microscope'

## Define Optuna trials and evaluation logic

In [74]:
def scib_benchmark_embedding(
    adata: ad.AnnData,
    batch_key: str,
    label_key: str,
) -> float:
    adata.obsm["trial"] = adata.X

    # silence output
    sys.stdout = open(os.devnull, "w")

    bm = Benchmarker(
        adata=adata,
        batch_key=batch_key,
        label_key=label_key,
        embedding_obsm_keys=["trial"],
    )
    bm.benchmark()
    df = bm.get_results(min_max_scale=False)

    # restore output
    sys.stdout.close()
    sys.stdout = sys.__stdout__

    return df.loc["trial"][["Batch correction", "Bio conservation"]].values


def objective(trial, adata, batch_key, label_key):

    # silence output
    sys.stdout = open(os.devnull, "w")

    # Optimize hidden layer sizes
    num_layers = trial.suggest_int("num_layers", 1, 3)  # 1 to 4 layers
    hidden_layer_sizes = [trial.suggest_int(f"layer_{i}_size", 32, 512, step=32) for i in range(num_layers)]
    
    # Optimize latent dimensions and embedding size
    latent_dim = trial.suggest_int("latent_dim", 16, 128, step=16)
    embedding_dims = trial.suggest_int("embedding_dims", 2, 20, step=1)

    # Optimize pretraining to training epoch ratio
    total_epochs = 100
    pretrain_to_train_ratio = trial.suggest_float("pretrain_to_train_ratio", 0.1, 0.9, step=0.1)
    n_pretrain_epochs = int(total_epochs * pretrain_to_train_ratio)
    n_train_epochs = total_epochs - n_pretrain_epochs

    # Optimize other parameters
    alpha_epoch_anneal = trial.suggest_int("alpha_epoch_anneal", 100, 1000, step=100)
    eta = trial.suggest_float("eta", 0.1, 1.0, step=0.1)

    model = scPoli(
        adata=adata,
        condition_keys=batch_key,
        cell_type_keys=label_key,
        hidden_layer_sizes=hidden_layer_sizes,
        latent_dim=latent_dim,
        embedding_dims=embedding_dims,
        recon_loss="mse",
    )

    model.train(
        n_epochs=n_train_epochs,
        pretraining_epochs=n_pretrain_epochs,
        use_early_stopping=True,
        alpha_epoch_anneal=alpha_epoch_anneal,
        eta=eta,
    )

    model.model.eval()
    vals = model.get_latent(adata, mean=True)
    features = [f"scpoli_{i}" for i in range(vals.shape[1])]

    integrated_adata = ad.AnnData(
        X=pd.DataFrame(vals, columns=features, index=adata.obs_names),
        obs=adata.obs,
    )

    batch, bio = scib_benchmark_embedding(
        adata=integrated_adata,
        batch_key=batch_key, 
        label_key=label_key
    )

    # restore output
    sys.stdout.close()
    sys.stdout = sys.__stdout__

    return batch, bio

In [75]:
batch_key = "Metadata_Source"
label_key = "Metadata_JCP2022"

study = optuna.create_study(directions=["maximize", "maximize"])
study.optimize(lambda trial: objective(trial, adata.copy(), batch_key, label_key), n_trials=50)

[I 2024-12-02 20:05:51,398] A new study created in memory with name: no-name-e5131b3b-9827-41dc-afc5-dafc0b316028
INFO:scarches.trainers.scpoli.trainer:GPU available: True, GPU used: True
[W 2024-12-02 20:06:31,885] Trial 0 failed with parameters: {'num_layers': 1, 'layer_0_size': 352, 'latent_dim': 48, 'embedding_dims': 20, 'pretrain_to_train_ratio': 0.1, 'alpha_epoch_anneal': 300, 'eta': 0.2} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/home/icb/tim.treis/miniconda/envs/scverse/envs/scpolituna/lib/python3.9/site-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
  File "/localscratch/tim.treis/ipykernel_3126830/4054591727.py", line 5, in <lambda>
    study.optimize(lambda trial: objective(trial, adata.copy(), batch_key, label_key), n_trials=50)
  File "/localscratch/tim.treis/ipykernel_3126830/3460070615.py", line 60, in objective
    model.train(
  File "/home/icb/tim.treis/miniconda/envs

KeyboardInterrupt: 

In [72]:
optuna.visualization.plot_pareto_front(study, target_names=["Batch", "Bio"])

In [78]:
study

<optuna.study.study.Study at 0x7fd0311f10d0>

In [79]:
with open("./optuna_study.pkl", "wb") as output_file:
    pickle.dump(study, output_file)