<a href="https://colab.research.google.com/github/soumik12345/wandb-addons/blob/keras%2Fyolov8/docs/keras/examples/train_retinanet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# !pip install --upgrade -q git+https://github.com/keras-team/keras-cv wandb
# !git clone https://github.com/soumik12345/wandb-addons -b keras/yolov8 && cd wandb-addons && pip install -q -e .

In [2]:
import os
from typing import Union, Optional

import tensorflow as tf
from tensorflow import keras
from tensorflow import data as tf_data
import tensorflow_datasets as tfds

import keras_cv
import numpy as np

import resource
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import wandb
from wandb.keras import WandbMetricsLogger
from wandb_addons.keras.detection.inference import get_mean_confidence_per_class

Using TensorFlow backend


In [3]:
wandb.init(
    project="keras-cv-callbacks", entity="geekyrakshit", job_type="detection"
)

config = wandb.config
config.batch_size = 4
config.base_lr = 0.005
config.model_name = "retinanet_resnet50_pascalvoc"
config.momentum = 0.9
config.global_clipnorm = 10.0


class_ids = [
    "Aeroplane",
    "Bicycle",
    "Bird",
    "Boat",
    "Bottle",
    "Bus",
    "Car",
    "Cat",
    "Chair",
    "Cow",
    "Dining Table",
    "Dog",
    "Horse",
    "Motorbike",
    "Person",
    "Potted Plant",
    "Sheep",
    "Sofa",
    "Train",
    "Tvmonitor",
    "Total",
]
config.class_mapping = dict(zip(range(len(class_ids)), class_ids))

