# POC of hyperparameter tuning with Optuna on scPoli
- using `envs/scpoli.yaml` as the base (annoyingly with Python 3.9)
- adding `Optuna` and `scib-metrics` on top of that


highly specific installs
```
pip install scib-metrics
pip install optuna
pip install scvi-tools==1.1.6.post2 numpyro==0.15.3 scarches==0.6.1 

# to speed up benchmarker
pip install jax-cuda12-pjrt jax-cuda12-plugin 

# to visualise trials
pip install plotly scikit-learn nbformat 
```

In [1]:
from scarches.models.scpoli import scPoli

  from .autonotebook import tqdm as notebook_tqdm
 captum (see https://github.com/pytorch/captum).


In [2]:
import os
import sys
import pandas as pd
import anndata as ad
import optuna
import pickle
from scib_metrics.benchmark import Benchmarker

Okay, finally worked. Going to save that env.

```
conda env export --name scpolituna --file ./scpolituna_conda.yaml --no-build
pip freeze > scpolituna_pip.txt
```

## Load data

In [3]:
data_path = "/home/icb/tim.treis/projects/broad_integrate/2023_Arevalo_BatchCorrection/outputs/scenario_7/mad_int_featselect.parquet"

data = pd.read_parquet(data_path)
metadata_cols = data.filter(regex="Metadata").columns

adata = ad.AnnData(X=data.drop(metadata_cols, axis=1).values, obs=data[metadata_cols])
adata



AnnData object with n_obs × n_vars = 25329 × 1018
    obs: 'Metadata_Source', 'Metadata_Plate', 'Metadata_Well', 'Metadata_JCP2022', 'Metadata_InChIKey', 'Metadata_InChI', 'Metadata_Batch', 'Metadata_PlateType', 'Metadata_PertType', 'Metadata_Row', 'Metadata_Column', 'Metadata_Microscope'

## Define Optuna trials and evaluation logic

In [5]:
def scib_benchmark_embedding(
    adata: ad.AnnData,
    batch_key: str,
    label_key: str,
) -> float:
    adata.obsm["trial"] = adata.X

    # silence output
    sys.stdout = open(os.devnull, "w")

    bm = Benchmarker(
        adata=adata,
        batch_key=batch_key,
        label_key=label_key,
        embedding_obsm_keys=["trial"],
    )
    bm.benchmark()
    df = bm.get_results(min_max_scale=False)

    # restore output
    sys.stdout.close()
    sys.stdout = sys.__stdout__

    return df.loc["trial"][["Batch correction", "Bio conservation"]].values


def objective(trial, adata, batch_key, label_key):

    # silence output
    sys.stdout = open(os.devnull, "w")

    # Optimize hidden layer sizes
    num_layers = trial.suggest_int("num_layers", 1, 3)  # 1 to 4 layers
    hidden_layer_sizes = [trial.suggest_int(f"layer_{i}_size", 32, 512, step=32) for i in range(num_layers)]
    
    # Optimize latent dimensions and embedding size
    latent_dim = trial.suggest_int("latent_dim", 16, 128, step=16)
    embedding_dims = trial.suggest_int("embedding_dims", 2, 20, step=1)

    # Optimize pretraining to training epoch ratio
    total_epochs = 100
    pretrain_to_train_ratio = trial.suggest_float("pretrain_to_train_ratio", 0.1, 0.9, step=0.1)
    n_pretrain_epochs = int(total_epochs * pretrain_to_train_ratio)
    n_train_epochs = total_epochs - n_pretrain_epochs

    # Optimize other parameters
    alpha_epoch_anneal = trial.suggest_int("alpha_epoch_anneal", 100, 1000, step=100)
    eta = trial.suggest_float("eta", 0.1, 1.0, step=0.1)

    model = scPoli(
        adata=adata,
        condition_keys=batch_key,
        cell_type_keys=label_key,
        hidden_layer_sizes=hidden_layer_sizes,
        latent_dim=latent_dim,
        embedding_dims=embedding_dims,
        recon_loss="mse",
    )

    model.train(
        n_epochs=n_train_epochs,
        pretraining_epochs=n_pretrain_epochs,
        use_early_stopping=True,
        alpha_epoch_anneal=alpha_epoch_anneal,
        eta=eta,
    )

    model.model.eval()
    vals = model.get_latent(adata, mean=True)
    features = [f"scpoli_{i}" for i in range(vals.shape[1])]

    integrated_adata = ad.AnnData(
        X=pd.DataFrame(vals, columns=features, index=adata.obs_names),
        obs=adata.obs,
    )

    batch, bio = scib_benchmark_embedding(
        adata=integrated_adata,
        batch_key=batch_key, 
        label_key=label_key
    )

    # restore output
    sys.stdout.close()
    sys.stdout = sys.__stdout__

    return batch, bio

batch_key = "Metadata_Source"
label_key = "Metadata_JCP2022"

study = optuna.create_study(directions=["maximize", "maximize"])
study.optimize(lambda trial: objective(trial, adata.copy(), batch_key, label_key), n_trials=1)

[I 2024-12-02 20:16:12,292] A new study created in memory with name: no-name-d68f48d1-2b40-44dc-9769-c2a6cb957c2d
INFO:scarches.trainers.scpoli.trainer:GPU available: True, GPU used: True
  c = torch.tensor(label_tensor, device=device).T
Computing neighbors: 100%|██████████| 1/1 [00:25<00:00, 25.88s/it]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:39<00:00, 39.16s/it]
[I 2024-12-02 20:31:44,845] Trial 0 finished with values: [0.2830856572133941, 0.5097253965108155] and parameters: {'num_layers': 1, 'layer_0_size': 224, 'latent_dim': 80, 'embedding_dims': 11, 'pretrain_to_train_ratio': 0.1, 'alpha_epoch_anneal': 900, 'eta': 0.8}.


