# Example of a hyperparameter sweep using Weights and Biases (wandb)

Often, there are a whole bunch of hyperparameters, such as learning rate, number of layers, width of an MLP, or batch size, that need to be chosen, without it being clear in advance what values will be good picks. 

A good way of finding good hyper parameter settings, is by just trying a bunch, and seeing what works best according to some pre-defined metric. 

Ideally, you should try to do this in an automated way. Using [wandb](https://wandb.ai/) helps with this. 

Additionally the `common_dl_utils` and `common_jax_utils` packages can help with setting such a wandb sweep up (or, if you prefer not to use wandb, they can help you automate things in different ways).

## In this notebook
In this notebook, we setup a hyperparameter sweep for the same model we saw in `inr_example.ipynb`. Basically, how this works is:
* We create a config detailing all hyperparameters (both fixed and varying)
* We start a wandb sweep using this config. This results in a sweep id.
* We run experiments using this sweep id. 

This last part can be done by calling `run_from_inr_sweep.py` with the commandline argument `--sweep_id=` followed by the correct id. When doing this locally, you can just do this on the command line. When using Snellius, you'll have to create a script for the job that loads the correct environment, details what resources are needed for how long, and calls `run_from_inr_sweep.py` with the correct sweep_id.

The way `run_from_inr_sweep.py` works is that it launches a wandb "agent" with a function for running the experiment. This "agent" receives a config from wandb with picks for the hyperparameters. From that point on, basically you have a config specifying a single run, and things work very similarly to what is done in `inr_example.ipynb` from the point you have your config there. 

## Type of sweep
Weights and Biases provides three options for doing these sweeps: you can either do a grid search, a random search, or a Bayesian search. Keep in mind that you have limited computational budget, so a grid search easily becomes unfeasible. 

In most cases, a random search will likely give you the best experience (and the least headaches). When running experiments from sweeps, you have the option to have a single agent perform multiple runs in sequence. However, somehow there seems to be a bug in wandb that causes the gpu memory to not always be freed up after each run, which can lead to OOM errors (both when using JAX and when using Pytorch). The only way I've found to reliably circumvent this, is to just keep `count` set to 1, and create a bunch of agents. 

However, when doing Bayesian search it seems that each agent still tries to create a new run after its first and single run, which just creates a lot of runs in your sweep that didn't really do anything. So my advice would be: don't waste your compute budget on grid search and don't waste your good mood on Bayesian search, unless you really need to.

### Grid search
If for some reason you really do need to do grid search, there might be better ways than doing this through a wandb sweep. The tools in `common_dl_utils.config_creation` allow for creating individual run configs (just like wandb does) in a way that some variables can be linked together. 

E.g. you want to vary the latent size, and you want the hidden size to always be twice the latent size. Or you want to vary what type of activation function you use for your INR, and you want to vary some hyper parameters, such as `w0` for Siren, but only if the corresponding layer is being used. 

In such cases, you might want to use the tools from `common_dl_utils.config_creation` to just create a folder full of config files for individual runs, and create a script that loops over those configs and runs the corresponding experiment (or ideally do something smarter than this so you can have a bunch of scripts do experiments in parallel on Snellius).

### Random seeds
If you want to vary only the prng seed (for random number generation) instead of any hyperparameters (or together with those hyper parameters) you can specify a 'prng_seed' in the config (this should be an integer).

If you don't specify a prng_seed in the config, the tools used for running the experiments will create a random prng seed and log it to wandb so as to make re-producing results easier.


In [1]:
from pprint import pprint

import wandb

import common_dl_utils as cdu
from common_dl_utils.config_creation import Config, VariableCollector

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msimon-martinus-koop[0m ([33mnld[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
config = Config()
variable = VariableCollector()  # we'll use this to keep track of all varying hyperparameters
# when working with wandb, this is basically just syntactic sugar
# but if you want to do hyperparameter optimization without wandb, this can help set things up in other ways too.

# first we specify what the model should look like
config.architecture = './model_components'  # module containing all relevant classes for architectures
# NB if the classes relevant for creating the model are spread over multiple modules, this is no problem
# let config.architecture be the module that contains the "main" model class, and for all other components just specify the module
# or specify the other modules as default modules to the tools in common_jax_utils.run_utils
config.model_type = 'inr_modules.CombinedINR'

config.model_config = Config()
config.model_config.in_size = 3
config.model_config.out_size = 1
config.model_config.terms = [  
    # CombinedINR uses multiple MLPs and returns the sum of their outputs. These 'terms' are the MLPs
    ('inr_modules.MLPINR.from_config',{
        'hidden_size': 256,  # you can also specify the probability of each option
        'num_layers': 5,  # indicate the options that the hyperparameter sweep can choose from
        'layer_type': 'inr_layers.SirenLayer', 
        'num_splits': 1,
        'use_complex': False,
        'activation_kwargs': {'w0': variable(distribution="uniform", min=10., max=40.)}, # or specify a distribution from https://docs.wandb.ai/guides/sweeps/sweep-config-keys#distribution-options-for-random-and-bayesian-search
        'initialization_scheme':'initialization_schemes.siren_scheme',
        #'positional_encoding_layer': ('inr_layers.ClassicalPositionalEncoding.from_config', {'num_frequencies': 10}),
    }),
]
config.trainer_module = './inr_utils/'  # similarly to config.architecture above, here we just specify in what module to look for objects by default
config.trainer_type = 'training.train_with_dataloader_scan'


config.dataloader = 'sdf.SDFDataLoader'

config.dataloader_config = {
    "sdf_name": "Armadillo",
    "batch_size": 200000,
    "keep_aspect_ratio":True

}

config.num_cycles = 100
config.steps_per_cycle = 200


config.loss_evaluator = "losses.SDFLossEvaluator"


config.target_function = 'sdf.SDFDataLoader' #see when config. losseval
config.target_function_config = {
    "sdf_name": "Armadillo",
    "batch_size": 200000,
    "keep_aspect_ratio":True,

}

#config.state_update_function = ('state_test_objects.py', 'counter_updater')

config.optimizer = 'training.OptimizerFactory.single_optimizer'#'adamw'  # we'll have to add optax to the additional default modules later
config.optimizer_type = 'adamw'
config.optimizer_config = {
    'learning_rate': variable(distribution='log_uniform_values', min=1e-5, max=1e-3),#1.5e-4
    'weight_decay': variable(distribution='log_uniform_values', min=1e-5, max=1e-3),
}
config.optimizer_mask = 'masking.array_mask'

config.use_wandb = True

# now we want some extra things, like logging, to happen during training
# the inr_utils.training.train_inr function allows for this through callbacks.
# The callbacks we want to use can be found in inr_utils.callbacks
config.after_cycle_callback = 'callbacks.ComposedCallback'
config.after_cycle_callback_config = {
    'callbacks':[
        ('callbacks.print_loss', {'after_every':1}),  # only print the loss every 400th step
        'callbacks.report_loss',  # but log the loss to wandb after every step
        ('callbacks.MetricCollectingCallback', # this thing will help us collect metrics and log images to wandb
             {'metric_collector':'metrics.MetricCollector'}
        ),
        'callbacks.raise_error_on_nan'  # stop training if the loss becomes NaN
    ],
    'show_logs': False
}
# config.after_step_callback = 'callbacks.ComposedCallback'
# config.after_step_callback_config = {
#     'callbacks':[
#         ('callbacks.print_loss', {'after_every':400}),  # only print the loss every 400th step
#         'callbacks.report_loss',  # but log the loss to wandb after every step
#         ('callbacks.MetricCollectingCallback', # this thing will help us collect metrics and log images to wandb
#              {'metric_collector':'metrics.MetricCollector'}
#         ),
#         'callbacks.raise_error_on_nan'  # stop training if the loss becomes NaN
#     ],
#     'show_logs': False
# }

#config.after_training_callback = ('state_test_objects.py', 'after_training_callback')
config.after_training_callback = None
config.metric_collector_config = {  # the metrics for MetricCollectingCallback / metrics.MetricCollector
    'metrics':[
        ('metrics.JaccardIndexSDF', {
            'frequency':'every_n_batches',
            'grid_resolution': 100,
            'num_dims': 3,
            'batch_size': 10000
        }),
        ("metrics.SDFReconstructor",
         {
            'frequency':'every_n_batches',
            'grid_resolution': 100,
            'batch_size': 10000,
        }),
        # todo add view rendering here

    ],
    'batch_frequency': 4,  # compute all of these metrics every 400 batches
    'epoch_frequency': 1  # not actually used
}

#config.after_training_callback = None  # don't care for one now, but you could have this e.g. store some nice loss plots if you're not using wandb 
config.optimizer_state = None  # we're starting from scratch

Originally, wandb didn't really deal with nested configurations like the one above very well. I think nowadays, the situation might be better, but I don't care to give myself a headache by finding out what the exact caveats to that are. 

So instead we'll make a flat config out of the nested config above
and the `common_jax_utils.wandb_utils.run_from_wandb` function that we use to actually run the experiments, will unflatten it when needed. 

In [3]:

flat_parameter_config = cdu.config_creation.make_flat_config(config)
#pprint(flat_parameter_config)  # uncomment to see what the flattened config looks like

Finally, we setup a sweep config detailing the method of the sweep (random in this case), the metric that is to be tracked, and the above `flat_parameter_config`.

In [4]:
sweep_config = {
    'name': 'sdf_example_sweep',
    'method': 'random',
    'metric': {'name':'jaccard_index', 'goal':'maximize'},
    'parameters': flat_parameter_config,
    'description': 'An example of a hyperparameter sweep for training an INR'
}

In [5]:
sweep_id = wandb.sweep(sweep_config, entity="nld", project="inr_edu_24")
print(f"nohup python run_from_inr_sweep.py --sweep_id={sweep_id} > sdf_sweep_example.out")

Create sweep with ID: 04jhl7r9
Sweep URL: https://wandb.ai/nld/inr_edu_24/sweeps/04jhl7r9
nohup python run_from_inr_sweep.py --sweep_id=04jhl7r9 > sdf_sweep_example.out


Next up, perform a single run for this sweep locally by typing `nohup python run_from_inr_sweep.py --sweep_id=0t56ux9d > inr_sweep_example.out`

Or run it on Snellius by submitting a job script that logs into wandb and runs `run_from_inr_sweep.py` with the correct sweep_id.