In [None]:
#export
from fastai2.basics import *
from fastai2.callback.progress import *
from fastai2.text.data import TensorText
from fastai2.tabular.all import TabularDataLoaders, Tabular
from fastai2.callback.hook import total_params

In [None]:
from nbdev.showdoc import *

In [None]:
#default_exp callback.wandb

# Wandb

> Integration with [Weights & Biases](https://docs.wandb.com/) 

First thing first, you need to install wandb with
```
pip install wandb
```
Create a free account then run 
``` 
wandb login
```
in your terminal. Follow the link to get an API token that you will need to paste, then you're all set!

In [None]:
#export
import wandb
from wandb.wandb_config import ConfigError

In [None]:
#export
class WandbCallback(Callback):
    "Saves model topology, losses & metrics"
    toward_end,remove_on_fetch,run_after = True,True,FetchPredsCallback
    # Record if watch has been called previously (even in another instance)
    _wandb_watch_called = False

    def __init__(self, log="gradients", log_preds=True, valid_dl=None, n_preds=36, seed=12345):
        # Check if wandb.init has been called
        if wandb.run is None:
            raise ValueError('You must call wandb.init() before WandbCallback()')
        # W&B log step
        self._wandb_step = wandb.run.step - 1  # -1 except if the run has previously logged data (incremented at each batch)
        self._wandb_epoch = 0 if not(wandb.run.step) else math.ceil(wandb.run.summary['epoch']) # continue to next epoch
        store_attr(self, 'log,log_preds,valid_dl,n_preds,seed')

    def begin_fit(self):
        "Call watch method to log model topology, gradients & weights"
        self.run = not hasattr(self.learn, 'lr_finder') and not hasattr(self, "gather_preds") and rank_distrib()==0
        if not self.run: return

        # Log config parameters
        log_config = self.learn.gather_args()
        _format_config(log_config)
        try:
            # Log all parameters at once
            wandb.config.update(log_config, allow_val_change=True)
        except Exception as e:
            print(f'WandbCallback could not log config parameters -> {e}')

        if not WandbCallback._wandb_watch_called:
            WandbCallback._wandb_watch_called = True
            # Logs model topology and optionally gradients and weights
            wandb.watch(self.learn.model, log=self.log)

        if hasattr(self, 'save_model'): self.save_model.add_save = Path(wandb.run.dir)/'bestmodel.pth'

        if self.log_preds:
            try:
                if not self.valid_dl:
                    #Initializes the batch watched
                    wandbRandom = random.Random(self.seed)  # For repeatability
                    self.n_preds = min(self.n_preds, len(self.dls.valid_ds))
                    idxs = wandbRandom.sample(range(len(self.dls.valid_ds)), self.n_preds)
                    if isinstance(self.dls,  TabularDataLoaders):
                        test_items = getattr(self.dls.valid_ds.items, 'iloc', self.dls.valid_ds.items)[idxs]
                        self.valid_dl = self.dls.test_dl(test_items, with_labels=True, process=False)
                    else:
                        test_items = [getattr(self.dls.valid_ds.items, 'iloc', self.dls.valid_ds.items)[i] for i in idxs]
                        self.valid_dl = self.dls.test_dl(test_items, with_labels=True)
                self.learn.add_cb(FetchPredsCallback(dl=self.valid_dl, with_input=True, with_decoded=True))
            except Exception as e:
                self.log_preds = False
                print(f'WandbCallback was not able to prepare a DataLoader for logging prediction samples -> {e}')

    def after_batch(self):
        "Log hyper-parameters and training loss"
        if self.training:
            self._wandb_step += 1
            self._wandb_epoch += 1/self.n_iter
            hypers = {f'{k}_{i}':v for i,h in enumerate(self.opt.hypers) for k,v in h.items()}
            wandb.log({'epoch': self._wandb_epoch, 'train_loss': self.smooth_loss, 'raw_loss': self.loss, **hypers}, step=self._wandb_step)

    def after_epoch(self):
        "Log validation loss and custom metrics & log prediction samples"
        # Correct any epoch rounding error and overwrite value
        self._wandb_epoch = round(self._wandb_epoch)
        wandb.log({'epoch': self._wandb_epoch}, step=self._wandb_step)
        # Log sample predictions
        if self.log_preds:
            try:
                inp,preds,targs,out = self.learn.fetch_preds.preds
                b = tuplify(inp) + tuplify(targs)
                x,y,its,outs = self.valid_dl.show_results(b, out, show=False, max_n=self.n_preds)
                wandb.log(wandb_process(x, y, its, outs), step=self._wandb_step)
            except Exception as e:
                self.log_preds = False
                print(f'WandbCallback was not able to get prediction samples -> {e}')
        wandb.log({n:s for n,s in zip(self.recorder.metric_names, self.recorder.log) if n not in ['train_loss', 'epoch', 'time']}, step=self._wandb_step)

    def after_fit(self):
        self.run = True
        if self.log_preds: self.remove_cb(FetchPredsCallback)
        wandb.log({}) # ensure sync of last step

Optionally logs weights and or gradients depending on `log` (can be "gradients", "parameters", "all" or None), sample predictions if ` log_preds=True` that will come from `valid_dl` or a random sample pf the validation set (determined by `seed`). `n_preds` are logged in this case.

If used in combination with `SaveModelCallback`, the best model is saved as well.

## Example of use:

Once your have defined your `Learner`, before you call to `fit` or `fit_one_cycle`, you need to initialize wandb:
```
import wandb
wandb.init()
```
To use Weights & Biases without an account, you can call `wandb.init(anonymous='allow')`.

Then you add the callback in your call to fit, potentially with `SaveModelCallback` if you want to save the best model:
```
from fastai2.callback.wandb import *

# To log only during one training phase
learn.fit(..., cbs=WandbCallback())

# To log continuously for all training phases
learn = learner(..., cbs=WandbCallback())
```

In [None]:
#export
def _make_plt(img):
    "Make plot to image resolution"
    # from https://stackoverflow.com/a/13714915
    my_dpi = 100
    fig = plt.figure(frameon=False, dpi=my_dpi)
    h, w = img.shape[:2]
    fig.set_size_inches(w / my_dpi, h / my_dpi)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    return fig, ax

In [None]:
#export
def _format_config(log_config):
    "Format config parameters before logging them"
    for k,v in log_config.items():
        if callable(v):
            if hasattr(v,'__qualname__') and hasattr(v,'__module__'): log_config[k] = f'{v.__module__}.{v.__qualname__}'
            else: log_config[k] = str(v)
        if isinstance(v, slice): log_config[k] = dict(slice_start=v.start, slice_step=v.step, slice_stop=v.stop)

In [None]:
#export
@typedispatch
def wandb_process(x:TensorImage, y, samples, outs):
    "Process `sample` and `out` depending on the type of `x/y`"
    res_input, res_pred, res_label = [],[],[]
    for s,o in zip(samples, outs):
        img = s[0].permute(1,2,0)
        res_input.append(wandb.Image(img, caption='Input data'))
        for t, capt, res in ((o[0], "Prediction", res_pred), (s[1], "Ground Truth", res_label)):
            fig, ax = _make_plt(img)
            # Superimpose label or prediction to input image
            ax = img.show(ctx=ax)
            ax = t.show(ctx=ax)
            res.append(wandb.Image(fig, caption=capt))
            plt.close(fig)
    return {"Inputs":res_input, "Predictions":res_pred, "Ground Truth":res_label}

In [None]:
#export
@typedispatch
def wandb_process(x:TensorImage, y:(TensorCategory,TensorMultiCategory), samples, outs):
    return {"Prediction Samples": [wandb.Image(s[0].permute(1,2,0), caption=f'Ground Truth: {s[1]}\nPrediction: {o[0]}')
            for s,o in zip(samples,outs)]}

In [None]:
#export
@typedispatch
def wandb_process(x:TensorImage, y:TensorMask, samples, outs):
    res = []
    class_labels = {i:f'{c}' for i,c in enumerate(y.get_meta('codes'))} if y.get_meta('codes') is not None else None
    for s,o in zip(samples, outs):
        img = s[0].permute(1,2,0)
        masks = {}
        for t, capt in ((o[0], "Prediction"), (s[1], "Ground Truth")):
            masks[capt] = {'mask_data':t.numpy().astype(np.uint8)}
            if class_labels: masks[capt]['class_labels'] = class_labels
        res.append(wandb.Image(img, masks=masks))
    return {"Prediction Samples":res}

In [None]:
#export
@typedispatch
def wandb_process(x:TensorText, y:(TensorCategory,TensorMultiCategory), samples, outs):
    data = [[s[0], s[1], o[0]] for s,o in zip(samples,outs)]
    return {"Prediction Samples": wandb.Table(data=data, columns=["Text", "Target", "Prediction"])}

In [None]:
#export
@typedispatch
def wandb_process(x:Tabular, y:Tabular, samples, outs):
    df = x.all_cols
    for n in x.y_names: df[n+'_pred'] = y[n].values
    return {"Prediction Samples": wandb.Table(dataframe=df)}

In [None]:
#hide
#slow
from fastai2.vision.all import *

path = untar_data(URLs.MNIST_TINY)
items = get_image_files(path)
tds = Datasets(items, [PILImageBW.create, [parent_label, Categorize()]], splits=GrandparentSplitter()(items))
dls = tds.dataloaders(after_item=[ToTensor(), IntToFloatTensor()])

os.environ['WANDB_MODE'] = 'dryrun' # run offline
wandb.init(anonymous='allow')
learn = cnn_learner(dls, resnet18, loss_func=CrossEntropyLossFlat(), cbs=WandbCallback())
learn.fit(1)

# add more data from a new learner on same run
learn = cnn_learner(dls, resnet18, loss_func=CrossEntropyLossFlat(), cbs=WandbCallback())
learn.fit(1, lr=slice(0.05))

epoch,train_loss,valid_loss,time
0,0.710007,0.229901,00:01


epoch,train_loss,valid_loss,time
0,1.742084,7.40317,00:01


In [None]:
#export
_all_ = ['wandb_process']

## Export -

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision.l