[34m[1mwandb[0m: Currently logged in as: [33mgeekyrakshit[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
train_ds = tfds.load(
    "voc/2007",
    split="train+validation",
    with_info=False,
    shuffle_files=True,
)
train_ds = train_ds.concatenate(
    tfds.load(
        "voc/2012",
        split="train+validation",
        with_info=False,
        shuffle_files=True,
    )
)
eval_ds = tfds.load("voc/2007", split="test", with_info=False)

In [5]:
def unpackage_tfds_inputs(inputs, bounding_box_format):
    image = inputs["image"]
    boxes = keras_cv.bounding_box.convert_format(
        inputs["objects"]["bbox"],
        images=image,
        source="rel_yxyx",
        target=bounding_box_format,
    )
    bounding_boxes = {
        "classes": tf.cast(inputs["objects"]["label"], dtype=tf.float32),
        "boxes": tf.cast(boxes, dtype=tf.float32),
    }
    return {
        "images": tf.cast(image, tf.float32),
        "bounding_boxes": bounding_boxes,
    }


train_ds = train_ds.map(
    lambda inputs: unpackage_tfds_inputs(inputs, bounding_box_format="xywh"),
    num_parallel_calls=tf.data.AUTOTUNE,
)
eval_ds = eval_ds.map(
    lambda inputs: unpackage_tfds_inputs(inputs, bounding_box_format="xywh"),
    num_parallel_calls=tf.data.AUTOTUNE,
)

In [6]:
train_ds = train_ds.ragged_batch(config.batch_size, drop_remainder=True)
eval_ds = eval_ds.ragged_batch(config.batch_size, drop_remainder=True)

In [7]:
augmenter = keras.Sequential(
    layers=[
        keras_cv.layers.RandomFlip(mode="horizontal", bounding_box_format="xywh"),
        keras_cv.layers.JitteredResize(
            target_size=(640, 640), scale_factor=(0.75, 1.3), bounding_box_format="xywh"
        ),
    ]
)

train_ds = train_ds.map(augmenter, num_parallel_calls=tf.data.AUTOTUNE)



In [8]:
inference_resizing = keras_cv.layers.Resizing(
    640, 640, bounding_box_format="xywh", pad_to_aspect_ratio=True
)
eval_ds = eval_ds.map(inference_resizing, num_parallel_calls=tf.data.AUTOTUNE)

In [9]:
def dict_to_tuple(inputs):
    return inputs["images"], keras_cv.bounding_box.to_dense(
        inputs["bounding_boxes"], max_boxes=32
    )


train_ds = train_ds.map(dict_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)
eval_ds = eval_ds.map(dict_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)

train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
eval_ds = eval_ds.prefetch(tf.data.AUTOTUNE)

In [10]:
model = keras_cv.models.RetinaNet.from_preset(
    "resnet50_imagenet",
    num_classes=len(config.class_mapping),
    bounding_box_format="xywh",
)
print(model.prediction_decoder)


optimizer = keras.optimizers.SGD(
    learning_rate=config.base_lr,
    momentum=config.momentum,
    global_clipnorm=config.global_clipnorm,
)
model.compile(
    classification_loss="focal",
    box_loss="smoothl1",
    optimizer=optimizer,
    metrics=None,
)

<keras_cv.layers.object_detection.non_max_suppression.NonMaxSuppression object at 0x7910e9fcdcf0>


In [11]:
class EvaluateCOCOMetricsCallback(keras.callbacks.Callback):
    def __init__(self, data, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.data = data
        self.metrics = keras_cv.metrics.BoxCOCOMetrics(
            bounding_box_format="xywh",
            evaluate_freq=1e9,
        )

    def on_epoch_end(self, epoch, logs):
        self.metrics.reset_state()
        for batch in tqdm(self.data):
            images, y_true = batch[0], batch[1]
            y_pred = self.model.predict(images, verbose=0)
            self.metrics.update_state(y_true, y_pred)

        metrics = self.metrics.result(force=True)
        logs.update(metrics)
        return logs

In [12]:
class WandBDetectionVisualizationCallback(keras.callbacks.Callback):
    def __init__(
        self,
        dataset: tf.data.Dataset,
        class_mapping: dict,
        max_batches_for_vis: Optional[Union[int, None]] = 1,
        iou_threshold: float = 0.01,
        confidence_threshold: float = 0.01,
        source_bounding_box_format: str = "xywh",
        title: str = "Evaluation-Table",
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.dataset = dataset.take(max_batches_for_vis)
        self.class_mapping = class_mapping
        self.max_batches_for_vis = max_batches_for_vis
        self.iou_threshold = iou_threshold
        self.confidence_threshold = confidence_threshold
        self.source_bounding_box_format = source_bounding_box_format
        self.title = title
        self.prediction_decoder = keras_cv.layers.MultiClassNonMaxSuppression(
            bounding_box_format=self.source_bounding_box_format,
            from_logits=True,
            iou_threshold=self.iou_threshold,
            confidence_threshold=self.confidence_threshold,
        )
        self.table = wandb.Table(columns=["Epoch", "Image", "Number-of-Ground-Truth-Boxes", "Mean-Confidence"])

    def plot_prediction(self, epoch, image_batch, y_true_batch):
        y_pred_batch = self.model.predict(image_batch)
        y_pred = keras_cv.bounding_box.to_ragged(y_pred_batch)
        image_batch = keras_cv.utils.to_numpy(image_batch).astype(np.uint8)
        ground_truth_bounding_boxes = keras_cv.utils.to_numpy(
            keras_cv.bounding_box.convert_format(
                y_true_batch["boxes"],
                source=self.source_bounding_box_format,
                target="xyxy",
                images=image_batch,
            )
        )
        ground_truth_classes = keras_cv.utils.to_numpy(y_true_batch["classes"])
        predicted_bounding_boxes = keras_cv.utils.to_numpy(
            keras_cv.bounding_box.convert_format(
                y_pred["boxes"],
                source=self.source_bounding_box_format,
                target="xyxy",
                images=image_batch,
            )
        )
        for idx in range(image_batch.shape[0]):
            num_detections = y_pred["num_detections"][idx].item()
            predicted_boxes = predicted_bounding_boxes[idx][:num_detections]
            confidences = keras_cv.utils.to_numpy(
                y_pred["confidence"][idx][:num_detections]
            )
            predicted_classes = keras_cv.utils.to_numpy(y_pred["classes"][idx][:num_detections])

            gt_classes = [
                int(class_idx) for class_idx in ground_truth_classes[idx].tolist()
            ]
            gt_boxes = ground_truth_bounding_boxes[idx]
            if -1 in gt_classes:
                gt_classes = gt_classes[: gt_classes.index(-1)]

            wandb_prediction_boxes = []
            for box_idx in range(num_detections):
                wandb_prediction_boxes.append(
                    {
                        "position": {
                            "minX": predicted_boxes[box_idx][0]
                            / image_batch[idx].shape[0],
                            "minY": predicted_boxes[box_idx][1]
                            / image_batch[idx].shape[1],
                            "maxX": predicted_boxes[box_idx][2]
                            / image_batch[idx].shape[0],
                            "maxY": predicted_boxes[box_idx][3]
                            / image_batch[idx].shape[1],
                        },
                        "class_id": int(predicted_classes[box_idx]),
                        "box_caption": self.class_mapping[int(predicted_classes[box_idx])],
                        "scores": {"confidence": float(confidences[box_idx])},
                    }
                )

            wandb_ground_truth_boxes = []
            for box_idx in range(len(gt_classes)):
                wandb_ground_truth_boxes.append(
                    {
                        "position": {
                            "minX": int(gt_boxes[box_idx][0]),
                            "minY": int(gt_boxes[box_idx][1]),
                            "maxX": int(gt_boxes[box_idx][2]),
                            "maxY": int(gt_boxes[box_idx][3]),
                        },
                        "class_id": gt_classes[box_idx],
                        "box_caption": self.class_mapping[int(gt_classes[box_idx])],
                        "domain": "pixel",
                    }
                )
            wandb_image = wandb.Image(
                image_batch[idx],
                boxes={
                    "ground-truth": {
                        "box_data": wandb_ground_truth_boxes,
                        "class_labels": self.class_mapping,
                    },
                    "predictions": {
                        "box_data": wandb_prediction_boxes,
                        "class_labels": self.class_mapping,
                    },
                },
            )
            mean_confidence_dict = get_mean_confidence_per_class(
                confidences, predicted_classes, self.class_mapping
            )
            self.table.add_data(epoch, wandb_image, len(gt_classes), mean_confidence_dict)

    def on_epoch_end(self, epoch, logs):
        original_prediction_decoder = self.model._prediction_decoder
        self.model.prediction_decoder = self.prediction_decoder
        for _ in tqdm(range(self.max_batches_for_vis)):
            image_batch, y_true_batch = next(iter(self.dataset))
            self.plot_prediction(epoch, image_batch, y_true_batch)
        self.model.prediction_decoder = original_prediction_decoder

    def on_train_end(self, logs):
        wandb.log({self.title: self.table})

In [13]:
model.fit(
    train_ds.take(20),
    validation_data=train_ds.take(20),
    epochs=3,
    callbacks=[
        # EvaluateCOCOMetricsCallback(eval_ds.take(20)),
        WandbMetricsLogger(log_freq="batch"),
        WandBDetectionVisualizationCallback(
            dataset=train_ds.take(20), class_mapping=config.class_mapping
        ),
    ],
)

Epoch 1/3

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch 2/3

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch 3/3

  0%|          | 0/1 [00:00<?, ?it/s]



<keras.callbacks.History at 0x7910e246f970>

In [14]:
wandb.finish()

0,1
batch/batch_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
batch/box_loss,▆██▇▇▇▇▆▆▇▆▆▆▆▃▃▃▃▃▃▃▃▃▃▃▃▃▁▂▂▁▁▁▁▂▁▁▁▂▂
batch/classification_loss,█▇▇▇▇▇▇▇▆▆▆▆▆▆▅▅▆▆▅▅▅▅▄▄▄▄▄▃▁▁▁▁▂▂▂▂▂▁▁▁
batch/learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
batch/loss,███▇▇▇▇▇▇▇▇▇▇▇▅▅▅▅▅▅▅▄▄▄▄▄▄▂▁▁▁▁▁▁▁▂▁▁▁▁
batch/percent_boxes_matched_with_anchor,▇▇▄▃▁▁▂▂▁▂▃▃▂▂▇█▇▇▄▄▅▄▃▃▄▄▄▄▃▄▅▆▅▆▄▄▃▃▄▄
epoch/box_loss,█▃▁
epoch/classification_loss,█▅▁
epoch/epoch,▁▅█
epoch/learning_rate,▁▁▁

0,1
batch/batch_step,59.0
batch/box_loss,0.63295
batch/classification_loss,0.8724
batch/learning_rate,0.005
batch/loss,1.50535
batch/percent_boxes_matched_with_anchor,0.92344
epoch/box_loss,0.63295
epoch/classification_loss,0.8724
epoch/epoch,2.0
epoch/learning_rate,0.005
