# Tutorial: Population-Based Training

### In this tutorial, we'll show you how to leverage Population-based Training.

<img src="pbt.png" alt="PBT" width="600"/>

Tune is a scalable framework for model training and hyperparameter search with a focus on deep learning and deep reinforcement learning.

* **Code**: https://github.com/ray-project/ray/tree/master/python/ray/tune 
* **Examples**: https://github.com/ray-project/ray/tree/master/python/ray/tune/examples
* **Documentation**: http://ray.readthedocs.io/en/latest/tune.html
* **Mailing List** https://groups.google.com/forum/#!forum/ray-dev

In [None]:
## If you are running on Google Colab, uncomment below to install the necessary dependencies 
## before beginning the exercise.

# print("Setting up colab environment")
# !pip uninstall -y -q pyarrow
# !pip install -q https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev5-cp36-cp36m-manylinux1_x86_64.whl
# !pip install -q ray[debug]

# # A hack to force the runtime to restart, needed to include the above dependencies.
# print("Done installing! Restarting via forced crash (this is not an issue).")
# import os
# os._exit(0)

In [None]:
import tensorflow as tf
try:
    tf.get_logger().setLevel('INFO')
except Exception as exc:
    print(exc)
import warnings
warnings.simplefilter("ignore")

import os
import numpy as np
import torch
import torch.optim as optim
from torchvision import datasets
from ray.tune.examples.mnist_pytorch import train, test, ConvNet, get_data_loaders

import ray
from ray import tune
from ray.tune import track
from ray.tune.schedulers import PopulationBasedTraining
from ray.tune.util import validate_save_restore

%matplotlib inline
import matplotlib.style as style
import matplotlib.pyplot as plt
style.use("ggplot")

datasets.MNIST("~/data", train=True, download=True)

# Setup Trainable


To utilize the PopulationBasedTraining Scheduler, we will have to use Tune's more extensive Class-based API. 

This API will allow Tune to take intermediate actions such as checkpointing and changing the hyperparameters in the middle of training.

``train()`` wraps ``_train()``.

A call to ``train()`` on a trainable will execute one logical iteration of training. As a rule of thumb, the execution time of one train call should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes).

### Instructions:

Add training code under ``_train`` as follows:

```python
    def _train(self):
        train(self.model, self.optimizer, self.train_loader, device=self.device)
        acc = test(self.model, self.test_loader, self.device)
        return {"mean_accuracy": acc}
```

In [None]:
class PytorchTrainble(tune.Trainable):
    def _setup(self, config):
        self.device = torch.device("cpu")
        self.train_loader, self.test_loader = get_data_loaders()
        self.model = ConvNet().to(self.device)
        self.optimizer = optim.SGD(
            self.model.parameters(),
            lr=config.get("lr", 0.01),
            momentum=config.get("momentum", 0.9))

    def _train(self):
        # TODO: Add training code here.
        return {"mean_accuracy": acc}

    def _save(self, checkpoint_dir):
        checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
        torch.save(self.model.state_dict(), checkpoint_path)
        return checkpoint_path

    def _restore(self, checkpoint_path):
        self.model.load_state_dict(torch.load(checkpoint_path))
        
    def reset_config(self, new_config):
        del self.optimizer
        self.optimizer = optim.SGD(
            self.model.parameters(),
            lr=new_config.get("lr", 0.01),
            momentum=new_config.get("momentum", 0.9))
        return True


ray.shutdown()  # Restart Ray defensively in case the ray connection is lost. 
ray.init(log_to_driver=False)

validate_save_restore(PytorchTrainble)
validate_save_restore(PytorchTrainble, use_object_store=True)
print("Success!")

# Use population-based training with 2 samples

PBT uses information from the rest of the population to refine the hyperparameters and direct computational resources to models which show promise. 

In PBT, a worker might copy the model parameters from a better performing worker. It can also explore new hyperparameters by changing the current values randomly (``hyperparam_mutations``).



As the training of the population of neural networks progresses, this process of exploiting and exploring is performed periodically, ensuring that all the workers in the population have a good base level of performance and also that new hyperparameters are consistently explored.  This means that PBT can quickly exploit good hyperparameters, can dedicate more training time to promising models and, crucially, can adapt the hyperparameter values throughout training, leading to automatic learning of the best configurations.

In [None]:
scheduler = PopulationBasedTraining(
    time_attr="training_iteration",
    metric="mean_accuracy",
    mode="max",
    perturbation_interval=5,
    hyperparam_mutations={
        # distribution for resampling
        "lr": lambda: np.random.uniform(0.0001, 1),
        # allow perturbations within this set of categorical values
        "momentum": [0.8, 0.9, 0.99],
    }
)

In [None]:
ray.shutdown()  # Restart Ray defensively in case the ray connection is lost. 
ray.init(log_to_driver=False)


analysis = tune.run(
    PytorchTrainble,
    name="pbt_test",
    scheduler=scheduler,
    reuse_actors=True,
    verbose=1,
    stop={
        "training_iteration": 100,
    },
    num_samples=4,
    
    # PBT starts by training many neural networks in parallel with random hyperparameters. 
    config={
        "lr": tune.uniform(0.001, 1),
        "momentum": tune.uniform(0.001, 1),
    })


In [None]:
# You can use this to visualize all mutations of Population-based Training.
! cat ~/ray_results/pbt_test/pbt_global.txt

# Visualizing the results of Population-based Training

In [None]:
# Plot by wall-clock time

dfs = analysis.fetch_trial_dataframes()
# This plots everything on the same plot
ax = None
for d in dfs.values():
    ax = d.plot("training_iteration", "mean_accuracy", ax=ax, legend=False)

plt.xlabel("epoch"); plt.ylabel("Test Accuracy"); 