# POC of hyperparameter tuning with Optuna on scVI

In [1]:
import os
import sys
import pandas as pd
import optuna
import jax
import scvi

import anndata as ad
from scib_metrics.benchmark import Benchmarker, BioConservation, BatchCorrection

ImportError: cannot import name 'get_num_classes' from 'torchmetrics.utilities.data' (/home/icb/tim.treis/miniconda/envs/scverse/envs/myenv/lib/python3.11/site-packages/torchmetrics/utilities/data.py)

In [None]:
for pkg in [jax]:
    print(f'{pkg.__name__}: {pkg.__version__}')

jax: 0.4.35


## Load data

In [None]:
data_path = "/home/icb/tim.treis/projects/broad_integrate/2023_Arevalo_BatchCorrection/outputs/scenario_7/mad_int_featselect.parquet"

data = pd.read_parquet(data_path)
metadata_cols = data.filter(regex="Metadata").columns

adata = ad.AnnData(X=data.drop(metadata_cols, axis=1).values, obs=data[metadata_cols])
adata



AnnData object with n_obs × n_vars = 25329 × 1018
    obs: 'Metadata_Source', 'Metadata_Plate', 'Metadata_Well', 'Metadata_JCP2022', 'Metadata_InChIKey', 'Metadata_InChI', 'Metadata_Batch', 'Metadata_PlateType', 'Metadata_PertType', 'Metadata_Row', 'Metadata_Column', 'Metadata_Microscope'

## Define Optuna trials and evaluation logic

In [None]:
def scib_benchmark_embedding(
    adata: ad.AnnData,
    batch_key: str,
    label_key: str,
) -> float:
    adata.obsm["trial"] = adata.X

    # silence output
    sys.stdout = open(os.devnull, "w")

    bm = Benchmarker(
        adata=adata,
        batch_key=batch_key,
        label_key=label_key,
        embedding_obsm_keys=["trial"],
        bio_conservation_metrics=BioConservation(),
        batch_correction_metrics=BatchCorrection(),
    )
    bm.benchmark()
    df = bm.get_results(min_max_scale=False)

    # restore output
    sys.stdout.close()
    sys.stdout = sys.__stdout__

    return df.loc["trial"][["Batch correction", "Bio conservation"]].values


def objective(
    trial,
    adata: ad.AnnData,
    batch_key: str,
    label_key: str,
    smoketest: bool = True,
):
    # Silence output during training and evaluation
    sys.stdout = open(os.devnull, "w")

    # Optimize scVI hyperparameters:
    n_hidden = trial.suggest_int("n_hidden", 64, 256, step=64)
    n_latent = trial.suggest_int("n_latent", 10, 100)
    n_layers = trial.suggest_int("n_layers", 1, 3)
    dropout_rate = trial.suggest_float("dropout_rate", 0.0, 0.5)
    # learning_rate = trial.suggest_float("learning_rate", 1e-4, 1e-2, log=True)
    n_epochs = 2 if smoketest else trial.suggest_int("n_epochs", 50, 200)

    # Preprocess data: subtract the minimum to ensure non-negative values (if needed)
    min_value = adata.X.min()
    adata.X -= min_value

    scvi.model.SCVI.setup_anndata(
        adata,
        batch_key=batch_key,
        labels_key=label_key,
        categorical_covariate_keys=["Metadata_Batch"],
    )
    vae = scvi.model.SCVI(
        adata,
        n_hidden=n_hidden,
        n_latent=n_latent,
        n_layers=n_layers,
        dropout_rate=dropout_rate,
    )

    vae.train(
        max_epochs=n_epochs,
        early_stopping=True,
        early_stopping_monitor="elbo_validation",
    )

    vals = vae.get_latent_representation()
    features = [f"scvi_{i}" for i in range(vals.shape[1])]
    integrated_adata = ad.AnnData(
        X=pd.DataFrame(vals, columns=features, index=adata.obs_names),
        obs=adata.obs.copy()
    )

    batch_score, bio_score = scib_benchmark_embedding(
        adata=integrated_adata,
        batch_key=batch_key,
        label_key=label_key,
    )

    sys.stdout.close()
    sys.stdout = sys.__stdout__

    return batch_score, bio_score

batch_key = "Metadata_Source"
label_key = "Metadata_JCP2022"

study = optuna.create_study(directions=["maximize", "maximize"])
study.optimize(lambda trial: objective(trial, adata.copy(), batch_key, label_key), n_trials=2)

[I 2025-03-24 15:38:33,463] A new study created in memory with name: no-name-7e2f09da-cd03-4adb-a96b-00a0c1d5e332
  self.validate_field(adata)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
SLURM auto-requeueing enabled. Setting signal handlers.
/home/icb/tim.treis/miniconda/envs/scverse/envs/myenv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/icb/tim.treis/miniconda/envs/scverse/envs/myenv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
  

RuntimeError: Unable to initialize backend 'cuda': INVALID_ARGUMENT: Unexpected PJRT_Client_Create_Args size: expected 88, got 72. The plugin is likely built with a later version than the framework. This plugin is built with PJRT API version 0.67. (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

In [None]:
!pip install --upgrade jax jaxlib


In [None]:
study.best_trials

[]