<img src="http://wandb.me/logo-im-png" width="400" alt="Weights & Biases"/> <br>

<!--- @wandbcode{keras-wandbcallback-demo, v=examples} -->

<img src="http://wandb.me/mini-diagram" width="600" alt="Weights & Biases"/>

<a href="https://colab.research.google.com/github/wandb/examples/blob/master/keras/Keras_pipeline_with_Weights_and_Biases.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🍀 Integrate Weights and Biases in your TensorFlow/Keras workflow

In this example, we will we will train an image classifier for the **bloodMNIST** dataset. The primary focus however will be the use of Weights and Biases and how easily it can be added in your TensorFlow/Keras workflow. 

Consider **[Weights and Biases](https://wandb.ai/site)** (W&B) to be the GitHub for machine learning. Use W&B for machine learning experiment tracking, dataset and model versioning, project collaboration, hyperparameter optimization, dataset exploration, model evaluation and so much more.

W&B comes with a lightweight **[integration for Keras](https://docs.wandb.ai/guides/integrations/keras)** (`WandbCallback`) and with just a few lines of code you can log your metrics, save model, training configuration, evaluate model and more. W&B is intrumented with most of your favourite machine learning frameworks. 

This notebook also introduces **[W&B Tables](https://docs.wandb.ai/guides/data-vis)**. Tables accelerate the ML development lifecycle by giving users the ability to rapidly extract meaningful insights from the data. The WB Table Visualizer provides an interactive interface to perform powerful analytics functions like grouping, joining, and creating custom fields while simultaneously supporting rich media annotations such as bounding boxes and segmentation masks.

# 🌴 In this Notebook

In this colab we'll cover:

- training an image classifier for medMNIST (bloodMNIST) dataset,
- use of W&B Tables for dataset exploration and evaluation,
- `WandbCallback` for experiment tracking and model evaluation.

In addition we will also cover some of the best practices of using Weights and Biases to get the most out of your data and model.

We will start by installing the dependencies and importing required libraries.

In [None]:
# For Weights and Biases
!pip install -qq wandb
# To download the dataset
!pip install -qq medmnist

In [None]:
# General Dependencies
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
%matplotlib inline

# For Deep Learning
import tensorflow as tf
print("TF: ", tf.__version__)
from tensorflow.keras import layers
from tensorflow.keras import models

# For MLOps
import wandb
print("W&B: ", wandb.__version__)
from wandb.keras import WandbCallback

# For medMNIST dataset
import medmnist
print("medMNIST: ", medmnist.__version__)
from medmnist import INFO

If this is your first time using W&B or you are not logged in, the link that appears after running `wandb.login()` will take you to sign-up/login page. Signing up for a [free account](https://wandb.ai/signup) is as easy as a few clicks.

In [None]:
# Login to W&B
wandb.login()

# 🎋 Configs

Configuration files in `.yaml` or `.json` format is an integral part of most mature machine learning systems. Keeping the track of hyperparameters used to train/evaluate your model is essential for reproducing the experiments. 

W&B can keep track of your configs. Here we will first define all the hyperparameters needed for training our classifier. 

In [None]:
configs = dict(
    data_flag = 'bloodmnist',
    image_width = 32,
    image_height = 32,
    batch_size = 128,
    model_name = 'vgg16',
    pretrain_weights = 'imagenet',
    epochs = 100,
    init_learning_rate = 0.001,
    lr_decay_rate = 0.1,
    optimizer = 'adam',
    loss_fn = 'sparse_categorical_crossentropy',
    metrics = ['acc'],
    earlystopping_patience = 5
)

# 🍁 Prepare Dataset

[MedMNIST](https://medmnist.com/) is a large-scale MNIST-like collection of standardized biomedical images, including 12 datasets for 2D and 6 datasets for 3D. All the images are pre-processed to image size of `28x 28` and doesn't require any prior domain knowledge to start with. 

In this tutorial, we will be using `BloodMNIST` dataset. From the dataset description. 

> The BloodMNIST is based on a dataset of individual normal cells, captured from individuals without infection, hematologic or oncologic disease and free of any pharmacologic treatment at the moment of blood collection. It contains a total of 17,092 images and is organized into 8 classes. We split the source dataset with a ratio of 7:1:2 into training, validation and test set. The source images with resolution 3×360×363 pixels are center-cropped into 3×200×200, and then resized into 3×28×28.

In [None]:
info = INFO[configs['data_flag']]
configs['class_names'] = info['label']
configs['image_channels'] = info['n_channels']

info

Each MedMNIST dataset can be downloaded using the `download_and_prepare_dataset` function below and the downloaded dataset is in the `.npz` format. 

Each subset (e.g., `bloodmnist.npz`) is comprised of 6 keys: `train_images`, `train_labels`, `val_images`, `val_labels`, `test_images` and `test_labels`.

In [None]:
#@title
def download_and_prepare_dataset(data_info: dict):
    """
    Utility function to download the dataset and return train/valid/test images/labels.

    Arguments:
        data_info (dict): Dataset metadata
    """
    data_path = tf.keras.utils.get_file(origin=data_info['url'], md5_hash=data_info['MD5'])

    with np.load(data_path) as data:
        # Get images
        train_images = data['train_images']
        valid_images = data['val_images']
        test_images = data['test_images']

        # Get labels
        train_labels = data['train_labels'].flatten()
        valid_labels = data['val_labels'].flatten()
        test_labels = data['test_labels'].flatten()

    return train_images, train_labels, valid_images, valid_labels, test_images, test_labels

In [None]:
train_images, train_labels, valid_images, valid_labels, test_images, test_labels = download_and_prepare_dataset(info)

# 🌳 Explore the Dataset using W&B Tables

As TensorFlow/Keras users you might be familar with the `show_batch` function. Or you might have written some `matplotlib` based code to visualize few batches of dataset. This is good for quick inspection of the dataset but for most real life scenario it's not enough. 

Here we will use W&B Tables (`wandb.Table`) to log the training data and visualize and query iteractively with W&B. As the name suggests it is a table of data specified by you. Check out more on Tables [here](https://docs.wandb.ai/guides/data-vis).

You can log data to W&B Tables row wise or column wise. In the section below, **we have created the table column wise**. Use `add_column` to define the name of the column and provide array of data associated with that column. Simply adding array of images will not render in the W&B Tables UI. You will have to wrap each image array with `wandb.Image`. To do so, `add_computed_columns` is used. You can learn about these methods [here](https://docs.wandb.ai/ref/python/data-types/table).

Finally, note that W&B Tables is built on top of **[W&B Artifacts](https://docs.wandb.ai/guides/artifacts)**, which can be viewed as a file (usually for dataset and models) storage system in W&B. In this section, we have explicitly initialized an artifact using `wandb.Artifact` and have added both the `train_table` and `validation_table` to the artifact. Alternatively, we could have prepared the table and logged it using `wandb.log`. Here's a quick [example](https://docs.wandb.ai/guides/data-vis#quickly-log-your-first-table) if you are interested. 

If you want to log the entire dataset place a tick in the `log_full` checkbox. Note that we are logging the entire validation data.

In [None]:
# For demonstration purposes
log_full = False #@param {type:"boolean"}

if log_full:
    log_train_samples = len(train_images)
else:
    log_train_samples = 1000 

print(f'Number of train images : {log_train_samples} to be logged')

In [None]:
%%time

# Initialize a new W&B run
run = wandb.init(project='medmnist-bloodmnist', group='viz_data')

# Intialize a W&B Artifacts
ds = wandb.Artifact("medmnist_bloodmnist_dataset", "dataset")

# Initialize an empty table
train_table = wandb.Table(columns=[], data=[])
# Add training data
train_table.add_column('image', train_images[:log_train_samples])
# Add training label_id
train_table.add_column('label_id', train_labels[:log_train_samples])
# Add training class names
train_table.add_computed_columns(lambda ndx, row:{
    "images": wandb.Image(row["image"]),
    "class_names": configs['class_names'][str(row["label_id"])]
    })

# Add the table to the Artifact
ds['train_data'] = train_table

# Let's do the same for the validation data
valid_table = wandb.Table(columns=[], data=[])
valid_table.add_column('image', valid_images)
valid_table.add_column('label_id', valid_labels)
valid_table.add_computed_columns(lambda ndx, row:{
    "images": wandb.Image(row["image"]),
    "class_name": configs['class_names'][str(row["label_id"])]
    })
ds['valid_data'] = valid_table

# Save the dataset as an Artifact
ds.save()

# Finish the run
wandb.finish()

Here's an example of the logged **[train table](https://wandb.ai/ayush-thakur/medmnist-bloodmnist/artifacts/dataset/medmnist_bloodmnist_dataset/61dc68d3fb90fa00168a/files/train_data.table.json)** and **[validation table](https://wandb.ai/ayush-thakur/medmnist-bloodmnist/artifacts/dataset/medmnist_bloodmnist_dataset/61dc68d3fb90fa00168a/files/valid_data.table.json)** of the logged table. You can group, filter and interactively query the dataset $→$

![1st_gif.gif](https://s10.gifyu.com/images/1st_gif.gif)

# 🌲 Data Pipeline

`tf.data.Dataset` is used to build the data pipeline. 

In [None]:
#@title
@tf.function
def preprocess(image: tf.Tensor, label: tf.Tensor):
    """
    Preprocess the image tensors and parse the labels
    """
    # Preprocess images
    image = tf.image.convert_image_dtype(image, tf.float32)
    
    # Parse label
    label = tf.cast(label, tf.float32)
    
    return image, label


def prepare_dataloader(images: np.ndarray,
                       labels: np.ndarray,
                       loader_type: str='train',
                       batch_size: int=128):
    """
    Utility function to prepare dataloader.
    """
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))

    if loader_type=='train':
        dataset = dataset.shuffle(1024)

    dataloader = (
        dataset
        .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )

    return dataloader

In [None]:
trainloader = prepare_dataloader(train_images, train_labels, 'train', configs.get('batch_size', 64))
validloader = prepare_dataloader(valid_images, valid_labels, 'valid', configs.get('batch_size', 64))
testloader = prepare_dataloader(test_images, test_labels, 'test', configs.get('batch_size', 64))

# 🪴 Data Augmentation

We will apply simple image augmentation policies using the Keras preprocessing layers API.

In [None]:
img_augmentation = models.Sequential(
    [
        layers.RandomRotation(factor=0.15),
        layers.RandomFlip()],
    name="img_augmentation",
)

# 🌿 Visualize Different Augmented View. 

Here, let's use W&B Tables to visualize augmented images of a subset of training images. 

Augmentation policies should make sense for the given classification task. By using W&B Tables here we can visualize how the original images are augmented. For the sake of simplicity, we will just be visualizing the first 100 images.

In [None]:
#@title
def augment_5_times(img):
    augmented_imgs = []
    for _ in range(5):
        aug_img = tf.squeeze(img_augmentation(img), axis=0)
        # Notice the use of wrapping the images with wandb.Image
        wandb_image = wandb.Image(aug_img.numpy())
        augmented_imgs.append(wandb_image)

    return augmented_imgs

We can download the dataset that we have logged as W&B Tables as shown in the code cell below. Since Tables are saved as W&B Artifacts, we first need to pass in the name (path as shown in the UI) of the artifact to `use_artifact`. You can find the name if you head over to the [artifact tab](https://docs.wandb.ai/ref/app/pages/project-page#artifacts-tab) on the W&B dashboard and click on the [API panel](https://docs.wandb.ai/ref/app/pages/project-page#api-panel).

Get the required table by using the `get` method and provide the name of the table. Use the `get_column` method get the data associated with that column. Here, the `augment_table` is initialized with the column names and data are added **row-wise** iteratively. 

In [None]:
%%time

viz_augment_samples = 100

# Initialize a W&B run
run = wandb.init(project='medmnist-bloodmnist', group='viz_augmentation')

# Use the already logged dataset
train_art = run.use_artifact('ayush-thakur/medmnist-bloodmnist/medmnist_bloodmnist_dataset:latest', type='dataset')

# Get the train_table to access the data
train_table = train_art.get("train_data")

# Get the images, ground truth label, and row index
images = train_table.get_column("images", convert_to="numpy")
labels = train_table.get_column("label_id", convert_to="numpy")
ids = train_table.get_index()
# Shuffle the ids and slice
random.shuffle(ids)
sample_ids = ids[0:viz_augment_samples]

# Create augmentation table
augment_table = wandb.Table(columns=['image', 'truth', 'label_id', 'aug1', 'aug2', 'aug3', 'aug4', 'aug5'])

# Get augmented images and log it onto the table
for sample_id in sample_ids:
    img = images[sample_id]
    label = labels[sample_id]
    augmented_imgs = augment_5_times(tf.expand_dims(img, axis=0))
    augment_table.add_data(wandb.Image(img),
                           np.argmax(label),
                           configs['class_names'][str(label)],
                           augmented_imgs[0],
                           augmented_imgs[1],
                           augmented_imgs[2],
                           augmented_imgs[3],
                           augmented_imgs[4])

# Log the table
wandb.log({'augmented data': augment_table})

# Finish the run
wandb.finish()

Here's an **[example](https://wandb.ai/ayush-thakur/medmnist-bloodmnist/runs/52pannnd?workspace=user-ayush-thakur)** of the logged table $→$
![3rd_gif_2.gif](https://s10.gifyu.com/images/3rd_gif_2.gif)

# 🎄 Model

We will be using [VGG16](https://www.tensorflow.org/api_docs/python/tf/keras/applications/vgg16/VGG16) as the backbone CNN block. 

In [None]:
def get_model(input_shape: tuple=(28, 28, 3), 
              resize: tuple=(32, 32, 3),
              dropout_rate: float=0.5,
              num_classes: int=8,
              output_activation: str='softmax'):
  
    inputs = layers.Input(input_shape)
    resize_img = layers.Resizing(resize[0], resize[1], interpolation='bilinear')(inputs)
    augment_img = img_augmentation(resize_img)
  
    base_model = tf.keras.applications.VGG16(include_top=False, 
                                             weights=configs['pretrain_weights'], 
                                             input_shape=resize,
                                             input_tensor=augment_img)
    base_model.trainabe = True

    
    x = base_model.output
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(dropout_rate)(x)
    outputs = layers.Dense(num_classes, activation=output_activation)(x)

    return models.Model(inputs, outputs)

tf.keras.backend.clear_session()
model = get_model()
model.summary()

# ☘️ Callback

Here we will define early stopping callback. We will define the `WandbCallback` later. 

In [None]:
earlystopper = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', patience=configs['earlystopping_patience'], verbose=0, mode='auto',
    restore_best_weights=True
)

You can use `wandb.log` to log any useful metric/parameter that's not logged by `WandbCallback`. Here we are using a learning rate scheduler to exponentially decay the learning rate after 10 epochs. Notice the use of `wandb.log` to capture the learning rate and `commit=False` in particular.

You can learn more about `wandb.log` [here](https://docs.wandb.ai/guides/track/log).

In [None]:
def lr_scheduler(epoch, lr):
    # log the current learning rate onto W&B
    if wandb.run is None:
        raise wandb.Error("You must call wandb.init() before WandbCallback()")

    wandb.log({'learning_rate': lr}, commit=False)
    
    if epoch < 7:
        return lr
    else:
        return lr * tf.math.exp(-configs['lr_decay_rate'])

lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)

# 🌻 Train

In [None]:
def train(config: dict, 
          callbacks: list,
          verbose: int=0):
    """
    Utility function to train the model.

    Arguments:
        config (dict): Dictionary of hyperparameters.
        callbacks (list): List of callbacks passed to `model.fit`.
        verbose (int): 0 for silent and 1 for progress bar.
    """

    # Initalize model
    tf.keras.backend.clear_session()
    model = get_model(resize=(config.image_width, config.image_height, config.image_channels))

    # Compile the model
    opt = tf.keras.optimizers.Adam(learning_rate=config.init_learning_rate)
    model.compile(opt,
                  config.loss_fn,
                  metrics=config.metrics)

    # Train the model
    _ = model.fit(trainloader,
                  epochs=config.epochs,
                  validation_data=validloader,
                  callbacks=callbacks,
                  verbose=verbose)

    return model

# 🍄 Train using `WandbCallback`

In the section below we will train our classifier using `WandbCallback` to by default log all the training and validation metrics to a wandb dashboard. 

`WandbCallback` enables to you keep track of your experiments, saves the best model, and helps visualize model performance with just one line of code. 

In the section below, we have used the following arguments:

* `monitor = 'val_loss'` will monitor the mentioned metric to save the best model. Note that `'val_loss'` is the default value for `monitor`.
* `log_weights = True` save histograms of the model's layer's weights. 
* `log_evaluation = True` will create a W&B Table of validation data and model prediction. The number of validation samples is controlled by `validation_steps` if a generator is passed to `model.fit`. 

Check out the documentation [here](https://docs.wandb.ai/ref/python/integrations/keras/wandbcallback) to know more about the `WandbCallback`.


In [None]:
# Initialize the W&B run
run = wandb.init(project='medmnist-bloodmnist', config=configs, job_type='train')
config = wandb.config

# Define WandbCallback for experiment tracking
wandb_callback = WandbCallback(monitor='val_loss',
                               log_weights=True,
                               log_evaluation=True,
                               validation_steps=5)

# callbacks
callbacks = [earlystopper, wandb_callback, lr_callback]

# Train
model = train(config, callbacks=callbacks, verbose=1)

# Evaluate the trained model
loss, acc = model.evaluate(validloader)
wandb.log({'evaluate/accuracy': acc})

# Close the W&B run.
wandb.finish()

Check out the example **[W&B run page](https://wandb.ai/ayush-thakur/medmnist-bloodmnist/runs/3ceo33mx)** after training the model $→$
![2nd_gif.gif](https://s10.gifyu.com/images/2nd_gif.gif)

# 🌱 Advanced Usage

In this section, we will see an advance usage of `WandbCallback`. 

We will use `WandbCallback` to log the GradCAM for each validation examples along with gound truth labels and model predictions.

We will be using this [tutorial](https://keras.io/examples/vision/grad_cam/) on GradCAM by François Chollet.


In [None]:
#@title
def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
    # First, we create a model that maps the input image to the activations
    # of the last conv layer as well as the output predictions
    grad_model = tf.keras.models.Model(
        [model.inputs], [model.get_layer(last_conv_layer_name).output, model.output]
    )

    # Then, we compute the gradient of the top predicted class for our input image
    # with respect to the activations of the last conv layer
    with tf.GradientTape() as tape:
        last_conv_layer_output, preds = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(preds[0])
        class_channel = preds[:, pred_index]

    # This is the gradient of the output neuron (top predicted or chosen)
    # with regard to the output feature map of the last conv layer
    grads = tape.gradient(class_channel, last_conv_layer_output)

    # This is a vector where each entry is the mean intensity of the gradient
    # over a specific feature map channel
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

    # We multiply each channel in the feature map array
    # by "how important this channel is" with regard to the top predicted class
    # then sum all the channels to obtain the heatmap class activation
    last_conv_layer_output = last_conv_layer_output[0]
    heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)

    # For visualization purpose, we will also normalize the heatmap between 0 & 1
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    return heatmap.numpy()

def create_gradcam(image, model, last_conv_layer_name, pred_index=None):
    # Preprocess the image array
    image, _ = preprocess(tf.expand_dims(image, axis=0), 0)
    # Get GradCAM
    heatmap = make_gradcam_heatmap(image, model, last_conv_layer_name, pred_index)
    heatmap = np.uint8(255 * heatmap)

    # Use jet colormap to colorize heatmap
    jet = cm.get_cmap("jet")

    # Use RGB values of the colormap
    jet_colors = jet(np.arange(256))[:, :3]
    jet_heatmap = jet_colors[heatmap]
    jet_heatmap = tf.image.resize(jet_heatmap, size=(28,28))

    # Overlay
    superimposed_img = jet_heatmap * 0.4 + tf.squeeze(image, axis=0)
    superimposed_img = tf.clip_by_value(superimposed_img, 0.0, 1.0)

    return superimposed_img

In [None]:
last_conv_layer_name = 'block4_conv3'

In the cell block below, we will be using `WandbCallback`'s `validation_row_processor` and `prediction_row_processor` to log the images, ground truth label, model prediction and the GradCAM for model interpretability. 

The processors' take a callable function that receive an `ndx` (index) and a `row` (dict of data). The `validation_processor` function below receives the input image array along with target label as `row` dict. The `prediction_processor` receives  the model output prediction and the validation data row index. 

The `validation_row_processor` is executed when `WandbCallback` is initialized (i.e, before model training) while `prediction_row_processor` is called once the training is over. The `validation_row_processor` creates a table with two columns `input:image` and `target:class`. Notice that in the `prediction_processor` function we can get the logged image at a given `val_row` using the `get_row` method. 

In [None]:
def validation_processor(ndx, row):
    return {
        "input:image": wandb.Image(row["input"]),
        "target:class": class_table.index_ref(row["target"])
    }

def prediction_processor(ndx, row):
    # Get the validation image
    valid_image = np.array(row["val_row"].get_row()["input:image"].image)

    return {
        "output:class": class_table.index_ref(np.argmax(row["output"])),
        "gradcam": wandb.Image(create_gradcam(valid_image, model, last_conv_layer_name)),
        "output:logits": {class_name: value for (class_name, value) in zip(list(config.class_names.values()), row["output"].tolist())}
    }

In [None]:
# Initialize the W&B run
run = wandb.init(project='medmnist-bloodmnist', config=configs, job_type='train')
config = wandb.config

# Get validation table
data_art = run.use_artifact('ayush-thakur/medmnist-bloodmnist/medmnist_bloodmnist_dataset:latest', type='dataset')
valid_table = data_art.get("valid_data")

# Create a class table
class_table = wandb.Table(columns=[], data=[])
class_table.add_column("class_name", list(config.class_names.values()))

# Define WandbCallback for experiment tracking
wandb_callback = WandbCallback(
                    log_evaluation=True,
                    validation_row_processor=lambda ndx, row: validation_processor(ndx, row),
                    prediction_row_processor=lambda ndx, row: prediction_processor(ndx, row),
                    validation_steps=4,
                    save_model=False
                )

# callbacks
callbacks = [earlystopper, wandb_callback, lr_callback]

# Train
model = train(config, callbacks=callbacks, verbose=1)

# Evaluate the trained model
loss, acc = model.evaluate(validloader)
wandb.log({'evaluate/accuracy': acc})

# Close the W&B run.
wandb.finish()

Here's an example **[W&B run page](https://wandb.ai/ayush-thakur/medmnist-bloodmnist/runs/3befja9v?workspace=user-ayush-thakur)** with logged GradCAM and logits of the model $→$

![3rd_gif.gif](https://s10.gifyu.com/images/3rd_gif.gif)

## 🌾 Conclusion

Weights and Biases's Keras integration enables experiment tracking and so much more with just few lines of code. In this notebook, we have seen some advanced usage of Keras `WandbCallback` and different ways of using W&B Tables for evaluation and data exploration. 

To sum up all you need is a free W&B account, import the `WandbCallback` and pass it to `model.fit(callbacks=[.])` just like any callback. There are few more arguments that you can learn about in the documentation page [here](https://docs.wandb.ai/ref/python/integrations/keras/wandbcallback). In particular, 

* you can log the metrics for each batch by setting `log_batch_frequency=1`,
* you can log the gradients of each layer to debug vanishing or exploding gradient issue by setting `log_gradients=True`. You will also have to provide the `training_data` in the format of `(X, y)`.
* if your task is semantic segmentation you can set `input_type=segmentation_mask`. 

If a usecase is not covered by the `WandbCallback` you can easily write a [custom Keras callback](https://keras.io/guides/writing_your_own_callbacks/) and use `wandb.log` to log the required data to W&B dashboard.