In [None]:
# Exercise 3 - Create tf records

## Objective
The goal of this exercise is to make you familiar with the tf record format. In particular, 
your job is to convert the data from the Waymo Open Dataset into a new tf record format that we 
will use for the final project, as there is a difference between the format used for the 
Waymo Open Dataset and that used by the TensorFlow Object Detection API.

## Details

You can read more about the Waymo Open Dataset data format [here](https://waymo.com/open/data/perception/). 
Each tf record files contains the data for an entire trip made by the car, meaning that 
it contains images from the different cameras as well as LIDAR data. Because we want 
to keep our dataset small, we are implementing the `create_tf_example` function to 
create cleaned tf records files.

We are using the Waymo Open Dataset github repository to parse the raw tf record files. 
would recommend to follow [this tutorial](https://github.com/waymo-research/waymo-open-dataset) 
to better understand the data format before diving into this exercise. 

## Tips

This [document](https://github.com/Jossome/Waymo-open-dataset-document) provides
an overview of the dataset structure.

Later on, we will leverage the Tensorflow Object Detection API to train Object Detection models.
In the API tutorial, you can find an example of `create_tf_example` [here](https://tensorflow-object-detection-api-tutorial.readthedocs.io/en/latest/training.html#create-tensorflow-records).

Note that running the code will require the use of the example `.tfrecord` included in the `/home/workspace` directory. You will need the GPU enabled (see bottom left of workspace) for the appropriate libraries to be available in the workspace as well.

In [None]:
import io
import os
import argparse
import logging

import tensorflow.compat.v1 as tf
from PIL import Image
from waymo_open_dataset import dataset_pb2 as open_dataset

from utils import parse_frame, int64_feature, int64_list_feature, bytes_feature
from utils import bytes_list_feature, float_list_feature


def create_tf_example(filename, encoded_jpeg, annotations):
    """
    convert to tensorflow object detection API format
    args:
    - filename [str]: name of the image
    - encoded_jpeg [bytes-likes]: encoded image
    - annotations [list]: bboxes and classes
    returns:
    - tf_example [tf.Example]
    """
    # TO BE IMPLEMENTED  
    encoded_jpg_io = io.BytesIO(encoded_jpeg)
    image = Image.open(encoded_jpg_io)
    width, height = image.size
    
    mapping = {
        1: 'vehicle',
        2: 'pedestrian',
        4: 'cyclist'
    }
    
    image_format = b'jpg'
    
    xmins = []
    xmaxs = []
    ymins = []
    ymaxs = []
    
    classes_text = []
    classes = []
    
    filename = filename.encode('utf8')
    
    for ann in annotations:
        xmin, ymin = ann.box.center_x - 0.5 * ann.box.length, ann.box.center_y - 0.5 * ann.box.width       
        xmax, ymax = ann.box.center_x + 0.5 * ann.box.length, ann.box.center_y + 0.5 * ann.box.width
        
        xmins.append(xmin / width)
        xmaxs.append(xmax / width)
        ymins.append(ymin / height)
        ymaxs.append(ymax / height)
        
        classes.append(ann.type)
        classes_text.append(mapping[ann.type].encode('utf8'))
        
        tf_example = tf.train.Example(features = tf.train.Features(feature={
            'image/height': int64_feature(height),
            'image/width': int64_feature(width),
            'image/filename': bytes_feature(filename),
            'image/source_id': bytes_feature(filename),
            'image/encoded': bytes_feature(encoded_jpeg),
            'image/format': bytes_feature(image_format),    
            'image/object/bbox/xmin': float_list_feature(xmins),
            'image/object/bbox/xmax': float_list_feature(xmaxs),            
            'image/object/bbox/ymin': float_list_feature(ymins),            
            'image/object/bbox/ymax': float_list_feature(ymaxs),                        
            'image/object/class/text': bytes_list_feature(classes_text),                        
            'image/object/class/label': int64_list_feature(classes)                                    
        }))
    
    return tf_example


def process_tfr(path):
    """
    process a waymo tf record into a tf api tf record
    """
    # create processed data dir
    file_name = os.path.basename(path)

    logging.info(f'\nProcessing {path}\n')
    
    writer = tf.python_io.TFRecordWriter(f'output/{file_name}')
    
    dataset = tf.data.TFRecordDataset(path, compression_type='')    
    for idx, data in enumerate(dataset):
        frame = open_dataset.Frame()
        frame.ParseFromString(bytearray(data.numpy()))        
        encoded_jpeg, annotations = parse_frame(frame)
        
        filename = file_name.replace('.tfrecord', f'_{idx}.tfrecord')
        tf_example = create_tf_example(filename, encoded_jpeg, annotations)
        
        writer.write(tf_example.SerializeToString())
        
    writer.close()
    
    logging.info(f'\nFinish {path}\n')    


if __name__ == "__main__": 
    parser = argparse.ArgumentParser()
    parser.add_argument('-p', '--path', required=True, type=str,
                        help='Waymo Open dataset tf record')
    args = parser.parse_args()  
    process_tfr(args.path)