# Tutorial: Population-based Training with Tune


### In this tutorial, we'll show you how to use Population-based Training with Tune. 

<img src="tune-pbt.png" alt="Tune Logo" width="600"/>

https://deepmind.com/blog/population-based-training-neural-networks

PBT trains a group of models (or agents) in parallel. Periodically, poorly performing models clone the state of the top performers, and a random mutation is applied to their hyperparameters in the hopes of outperforming the current top models.

The goal is to optimize this trainable's accuracy. The accuracy increases
fastest at the optimal lr, which is a function of the current accuracy.

The optimal lr schedule for this problem is the triangle wave as follows.
Note that many lr schedules for real models also follow this shape:

     best lr
      ^
      |    /\
      |   /  \
      |  /    \
      | /      \
      ------------> accuracy

In this problem, using PBT with a population of 2-4 is sufficient to
roughly approximate this lr schedule. Higher population sizes will yield
faster convergence. 

Note that training will not converge without PBT.


In [1]:
import numpy as np
import argparse
import random

import ray
from ray.tune import Trainable, run
from ray.tune.schedulers import PopulationBasedTraining

In [None]:
class PBTBenchmarkExample(Trainable):
    def _setup(self, config):
        self.lr = config["lr"]
        
        # penalize exceeding lr by more than this multiple
        self.tolerance = config["tolerance"]
        self.score = 0.0  # end = 1000

    def _train(self):
        midpoint = 100  # lr starts decreasing after acc > midpoint
        noise_level = 2  # add gaussian noise to the acc increase
        # triangle wave:
        #  - start at 0.001 @ t=0,
        #  - peak at 0.01 @ t=midpoint,
        #  - end at 0.001 @ t=midpoint * 2,
        if self.score < midpoint:
            optimal_lr = 0.01 * self.score / midpoint
        else:
            optimal_lr = 0.01 - 0.01 * (self.score - midpoint) / midpoint
        optimal_lr = min(0.01, max(0.001, optimal_lr))

        # compute accuracy increase
        q_err = max(self.lr, optimal_lr) / min(self.lr, optimal_lr)
        if q_err < self.tolerance:
            self.score += (1.0 / q_err) * random.random()
        elif self.lr > optimal_lr:
            self.accuracy -= (q_err - self.tolerance) * random.random()
        self.score += noise_level * np.random.normal()
        self.score = max(0, self.score)

        return {
            "mean_accuracy": self.score,
            "cur_lr": self.lr,
            "optimal_lr": optimal_lr,  # for debugging
            "q_err": q_err,  # for debugging
            "done": self.accuracy > midpoint * 2,
        }

    def _save(self, checkpoint_dir):
        return {"accuracy": self.accuracy}

    def _restore(self, checkpoint):
        self.accuracy = checkpoint["accuracy"]

    def reset_config(self, new_config):
        self.lr = new_config["lr"]
        self.tolerance = new_config["tolerance"]
        return True


In [None]:
pbt = PopulationBasedTraining(
    time_attr="training_iteration",
    metric="mean_accuracy",
    mode="max",
    perturbation_interval=20,
    hyperparam_mutations={
        # distribution for resampling
        "lr": lambda: random.uniform(0.0001, 0.02),
        # allow perturbations within this set of categorical values
        "tolerance": [2, 3, 10],
    }
)

tune.run(
    PBTBenchmarkExample,
    name="pbt_test",
    scheduler=pbt,
    reuse_actors=True,
    stop={"training_iteration": 2000},
    num_samples=4,
    config={
        "lr": 0.0001,
        # note: this parameter is perturbed but has no effect on
        # the model training in this example
        "tolerance": 3,
    })
