# Using Weights & Biases with Tune

(tune-wandb-ref)=

[Weights & Biases](https://www.wandb.ai/) (Wandb) is a tool for experiment
tracking, model optimizaton, and dataset versioning. It is very popular
in the machine learning and data science community for its superb visualization
tools.

```{image} /images/wandb_logo_full.png
:align: center
:alt: Weights & Biases
:height: 80px
:target: https://www.wandb.ai/
```

Ray Tune currently offers two lightweight integrations for Weights & Biases.
One is the {ref}`WandbLoggerCallback <tune-wandb-logger>`, which automatically logs
metrics reported to Tune to the Wandb API.

The other one is the {ref}`@wandb_mixin <tune-wandb-mixin>` decorator, which can be
used with the function API. It automatically
initializes the Wandb API with Tune's training information. You can just use the
Wandb API like you would normally do, e.g. using `wandb.log()` to log your training
process.

```{contents}
:backlinks: none
:local: true
```

## Running A Weights & Biases Example

In the following example we're going to use both of the above methods, namely the `WandbLoggerCallback` and
the `setup_wandb` function to log metrics.
Let's start with a few crucial imports:

In [1]:
import numpy as np

import ray
from ray import air, tune
from ray.air import session
from ray.air.integrations.wandb import setup_wandb
from ray.air.integrations.wandb import WandbLoggerCallback

Next, let's define an easy `train_function` function (a Tune `Trainable`) that reports a random loss to Tune.
The objective function itself is not important for this example, since we want to focus on the Weights & Biases
integration primarily.

In [2]:
def train_function(config):
    for i in range(30):
        loss = config["mean"] + config["sd"] * np.random.randn()
        session.report({"loss": loss})

Given that you provide an `api_key_file` pointing to your Weights & Biases API key, you can define a
simple grid-search Tune run using the `WandbLoggerCallback` as follows:

In [3]:
def tune_with_callback(api_key_file):
    """Example for using a WandbLoggerCallback with the function API"""
    tuner = tune.Tuner(
        train_function,
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
        ),
        run_config=air.RunConfig(
            callbacks=[
                WandbLoggerCallback(api_key_file=api_key_file, project="Wandb_example")
            ]
        ),
        param_space={
            "mean": tune.grid_search([1, 2, 3, 4, 5]),
            "sd": tune.uniform(0.2, 0.8),
        },
    )
    tuner.fit()

To use the `setup_wandb` utility, you simply call this function in your objective.
Note that we also use `wandb.log(...)` to log the `loss` to Weights & Biases as a dictionary.
Otherwise, this version of our objective is identical to its original.

In [4]:
def train_function_wandb(config):
    wandb = setup_wandb(config)

    for i in range(30):
        loss = config["mean"] + config["sd"] * np.random.randn()
        session.report({"loss": loss})
        wandb.log(dict(loss=loss))

With the `train_function_wandb` defined, running a Tune experiment is as simple as providing this objective and
passing the `api_key_file` to the `wandb` key of your Tune `config`:

In [5]:
def tune_with_setup(api_key_file):
    """Example for using the setup_wandb utility with the function API"""
    tuner = tune.Tuner(
        train_function_wandb,
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
        ),
        param_space={
            "mean": tune.grid_search([1, 2, 3, 4, 5]),
            "sd": tune.uniform(0.2, 0.8),
            "wandb": {"api_key_file": api_key_file, "project": "Wandb_example"},
        },
    )
    tuner.fit()

Finally, you can also define a class-based Tune `Trainable` by using the `setup_wandb` in the `setup()` method and storing the run object as an attribute. Please note that with the class trainable, you have to pass the trial id, name, and group separately:

In [6]:
class WandbTrainable(tune.Trainable):
    def setup(self, config):
        self.wandb = setup_wandb(
            config, trial_id=self.trial_id, trial_name=self.trial_name, group="Example"
        )

    def step(self):
        for i in range(30):
            loss = self.config["mean"] + self.config["sd"] * np.random.randn()
            self.wandb.log({"loss": loss})
        return {"loss": loss, "done": True}
    
    def save_checkpoint(self, checkpoint_dir: str):
        pass
    
    def load_checkpoint(self, checkpoint_dir: str):
        pass

Running Tune with this `WandbTrainable` works exactly the same as with the function API.
The below `tune_trainable` function differs from `tune_decorated` above only in the first argument we pass to
`Tuner()`:

In [7]:
def tune_trainable(api_key_file):
    """Example for using a WandTrainableMixin with the class API"""
    tuner = tune.Tuner(
        WandbTrainable,
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
        ),
        param_space={
            "mean": tune.grid_search([1, 2, 3, 4, 5]),
            "sd": tune.uniform(0.2, 0.8),
            "wandb": {"api_key_file": api_key_file, "project": "Wandb_example"},
        },
    )
    results = tuner.fit()

    return results.get_best_result().config

