# Explore the dataset


In this notebook, we will perform an EDA (Exploratory Data Analysis) on the processed Waymo dataset (data in the `processed` folder). In the first part, you will create a function to display 

In [None]:
import os
import random

import matplotlib.pyplot as pyplot
import matplotlib.patches as patches
import seaborn
import tensorflow as tf

from utils import get_dataset

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.get_logger().setLevel("ERROR")

In [None]:
dataset = get_dataset("/data/processed/segment-10023947602400723454_1120_000_1140_000_with_camera_labels.tfrecord")

## Write a function to display an image and the bounding boxes

Implement the `display_instances` function below. This function takes a batch as an input and display an image with its corresponding bounding boxes. The only requirement is that the classes should be color coded (eg, vehicles in red, pedestrians in blue, cyclist in green).

In [None]:
def display_instances(batch):
    """
    This function takes a batch from the dataset and display the image with 
    the associated bounding boxes.
    """
    for elem in batch.as_numpy_iterator():
        image = elem['image']
        figw, figh = len(image), len(image[0])
        px = 1/pyplot.rcParams['figure.dpi']
        fig, ax = pyplot.subplots(figsize=(figw*px, figh*px))
        img = ax.imshow(image, interpolation='nearest', extent=(0, figw, figh, 0))
        for i, box in enumerate(elem['groundtruth_boxes']):
            boxclass = elem['groundtruth_classes'][i]
            box = [a * b for a, b in zip(box, [figw, figh, figw, figh])]
            anchor = (box[1], box[0])
            height = box[2] - box[0]
            width = box[3] - box[1]
            color = {1: 'r', 2: 'b', 4: 'g'}[boxclass]

            patch = patches.Rectangle(anchor, width, height, linewidth=1, edgecolor=color, facecolor='none')
            ax.add_patch(patch)
        pyplot.axis('off')
        
        pyplot.show()

        

## Display 10 images 

Using the dataset created in the second cell and the function you just coded, display 10 random images with the associated bounding boxes. You can use the methods `take` and `shuffle` on the dataset.

In [None]:
%matplotlib inline
display_instances(dataset.shuffle(10).take(10))

## Additional EDA

In this last part, you are free to perform any additional analysis of the dataset. What else would like to know about the data?
For example, think about data distribution. So far, you have only looked at a single file...

In [None]:
def display_random_images(paths):
    """
    Picks and displays 10 random images from the dataset.
    """
    images = []
    for path in random.sample(paths, k=10):
        dataset = get_dataset(path)
        for elem in dataset.shuffle(10).take(1).as_numpy_iterator():
            images.append(elem['image'])
    
    figw, figh = len(images[0]), len(images[0][0])
    px = 1/pyplot.rcParams['figure.dpi']
    fig, plots = pyplot.subplots(2, 5, figsize=(figw*px*5, figh*px*2))
    
    for index, image in enumerate(images):
        plot = plots[int(index / 5)][index % 5]
        plot.imshow(image, interpolation='nearest', extent=(0, figw, figh, 0))
        plot.axis('off')
        
    pyplot.show()


In [None]:
paths = []
for filename in os.listdir('/data/processed'):
    paths.append(f'/data/processed/{filename}')
    
display_random_images(paths)

In [None]:
def sample_from(batch, samples):
    """
    This function gathers data by traversing the dataset
    and calculating specific metrics including the classes
    in the groundtruth data, number and size of bounding
    boxes, and average brightness of the images in the batch.
    """
    if 'classes' not in samples:
        samples['classes'] = {}
    if 'boxes' not in samples:
        samples['boxes'] = []
    if 'sizes' not in samples:
        samples['sizes'] = []
    if 'brightness' not in samples:
        samples['brightness'] = []
    
    for elem in batch.as_numpy_iterator():
        samples['boxes'].append(len(elem['groundtruth_boxes']))
        
        image = elem['image']
        figw, figh = len(image), len(image[0])
        
        total = 0
        for line in image:
            for pixel in line:
                for channel in pixel:
                    total += channel
        
        samples['brightness'].append(total / (figw * figh * 3))
        
        for box in elem['groundtruth_boxes']:
            box = [a * b for a, b in zip(box, [figw, figh, figw, figh])]
            samples['sizes'].append((box[3] - box[1]) * (box[2] - box[0]))
        
        for c in list(elem['groundtruth_classes']):
            if c in samples['classes'].keys():
                samples['classes'][c] = samples['classes'][c] + 1
            else:
                samples['classes'][c] = 1
    return samples


In [None]:
def display_class_distribution(classes):
    """
    Displays a pie chart of the distribution of
    classes within the dataset.
    """
    total = 0
    for value in samples['classes'].values():
        total += value
    
    dist = {}
    for klass in samples['classes'].keys():
        percentage = samples['classes'][klass] / total * 100.0
        dist[['Unknown', 'Cars', 'Pedestrians', 'Unknown', 'Cyclists'][klass]] = percentage

    pyplot.pie(dist.values(), labels=dist.keys(), autopct='%1.1f%%')
    pyplot.title('Class Distribution')
    pyplot.axis('equal')
    pyplot.show()

def display_box_count_distribution(boxes):
    """
    Displays a histogram of the distribution of
    bounding boxes in each image in the dataset.
    """
    pyplot.hist(boxes, bins=len(boxes))
    pyplot.title('Box Count Distribution')
    pyplot.show()
    
def display_box_size_distribution(sizes):
    """
    Displays a histogram of the distribution of
    bounding box sizes appearing in the dataset.
    """
    pyplot.hist(sizes, bins=4)
    pyplot.title('Box Size Distribution')
    pyplot.show()
    
def display_image_brightness_distribution(brightness):
    """
    Displays a histogram of the distribution of the
    average brightness of images in the dataset.
    """
    pyplot.hist(brightness, bins=256)
    pyplot.title('Image Brightness Distribution')
    pyplot.show()

In [None]:
samples = {}
paths = os.listdir('/data/processed')
for index, path in enumerate(paths):
    dataset = get_dataset(f'/data/processed/{path}')
    samples = sample_from(dataset.shuffle(10).take(10), samples)

display_class_distribution(samples['classes'])
display_box_count_distribution(samples['boxes'])
display_box_size_distribution(samples['sizes'])
display_image_brightness_distribution(samples['brightness'])