# Description
This notebook loads and displays training images with bounding box labels using PyTorch utility functions.

There is the option to enable some albumentations transformations and display the same images with transformations applied.

In [None]:
import os
import pprint

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from torchvision.utils import draw_bounding_boxes

from wheat.config import load_config
from wheat.data_module import WheatDataModule

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# change to the top level directory for this repository
os.chdir('..')

In [None]:
# this conversion is needed because albumentations transforms return
# images in uint8, but pytorch expects them to be floats in [0, 1]
image_float_to_int_transform = T.ConvertImageDtype(torch.uint8)

In [None]:
# from https://pytorch.org/vision/stable/auto_examples/plot_visualization_utils.html#visualizing-bounding-boxes
def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
def display_image_grid(dataset, images_per_row=3, num_rows=3, first_image_index=0):
    for irow in range(num_rows):
        image_list = []
        for icol in range(images_per_row):
            image, labels = train_dataset[irow * images_per_row + icol + first_image_index]
            result = draw_bounding_boxes(
                image_float_to_int_transform(image), labels['boxes'], colors='blue', width=5)
            image_list.append(result)
        show(image_list)

In [None]:
# load the default configuration
config = load_config('wheat/config/config.ini')
pp = pprint.PrettyPrinter(indent=2)

In [None]:
# make sure data augmentation is off for first look at training images
for key in config['train']['transforms']:
    config['train']['transforms'][key] = 0  # set transform probability to 0

In [None]:
pp.pprint(config['train']['transforms'])

In [None]:
# initialize the dataset
wheat_data_module = WheatDataModule(config)
wheat_data_module.setup(stage='fit')
train_dataset = wheat_data_module.train_dataset

In [None]:
# display a few images
plt.rcParams['figure.figsize'] = [15, 10]
display_image_grid(train_dataset)

In [None]:
# turn on one or more transforms
config['train']['transforms']['color_jitter_prob'] = 1

In [None]:
# update the transforms associated with the dataset
wheat_data_module.config = config
train_dataset.transform = wheat_data_module.get_transforms()[0]

In [None]:
# display images with the transforms applied
display_image_grid(train_dataset)