# Detectron2 on SKU-110K dataset

** Index **

1. [Background](#Background)
1. [Setup](#Setup)
1. [Data](#Data)
1. [Training](#Training)
1. [Hyperparameter Tuning Jobs](#HPO)
1. [Deploy: Batch Transform](#Deploy)
1. [Evaluation](#Evaluation)

## Background

TODO present the dataset and the goal of this notebook

## Setup

In [None]:
import sagemaker

In [None]:
bucket = "sagemaker-sku110k-dataset" # "YOUR-BUCKET"
prefix_data = "detectron2/data"
prefix_model = "detectron2/training_artefacts"
local_folder = "cache"

sm_session = sagemaker.Session(default_bucket=bucket)

role = sagemaker.get_execution_role()

## Data

In [None]:
from pathlib import Path
from urllib import request
import tarfile
from typing import Sequence, Mapping, Optional
from tqdm import tqdm
from datetime import datetime
import tempfile
import json

import pandas as pd
import numpy as np
import boto3

### Download SKU-110K dataset

In [None]:
sku_dataset = ("SKU110K_fixed", "http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz")

if not (Path(local_folder) / sku_dataset[0]).exists():
    compressed_file = tarfile.open(fileobj=request.urlopen(sku_dataset[1]), mode="r|gz")
    compressed_file.extractall(path=local_folder)
else:
    print(f"Using the data in `{local_folder}` folder")

### Reorganize images

Images are moved to three channels, training, validation and test, according to the prefix of the of the image name. The images are then uploaded to the S3 bucket specified in the setup.

:warning: upload to S3 will take some time

In [None]:
path_images = Path(local_folder) / sku_dataset[0] / "images"
assert path_images.exists(), f"{path_images} not found"

prefix_to_channel = {
    "train": "training",
    "val": "validation",
    "test": "test",
}
for channel_name in prefix_to_channel.values():
    if not (path_images.parent / channel_name).exists():
        (path_images.parent / channel_name).mkdir()

for path_img in path_images.iterdir():
    for prefix in prefix_to_channel:
        if path_img.name.startswith(prefix):
            path_img.replace(path_images.parent / prefix_to_channel[prefix] / path_img.name)

Detectron2 uses Pillow to read images. We found out that some images in the SKU dataset are corrupted, which causes the dataloader to raise an IOError exception. Therefore, we remove them from the dataset. 

In [None]:
CORRUPTED_IMAGES = {
    "training": ("train_4222.jpg", "train_5822.jpg", "train_882.jpg", "train_924.jpg"),
    "validation": tuple(),
    "test": ("test_274.jpg",)
}

In [None]:
for channel_name in prefix_to_channel.values():
    for img_name in CORRUPTED_IMAGES[channel_name]:
        try:
            (path_images.parent / channel_name / img_name).unlink()
            print(f"{img_name} removed from channel {channel_name} ")
        except FileNotFoundError:
            print(f"{img_name} not in channel {channel_name}")
            

In [None]:
for channel_name in prefix_to_channel.values():
    print(
        f"Number of {channel_name} images = {sum(1 for x in (path_images.parent / channel_name).glob('*.jpg'))}"
    )

In [None]:
channel_to_s3_imgs = {}

for channel_name in prefix_to_channel.values():
    inputs = sm_session.upload_data(
        path=str(path_images.parent / channel_name) ,
        bucket=bucket,
        key_prefix=f"{prefix_data}/{channel_name}"
    )
    print(f"{channel_name} images uploaded to {inputs}")
    channel_to_s3_imgs[channel_name] = inputs

### Annotations processing

The annotations are 

In [None]:
def create_annotation_channel(
    channel_id: str, path_to_annotation: Path, bucket_name: str, data_prefix: str,
    img_annotation_to_ignore: Optional[Sequence[str]] = None
) -> Sequence[Mapping]:
    r"""Change format from original to augmented manifest files

    Parameters
    ----------
    channel_id : str
        name of the channel, i.e. training, validation or test
    path_to_annotation : Path
        path to annotation file
    bucket_name : str
        bucket where the data are uploaded
    data_prefix : str
        bucket prefix
    img_annotation_to_ignore : Optional[Sequence[str]]
        annotation from these images are ignore because the corresponding images are corrupted, default to None

    Returns
    -------
    Sequence[Mapping]
        List of json lines, each lines contains the annotations for a single. This recreates the
        format of augmented manifest files that are generated by Amazon SageMaker GroundTruth
        labeling jobs
    """
    if channel_id not in ("training", "validation", "test"):
        raise ValueError(
            f"Channel identifier must be training, validation or test. The passed values is {channel_id}"
        )
    if not path_to_annotation.exists():
        raise FileNotFoundError(f"Annotation file {path_to_annotation} not found")

    df_annotation = pd.read_csv(
        path_to_annotation,
        header=0,
        names=(
            "image_name",
            "x1",
            "y1",
            "x2",
            "y2",
            "class",
            "image_width",
            "image_height",
        ),
    )

    df_annotation["left"] = df_annotation["x1"]
    df_annotation["top"] = df_annotation["y1"]
    df_annotation["width"] = df_annotation["x2"] - df_annotation["x1"]
    df_annotation["height"] = df_annotation["y2"] - df_annotation["y1"]
    df_annotation.drop(columns=["x1", "x2", "y1", "y2"], inplace=True)

    jsonlines = []
    for img_id in df_annotation["image_name"].unique():
        if img_annotation_to_ignore and img_id in img_annotation_to_ignore:
            print(f"Annotations for image {img_id} are neglected as the image is corrupted")
            continue
        img_annotations = df_annotation.loc[df_annotation["image_name"] == img_id, :]
        annotations = []
        for (
            _,
            _,
            img_width,
            img_heigh,
            bbox_l,
            bbox_t,
            bbox_w,
            bbox_h,
        ) in img_annotations.itertuples(index=False):
            annotations.append(
                {
                    "class_id": 0,
                    "width": bbox_w,
                    "top": bbox_t,
                    "left": bbox_l,
                    "height": bbox_h,
                }
            )
        jsonline = {
            "sku": {
                "annotations": annotations,
                "image_size": [{"width": img_width, "depth": 3, "height": img_heigh,}],
            },
            "sku-metadata": {
                "job_name": f"labeling-job/sku-110k-{channel_id}",
                "class-map": {"0": "SKU"},
                "human-annotated": "yes",
                "objects": len(annotations) * [{"confidence": 0.0}],
                "type": "groundtruth/object-detection",
                "creation-date": datetime.now()
                .replace(second=0, microsecond=0)
                .isoformat(),
            },
            "source-ref": f"s3://{bucket_name}/{data_prefix}/{channel_id}/{img_id}",
        }
        jsonlines.append(jsonline)
    return jsonlines

In [None]:
channel_to_annotation_path = {
    "training": Path(local_folder) / sku_dataset[0] / "annotations" / "annotations_train.csv",
    "validation": Path(local_folder) / sku_dataset[0] / "annotations" / "annotations_val.csv",
    "test": Path(local_folder) / sku_dataset[0] / "annotations" / "annotations_test.csv",
}
channel_to_annotation = {}

for channel in channel_to_annotation_path:
    annotations = create_annotation_channel(
        channel,
        channel_to_annotation_path[channel],
        bucket,
        prefix_data,
        CORRUPTED_IMAGES[channel]
    )
    print(f"Number of {channel} annotations: {len(annotations)}")
    channel_to_annotation[channel] = annotations


In [None]:
def upload_annotations(p_annotations, p_channel: str):
    rsc_bucket = boto3.resource("s3").Bucket(bucket)
    
    json_lines = [json.dumps(elem) for elem in p_annotations]
    to_write = "\n".join(json_lines)

    with tempfile.NamedTemporaryFile(mode="w") as fid:
        fid.write(to_write)
        rsc_bucket.upload_file(fid.name, f"{prefix_data}/annotations/{p_channel}.manifest")

In [None]:
for channel_id, annotations in channel_to_annotation.items():
    upload_annotations(annotations, channel_id)

## Training

Build the Docker container defined in the image *Dockerfile.sku110ktraining* and push it to ECR. The Python SageMaker SDK can then be used to launch Amazon SageMaker training jobs.

In [None]:
import json

import boto3
from sagemaker.estimator import Estimator

assert sagemaker.__version__.split('.')[0] == '2', f"Install Sagemaker Python SDK vs 2"

In [None]:
training_channel = f"s3://{bucket}/{prefix_data}/training/"
validation_channel = f"s3://{bucket}/{prefix_data}/validation/"
test_channel = f"s3://{bucket}/{prefix_data}/test/"

annotation_channel = f"s3://{bucket}/{prefix_data}/annotations/"

classes = ["SKU",]

In [None]:
account_id = boto3.client("sts").get_caller_identity().get("Account")
region = boto3.session.Session().region_name
container_name = "sagemaker-d2-train-sku110k"
container_version = "latest"
training_image_uri = f"{account_id}.dkr.ecr.{region}.amazonaws.com/{container_name}:{container_version}"

The following hyper-parameters are used in the training job. Feel free to change them and experiment.

In [None]:
metrics = [
    {"Name": "training:loss", "Regex": "total_loss: ([0-9\\.]+)",},
    {"Name": "training:loss_cls", "Regex": "loss_cls: ([0-9\\.]+)",},
    {"Name": "training:loss_box_reg", "Regex": "loss_box_reg: ([0-9\\.]+)",},
    {"Name": "training:loss_rpn_cls", "Regex": "loss_rpn_cls: ([0-9\\.]+)",},
    {"Name": "training:loss_rpn_loc", "Regex": "loss_rpn_loc: ([0-9\\.]+)",},
    {"Name": "validation:loss", "Regex": "total_val_loss: ([0-9\\.]+)",},
    {"Name": "validation:loss_cls", "Regex": "val_loss_cls: ([0-9\\.]+)",},
    {"Name": "validation:loss_box_reg", "Regex": "val_loss_box_reg: ([0-9\\.]+)",},
    {"Name": "validation:loss_rpn_cls", "Regex": "val_loss_rpn_cls: ([0-9\\.]+)",},
    {"Name": "validation:loss_rpn_loc", "Regex": "val_loss_rpn_loc: ([0-9\\.]+)",},
]

In [None]:
training_instance = "ml.p3.2xlarge"
if training_instance.startswith("local"):
    training_session = sagemaker.LocalSession()
    training_session.config = {'local': {'local_code': True}}
else:
    training_session = sm_session

In [None]:
od_algorithm = "faster_rcnn" # choose one in ("faster_rcnn", "retinanet")
training_job_hp = {
    # Dataset
    "classes": json.dumps(classes),
    "dataset-name": json.dumps("sku110k"),
    "label-name": json.dumps("sku"),
    # Algo specs
    "model-type": json.dumps(od_algorithm),
    "backbone": json.dumps("R_101_FPN"),
    # Data loader
    "num-iter": 500,
    "log-period": 500,
    "batch-size": 4,
    "num-workers": 8,
    # Optimization
    "lr": 0.004681578380412093,
    "lr-schedule": 3,
    # Faster-RCNN specific
    "num-rpn": 1024,
    "bbox-head-pos-fraction": 0.20143586338550198,
    "bbox-rpn-pos-fraction": 0.21452816271559746,
    # Prediction specific
    "nms-thr": 0.2,
    "pred-thr": 0.1,
    "det-per-img": 300,
}

In [None]:
d2_estimator = Estimator(
    image_uri=training_image_uri,
    role=role,
    sagemaker_session=training_session,
    instance_count=1,
    instance_type=training_instance,
    hyperparameters=training_job_hp,
    metric_definitions=metrics,
    output_path=f"s3://{bucket}/{prefix_model}",
    base_job_name=f"detectron2-{od_algorithm.replace('_', '-')}",
)

In [None]:
d2_estimator.fit(
    {
        "training": training_channel,
        "validation": validation_channel,
        "annotation": annotation_channel,
    },
    wait=training_instance == "local",
)

## HPO

In [None]:
from sagemaker.tuner import IntegerParameter, CategoricalParameter, ContinuousParameter, HyperparameterTuner

od_algorithm = "faster_rcnn" # choose one in ("faster_rcnn", "retinanet")

In [None]:
hparams_range = {
    "lr": ContinuousParameter(0.0001, 0.01),
    "focal-loss-gamma": ContinuousParameter(2.0, 4.0),   # RetinaNet only
    "focal-loss-alpha": ContinuousParameter(0.1, 1.0),   # RetinaNet only    
}
if od_algorithm == "faster_rcnn":
    hparams_range.update(
        {
            "bbox-rpn-pos-fraction": ContinuousParameter(0.2, 0.8),
            "bbox-head-pos-fraction": ContinuousParameter(0.2, 0.8),
            "num-rpn": IntegerParameter(500, 2000),
        }
    )
elif od_algorithm=="retinanet":
    hparams_range.update(
        {
            "focal-loss-gamma": ContinuousParameter(2.0, 4.0),
            "focal-loss-alpha": ContinuousParameter(0.1, 1.0), 
        }
    )
else:
    assert False, f"{od_algorithm} not supported"

In [None]:
obj_metric_name = "validation:loss"
obj_type = "Minimize"
metric_definitions = [
    {"Name": "training:loss", "Regex": "total_loss: ([0-9\\.]+)",},
    {"Name": "training:loss_cls", "Regex": "loss_cls: ([0-9\\.]+)",},
    {"Name": "training:loss_box_reg", "Regex": "loss_box_reg: ([0-9\\.]+)",},
    {"Name": obj_metric_name, "Regex": "total_val_loss: ([0-9\\.]+)",},
    {"Name": "validation:loss_cls", "Regex": "val_loss_cls: ([0-9\\.]+)",},
    {"Name": "validation:loss_box_reg", "Regex": "val_loss_box_reg: ([0-9\\.]+)",},
]

In [None]:
fixed_hparams = {
    # Dataset
    "classes": json.dumps(classes),
    "dataset-name": json.dumps("sku110k"),
    "label-name": json.dumps("sku"),
    # Algo specs
    "model-type": json.dumps(od_algorithm),
    "backbone": json.dumps("R_101_FPN"),
    # Data loader
    "num-iter": 9000,
    "log-period": 500,
    "batch-size": 16,
    "num-workers": 8,
    # Optimization
    "lr-schedule": 3,
    # Prediction specific
    "nms-thr": 0.2,
    "pred-thr": 0.1,
    "det-per-img": 300,
}

hpo_estimator = Estimator(
    image_uri=training_image_uri,
    role=role,
    sagemaker_session=sm_session,
    instance_count=1,
    instance_type="ml.p3.8xlarge",
    hyperparameters=fixed_hparams,
    output_path=f"s3://{bucket}/{prefix_model}",
    use_spot_instances = True,    # Use spot instances
    max_run=2 * 60 * 60,
    max_wait=3 * 60 * 60,
)

In [None]:
tuner = HyperparameterTuner(
    hpo_estimator,
    obj_metric_name,
    hparams_range,
    metric_definitions,
    objective_type=obj_type,
    max_jobs=16,
    max_parallel_jobs=2,
    base_tuning_job_name=f"hpo-detectron2-{od_algorithm.replace('_', '-')}",
)

In [None]:
tuner.fit(
    inputs={
        "training": training_channel,
        "validation": validation_channel,
        "annotation": annotation_channel,
    },
    wait=False,
)

In [None]:
existing_tuning_job = "the-tuning-job-name"   # Change this
tuner.attach(tuning_job_name=existing_tuning_job)

In [None]:
bayes_metrics = sagemaker.HyperparameterTuningJobAnalytics(existing_tuning_job).dataframe()
bayes_metrics.sort_values(["FinalObjectiveValue"], ascending=True)

## Deploy

In [None]:
# TODO

## Evaluation

In [None]:
# TODO