Since you may not have an API key for Wandb, we can _mock_ the Wandb logger and test all three of our training
functions as follows.
If you do have an API key file, make sure to set `mock_api` to `False` and pass in the right `api_key_file` below.

In [8]:
import os

mock_api = True

api_key_file = "~/.wandb_api_key"

if mock_api:
    os.environ.setdefault("WANDB_MODE", "disabled")
    ray.init(runtime_env={"env_vars": {"WANDB_MODE": "disabled"}})

tune_with_callback(api_key_file)
tune_with_setup(api_key_file)
tune_trainable(api_key_file)

2022-11-02 13:10:07,806	INFO worker.py:1524 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8266 [39m[22m


0,1
Current time:,2022-11-02 13:10:46
Running for:,00:00:37.37
Memory:,10.5/16.0 GiB

Trial name,status,loc,mean,sd,iter,total time (s),loss
train_function_58e30_00000,TERMINATED,127.0.0.1:5393,1,0.615867,30,0.22813,1.58301
train_function_58e30_00001,TERMINATED,127.0.0.1:5401,2,0.467716,30,17.521,2.67345
train_function_58e30_00002,TERMINATED,127.0.0.1:5403,3,0.458005,30,17.4535,3.7088
train_function_58e30_00003,TERMINATED,127.0.0.1:5404,4,0.501836,30,7.45093,4.66316
train_function_58e30_00004,TERMINATED,127.0.0.1:5405,5,0.567112,30,7.443,5.37309


Trial name,date,done,episodes_total,experiment_id,experiment_tag,hostname,iterations_since_restore,loss,node_ip,pid,time_since_restore,time_this_iter_s,time_total_s,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
train_function_58e30_00000,2022-11-02_13-10-15,True,,9ed8d9e26b5c4a06955fdf798c3a6b07,"0_mean=1,sd=0.6159",Kais-MBP.local.meter,30,1.58301,127.0.0.1,5393,0.22813,0.00403094,0.22813,1667419815,0,,30,58e30_00000,0.0032289
train_function_58e30_00001,2022-11-02_13-10-36,True,,4a7b5adcdd9a4e919a14a51f61ed20c2,"1_mean=2,sd=0.4677",Kais-MBP.local.meter,30,2.67345,127.0.0.1,5401,17.521,0.00539875,17.521,1667419836,0,,30,58e30_00001,0.00279927
train_function_58e30_00002,2022-11-02_13-10-36,True,,f06041da92d248709e3e495968832509,"2_mean=3,sd=0.4580",Kais-MBP.local.meter,30,3.7088,127.0.0.1,5403,17.4535,0.00603294,17.4535,1667419836,0,,30,58e30_00002,0.0028522
train_function_58e30_00003,2022-11-02_13-10-26,True,,f349852e34a64c1ea185eb08233d9bfd,"3_mean=4,sd=0.5018",Kais-MBP.local.meter,30,4.66316,127.0.0.1,5404,7.45093,0.0273881,7.45093,1667419826,0,,30,58e30_00003,0.00311399
train_function_58e30_00004,2022-11-02_13-10-26,True,,3d90543e07074e71be119558c26cab6f,"4_mean=5,sd=0.5671",Kais-MBP.local.meter,30,5.37309,127.0.0.1,5405,7.443,0.00747919,7.443,1667419826,0,,30,58e30_00004,0.00292921


2022-11-02 13:10:46,589	INFO tune.py:788 -- Total run time: 38.75 seconds (37.36 seconds for the tuning loop).


0,1
Current time:,2022-11-02 13:10:56
Running for:,00:00:10.13
Memory:,11.8/16.0 GiB

Trial name,status,loc,mean,sd,iter,total time (s),loss
train_function_wandb_70004_00000,TERMINATED,127.0.0.1:5465,1,0.742524,30,3.87935,1.20014
train_function_wandb_70004_00001,TERMINATED,127.0.0.1:5477,2,0.721663,30,2.52346,2.11307
train_function_wandb_70004_00002,TERMINATED,127.0.0.1:5478,3,0.731064,30,2.53241,2.09208
train_function_wandb_70004_00003,TERMINATED,127.0.0.1:5479,4,0.510719,30,2.57405,3.71994
train_function_wandb_70004_00004,TERMINATED,127.0.0.1:5480,5,0.203644,30,2.48766,4.99246


Trial name,date,done,episodes_total,experiment_id,experiment_tag,hostname,iterations_since_restore,loss,node_ip,pid,time_since_restore,time_this_iter_s,time_total_s,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
train_function_wandb_70004_00000,2022-11-02_13-10-54,True,,9f99f093433d4113a7660c864e126115,"0_mean=1,sd=0.7425",Kais-MBP.local.meter,30,1.20014,127.0.0.1,5465,3.87935,0.00325608,3.87935,1667419854,0,,30,70004_00000,0.00281096
train_function_wandb_70004_00001,2022-11-02_13-10-56,True,,fdaf299974584907b5bf4e007921db75,"1_mean=2,sd=0.7217",Kais-MBP.local.meter,30,2.11307,127.0.0.1,5477,2.52346,0.015337,2.52346,1667419856,0,,30,70004_00001,0.00400376
train_function_wandb_70004_00002,2022-11-02_13-10-56,True,,cb3bfe4e2c194b0ab3b544da0da082db,"2_mean=3,sd=0.7311",Kais-MBP.local.meter,30,2.09208,127.0.0.1,5478,2.53241,0.00435305,2.53241,1667419856,0,,30,70004_00002,0.00385976
train_function_wandb_70004_00003,2022-11-02_13-10-56,True,,4011756369f64cdaba9f628a47f7e91c,"3_mean=4,sd=0.5107",Kais-MBP.local.meter,30,3.71994,127.0.0.1,5479,2.57405,0.00265694,2.57405,1667419856,0,,30,70004_00003,0.00383115
train_function_wandb_70004_00004,2022-11-02_13-10-56,True,,e8e6a556f5684af88d20ffd52d166faa,"4_mean=5,sd=0.2036",Kais-MBP.local.meter,30,4.99246,127.0.0.1,5480,2.48766,0.0136211,2.48766,1667419856,0,,30,70004_00004,0.00364685


2022-11-02 13:10:56,869	INFO tune.py:788 -- Total run time: 10.25 seconds (10.11 seconds for the tuning loop).


0,1
Current time:,2022-11-02 13:11:14
Running for:,00:00:17.53
Memory:,12.0/16.0 GiB

Trial name,status,loc,mean,sd,iter,total time (s),loss
WandbTrainable_76213_00000,TERMINATED,127.0.0.1:5538,1,0.796319,1,0.000174046,1.91921
WandbTrainable_76213_00001,TERMINATED,127.0.0.1:5554,2,0.370656,1,0.000159979,1.5508
WandbTrainable_76213_00002,TERMINATED,127.0.0.1:5571,3,0.456557,1,0.000174046,3.77364
WandbTrainable_76213_00003,TERMINATED,127.0.0.1:5572,4,0.532619,1,0.000191212,4.35208
WandbTrainable_76213_00004,TERMINATED,127.0.0.1:5573,5,0.472552,1,0.000159979,4.9846


Trial name,date,done,episodes_total,experiment_id,hostname,iterations_since_restore,loss,node_ip,pid,time_since_restore,time_this_iter_s,time_total_s,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
WandbTrainable_76213_00000,2022-11-02_13-11-03,True,,be9b77b1b3944a3c9790103a629c90f9,Kais-MBP.local.meter,1,1.91921,127.0.0.1,5538,0.000174046,0.000174046,0.000174046,1667419863,0,,1,76213_00000,1.5337
WandbTrainable_76213_00001,2022-11-02_13-11-07,True,,510e386820594f1a8aa1ed5c47e5d0df,Kais-MBP.local.meter,1,1.5508,127.0.0.1,5554,0.000159979,0.000159979,0.000159979,1667419867,0,,1,76213_00001,1.3163
WandbTrainable_76213_00002,2022-11-02_13-11-14,True,,03f14a58c68947d0afa7d7fcc8fff250,Kais-MBP.local.meter,1,3.77364,127.0.0.1,5571,0.000174046,0.000174046,0.000174046,1667419874,0,,1,76213_00002,1.31631
WandbTrainable_76213_00003,2022-11-02_13-11-14,True,,e0497ed8982644dab607ee58ff9f0f46,Kais-MBP.local.meter,1,4.35208,127.0.0.1,5572,0.000191212,0.000191212,0.000191212,1667419874,0,,1,76213_00003,1.31873
WandbTrainable_76213_00004,2022-11-02_13-11-14,True,,fcdd4e1fab9f48aa98be60f64ac3dbeb,Kais-MBP.local.meter,1,4.9846,127.0.0.1,5573,0.000159979,0.000159979,0.000159979,1667419874,0,,1,76213_00004,1.31499


2022-11-02 13:11:14,550	INFO tune.py:788 -- Total run time: 17.65 seconds (17.51 seconds for the tuning loop).


{'mean': 2,
 'sd': 0.3706555346739163,
 'wandb': {'api_key_file': '~/.wandb_api_key', 'project': 'Wandb_example'}}

This completes our Tune and Wandb walk-through.
In the following sections you can find more details on the API of the Tune-Wandb integration.

## Tune Wandb API Reference

### WandbLoggerCallback

(tune-wandb-logger)=

```{eval-rst}
.. autoclass:: ray.air.integrations.wandb.WandbLoggerCallback
   :noindex:
```

### Wandb-Mixin

(tune-wandb-mixin)=

```{eval-rst}
.. autofunction:: ray.tune.integration.wandb.wandb_mixin
   :noindex:
```