## POC works, run it as headless script now with enough trials

In [3]:
with open("./optuna_study.pkl", "rb") as input_file:
    study = pickle.load(input_file)

In [6]:
optuna.visualization.plot_pareto_front(study, target_names=["Batch", "Bio"])

### Convert to table

In [7]:
trials = study.trials

data = []
for trial in trials:
    batch, bio = trial.values
    sorted_params = {k: trial.params[k] for k in sorted(trial.params.keys())}

    row = {
        "state": trial.state.name,
        "batch": batch,
        "bio": bio,
        "total": 0.6 * bio + 0.4 * batch,
        **sorted_params,          # Hyperparameters
    }
    data.append(row)


data = pd.DataFrame(data)
data = data.sort_values("total", ascending=False)

# custom sort columns
custom_order = ["state", "batch", "bio", "total"]  # Columns to appear first
remaining_columns = sorted([col for col in data.columns if col not in custom_order])
final_column_order = custom_order + remaining_columns
data = data[final_column_order]

data

Unnamed: 0,state,batch,bio,total,alpha_epoch_anneal,embedding_dims,eta,latent_dim,layer_0_size,layer_1_size,layer_2_size,num_layers,pretrain_to_train_ratio
26,COMPLETE,0.369717,0.505063,0.450925,600,5,0.6,32,416,96.0,384.0,3,0.5
0,COMPLETE,0.348516,0.51648,0.449294,300,13,0.5,32,480,288.0,,2,0.5
11,COMPLETE,0.331015,0.52205,0.445636,700,4,0.8,128,256,,,1,0.5
30,COMPLETE,0.349754,0.508287,0.444874,800,19,0.8,96,384,256.0,,2,0.4
38,COMPLETE,0.333959,0.518329,0.444581,600,7,0.9,80,160,,,1,0.5
42,COMPLETE,0.350704,0.506698,0.444301,200,3,0.7,32,64,480.0,,2,0.5
45,COMPLETE,0.352857,0.504641,0.443927,300,8,0.8,32,416,512.0,448.0,3,0.1
12,COMPLETE,0.341758,0.511121,0.443376,800,8,1.0,80,320,480.0,,2,0.1
47,COMPLETE,0.337732,0.504015,0.437502,900,10,0.7,32,128,416.0,,2,0.4
28,COMPLETE,0.330649,0.505752,0.435711,400,8,0.7,32,384,480.0,,2,0.1


In [18]:
df = study.trials_dataframe()
df = df.rename(columns={"values_0": "batch", "values_1": "bio"})
df["total"] = 0.6 * df["bio"] + 0.4 * df["batch"]
df = df.sort_values("total", ascending=False)
df

Unnamed: 0,number,batch,bio,datetime_start,datetime_complete,duration,params_alpha_epoch_anneal,params_embedding_dims,params_eta,params_latent_dim,params_layer_0_size,params_layer_1_size,params_layer_2_size,params_num_layers,params_pretrain_to_train_ratio,system_attrs_nsga2:generation,state,total
26,26,0.369717,0.505063,2024-12-02 22:59:16.571127,2024-12-02 23:00:58.440468,0 days 00:01:41.869341,600,5,0.6,32,416,96.0,384.0,3,0.5,0,COMPLETE,0.450925
0,0,0.348516,0.51648,2024-12-02 21:18:13.889310,2024-12-02 21:21:52.457649,0 days 00:03:38.568339,300,13,0.5,32,480,288.0,,2,0.5,0,COMPLETE,0.449294
11,11,0.331015,0.52205,2024-12-02 21:53:36.683358,2024-12-02 21:55:15.499466,0 days 00:01:38.816108,700,4,0.8,128,256,,,1,0.5,0,COMPLETE,0.445636
30,30,0.349754,0.508287,2024-12-02 23:16:30.105069,2024-12-02 23:20:15.210932,0 days 00:03:45.105863,800,19,0.8,96,384,256.0,,2,0.4,0,COMPLETE,0.444874
38,38,0.333959,0.518329,2024-12-02 23:45:51.658629,2024-12-02 23:47:27.366824,0 days 00:01:35.708195,600,7,0.9,80,160,,,1,0.5,0,COMPLETE,0.444581
42,42,0.350704,0.506698,2024-12-03 00:00:56.923245,2024-12-03 00:02:37.577585,0 days 00:01:40.654340,200,3,0.7,32,64,480.0,,2,0.5,0,COMPLETE,0.444301
45,45,0.352857,0.504641,2024-12-03 00:11:33.527766,2024-12-03 00:22:22.479232,0 days 00:10:48.951466,300,8,0.8,32,416,512.0,448.0,3,0.1,0,COMPLETE,0.443927
12,12,0.341758,0.511121,2024-12-02 21:55:15.500483,2024-12-02 22:05:23.076041,0 days 00:10:07.575558,800,8,1.0,80,320,480.0,,2,0.1,0,COMPLETE,0.443376
47,47,0.337732,0.504015,2024-12-03 00:30:51.971004,2024-12-03 00:34:47.506981,0 days 00:03:55.535977,900,10,0.7,32,128,416.0,,2,0.4,0,COMPLETE,0.437502
28,28,0.330649,0.505752,2024-12-02 23:04:29.309385,2024-12-02 23:15:14.648358,0 days 00:10:45.338973,400,8,0.7,32,384,480.0,,2,0.1,0,COMPLETE,0.435711
