# Model hyperparameter tuning with scVI

<div class="alert alert-warning">

Warning

`scvi.autotune` development is still in progress. The API is subject to change.

</div>

Finding an effective set of model hyperparameters (e.g. learning rate, number of hidden layers, etc.) is an important component in training a model as its performance can be highly dependent on these non-trainable parameters. Manually tuning a model often involves picking a set of hyperparameters to search over and then evaluating different configurations over a validation set for a desired metric. This process can be time consuming and can require some prior intuition about a model and dataset pair, which is not always feasible. 

In this tutorial, we show how to use `scvi`'s [`autotune`](https://docs.scvi-tools.org/en/latest/api/user.html#model-hyperparameter-autotuning) module, which allows us to automatically find a good set of model hyperparameters using [Ray Tune](https://docs.ray.io/en/latest/tune/index.html). We will use `SCVI` and a subsample of the [heart cell atlas](https://www.heartcellatlas.org/#DataSources) for the task of batch correction, but the principles outlined here can be applied to any model and dataset. In particular, we will go through the following steps:

1. Installing required packages
1. Loading and preprocessing the dataset
1. Defining the tuner and discovering hyperparameters
1. Running the tuner
1. Visualizing latent spaces
1. Optional: Monitoring progress with Tensorboard
1. Optional: Tuning over integration metrics with `scib-metrics`

## 1. Installing required packages

In [None]:
!pip install --quiet hyperopt
!pip install --quiet "ray[tune]"
!pip install --quiet scvi-colab
from scvi_colab import install

install()

In [None]:
import scanpy as sc
import scvi
from ray import tune
from scvi import autotune

## 2. Loading and preprocessing the dataset

In [None]:
adata = scvi.data.heart_cell_atlas_subsampled()
adata

The only preprocessing step we will perform in this case will be to subsample the dataset for 2000 highly variable genes using `scanpy`.

In [None]:
sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor="seurat_v3", subset=True)
adata

## 3. Defining the tuner and discovering hyperparameters

The first part of our workflow is the same as before: we start with our desired model class, and we register our dataset with it using `setup_anndata`. All datasets must be registered prior to hyperparameter tuning.

In [None]:
model_cls = scvi.model.SCVI
model_cls.setup_anndata(adata)

Next, our main entry point in the `autotune` module is `ModelTuner`, which is a wrapper around [`ray.tune.Tuner`](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tuner) with additional functionality specific to `scvi-tools`. We can define a new `ModelTuner` by providing it with our model class.

In [None]:
scvi_tuner = autotune.ModelTuner(model_cls)

`ModelTuner` will scan `SCVI` for all hyperparameters that can be tuned -- these can be viewed by calling `info()`. By default, this method will display three tables:

1. **Tunable hyperparameters**: The names of hyperparameters that can be tuned, the type of parameter they are, their default values, and the internal classes they are defined in.
1. **Available metrics**: The metrics that can be used to evaluate the performance of the model. One of these must be provided when running the tuner.
1. **Default search space**: The default search space for the model class, which will be used if no search space is provided by the user.

In [None]:
scvi_tuner.info()

## 4. Running the tuner

Now that we know what hyperparameters are available to us, we can define a search space using the [search space API](https://docs.ray.io/en/latest/tune/api_docs/search_space.html) in `ray.tune`. For this tutorial, we choose a simple search space with two model hyperparameters and one training plan hyperparameter. These can all be combined into a single dictionary that we pass into the `fit` method.

In [None]:
search_space = {
    "n_hidden": tune.choice([64, 128, 256]),
    "n_layers": tune.choice([1, 2, 3]),
    "lr": tune.loguniform(1e-4, 1e-2),
}

There are a couple more arguments we should be aware of before fitting the tuner:

- `num_samples`: The total number of hyperparameters to sample from `search_space`. This is the total number of models that will be trained.
- `max_epochs`: The maximum number of epochs to train each model for.
- `resources`: A dictionary of maximum resources to allocate for the whole experiment. This allows us to run concurrent trials on limited hardware.

In [None]:
results = scvi_tuner.fit(
    adata,
    metric="validation_loss",
    search_space=search_space,
    num_samples=5,
    max_epochs=100,
    resources={"cpu": 20, "gpu": 1},
)

## 5. Visualizing latent spaces

In [None]:
# work in progress :)

## 6. Optional: Monitoring progress with Tensorboard

In [1]:
# work in progress :)

## 7. Optional: Tuning over integration metrics with `scib-metrics`

In [2]:
# work in progress :)