# Assignment 4

Now that we are finally learning about Bayesian optimization policies, it is a good idea to gain experience with actual use-cases of Bayesian/bandit optimization. In this problem, we will be tuning the hyperparameters of a CNN on the MNIST dataset using Ax (https://ax.dev/). Ax is a platform for optimizing experiements using multi-armed bandits and Bayesian optimization. Ax is built on BoTorch, and is what you would actually use for applying BayesOpt theory to an experiment.

Your job will be to read through the code, make sure you understand Ax's syntax and the decisions being made, fill in some missing pieces, and answer some questions at the end.

In [None]:
import torch
import numpy as np
from ax.plot.contour import plot_contour
from ax.plot.trace import optimization_trace_single_method
from ax.service.managed_loop import optimize
from ax.utils.notebook.plotting import render, init_notebook_plotting
from ax.utils.tutorials.cnn_utils import load_mnist, CNN

init_notebook_plotting()

In [8]:
torch.manual_seed(29)
dtype = torch.float
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Import the data

Ax has a function for loading the MNIST data directly, and calling this function will download the dataset if you are running it for the first time. 

The dataset is already contained within PyTorch DataLoader objects, and has been split into train, validation, and test sets.

In [9]:
batch_size = 512
train_dl, valid_dl, test_dl = load_mnist(batch_size=batch_size)

## Define the search space

In Ax, a search space is composed of a set of parameters to be tuned in the experiment (and optionally a set of parameter constraints, but we will not be constraining our parameters in this problem). Parameters are defined by their name, parameter type, domain, value type, and some other optional fields. Ax supports three kinds of parameters:

* Range parameters -- Domain represented by lower and upper bound
* Choice parameters -- Domain is a set of values to choose from
* Field parameters -- Domain is a single value

The search space is used by Ax's optimization algorithms (Bayesian Optimization for continuous objective functions, Bandit Optimization for problems with a finite set of choices) to know which arms are valid to suggest on a trial. An arm in Ax is a named set of parameters and their values. In our case, an arm is a hyperparameter configuration explored in the course of a given optimization.

In the cell below, we define the search space. The parameter for learning rate has already been specified to be a range parameter on \[1e-6, 0.4\] on a log scale. It is your job to define the following parameters (none of which will be on log scale):

* momentum -- range parameter between 0 and 1
* weight_decay -- range parameter between bounds of your choosing
* num_epochs -- fixed parameter with a value of 3

Note: in this example, we are defining each parameter in the search space as dictionaries, but Ax has other ways to define parameters that correspond to the different Ax APIs. Check out the Ax documentation for more information on the 3 Ax APIs.

In [None]:
parameters_to_optimize = [
    {"name": "lr", "type": "range", "bounds": [1e-6, 0.4], "value_type": "float", "log_scale": True},
    # TODO: momentum
    # TODO: weight_decay
    # TODO: num_epochs
]

## Define the train and evaluation functions for the CNN

The training function is called once per iteration (trial) of the Bayesian/Bandit optimizer. The training function is just like a normal PyTorch training loop, but the function takes in a set of parameters that have been generated by an acquisition function, and uses these to create the CNN's optimizer for this trial.

The model evaluation function takes in a trained model, evaluates it on the validation set, and outputs an overall score for the model. Ax's optimizer is trying to maximize this score as a function of the parameters (lr, momentum, weight decay).

Although I am not requiring you to define these functions, make sure to read through them and understand what is happening.

In [10]:
def train(net, train_loader, parameters, dtype, device):
    """
    Args:
        net: initialized neural network
        train_loader: DataLoader containing training set
        parameters: dictionary containing parameters to be passed to the optimizer
        dtype: torch dtype
        device: torch device
    """
    net.to(dtype=dtype, device=device)
    net.train()
    criterion = torch.nn.NLLLoss(reduction="sum")
    optimizer = torch.optim.SGD(
        net.parameters(),
        lr=parameters.get('lr'),
        momentum=parameters.get('momentum'),
        weight_decay=parameters.get('weight_decay')
    )
    num_epochs = parameters.get('num_epochs')
    
    for _ in range(num_epochs):
        for inputs, labels in train_loader:
            inputs = inputs.to(dtype=dtype, device=device)
            labels = labels.to(device=device)
            
            optimizer.zero_grad()
            
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    return net

def evaluate(net, data_loader, dtype, device):
    net.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs = inputs.to(dtype=dtype, device=device)
            labels = labels.to(device=device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    return correct / total

## Define the evaluation function for the hyperparameter optimization

This function is passed into Ax's optimizer as the overall evaluation function. The higher the returned value, the better the evaluated arm performed.

In [13]:
def train_evaluate(parameterization):
    net = CNN()
    net = train(
        net=net,
        train_loader=train_dl,
        parameters=parameterization,
        dtype=dtype,
        device=device
    )
    return evaluate(
        net=net,
        data_loader=valid_dl,
        dtype=dtype,
        device=device
    )

## Run the optimization loop

Here we pass in our parameters, the evaluation function, and an arbitrary objective name to Ax's optimize function, which runs a pre-defined managed optimization loop. In most cases, you would define your own Bayesian Optimization arm generation strategy, but this example is simple enough that this off-the-shelf loop is sufficient.

This pre-defined loop runs 20 trials. For the first 5 trials, arm values are generated using a SOBOL sampler to build a baseline dataset. All the subsequent trials are conducted using Ax's GPEI model, which is a Gaussian Process with the Expected Improvement acquisition function. The GP is fit to the current known data and EI is used to generate new points.

The optimize function returns the following:
   * best_parameters -- A dictionary in the form \<parameter name\>: \<best value\>
   * values -- means and covariances of the objective
   * experiment -- An Ax Experiment object, which keeps track of the whole optimization process and contains the search space, optimization configuration, and other metadata 
   * model -- An Ax ModelBridge object that can be used to generate new points in the search space

Run the optimization loop. This shouldn't take too long. If your parameters were set up properly, this should run without any problems. You can ignore any InputData Warnings.

In [None]:
best_parameters, values, experiment, model = optimize(
    parameters=parameters_to_optimize,
    evaluation_function=train_evaluate,
    objective_name='accuracy'
)

## Plot results

Ax has many useful visulation functions built off of plotly. plot_contour creates a contour plot showing the classification accuracy as a function of two hyperparameters. Black squares show points we have actually run, and they are clustered in the optimal region. If these plots are not displaying for you, message me on Slack and I can send you a screenshot.

In [None]:
render(plot_contour(model=model, param_x='lr', param_y='momentum', metric_name='accuracy'))

In [None]:
render(plot_contour(model=model, param_x='lr', param_y='weight_decay', metric_name='accuracy'))

In [None]:
render(plot_contour(model=model, param_x='momentum', param_y='weight_decay', metric_name='accuracy'))

Next, we plot the objective value as a function of the iteration to visualize the improvement as our BayesOpt model converges to better and better hyperparameters

In [None]:
best_objectives = np.array([[trial.objective_mean*100 for trial in experiment.trials.values()]])
best_objective_plot = optimization_trace_single_method(
    y=np.maximum.accumulate(best_objectives, axis=1),
    title="Model performance vs. # of iterations",
    ylabel="Classification Accuracy, %",
)
render(best_objective_plot)

## Train the final CNN

Finally, we retrieve the best arm from the experiment, combine our train and validation sets into a single DataLoader, and train the CNN using the hyperparameters from the best arm. We then evaluate the trained model and see that it did extremely well.

In [None]:
data = experiment.fetch_data()
df = data.df
best_arm_name = df.arm_name[df['mean']==df['mean'].max()].values[0]
best_arm = experiment.arms_by_name[best_arm_name]
best_arm

In [20]:
combined_dataset = torch.utils.data.ConcatDataset([
    train_dl.dataset.dataset,
    valid_dl.dataset.dataset
])

combined_dl = torch.utils.data.DataLoader(
    combined_dataset,
    batch_size=batch_size,
    shuffle=True
)

In [24]:
final_net = train(
    net=CNN(),
    train_loader=combined_dl,
    parameters=best_arm.parameters,
    dtype=dtype,
    device=device
)

In [25]:
test_accuracy = evaluate(
    net=final_net,
    data_loader=test_dl,
    dtype=dtype,
    device=device
)

In [None]:
print(f"Classification Accuracy (test set): {round(test_accuracy*100, 2)}%")

# Followup Questions

## Question 1

Which part of this entire tutorial seemed the most confusing to you?

## Question 2

Do you see anywhere in the optimization policy that could be improved? If so, what would you change?