In [None]:
import torch
import torchvision
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

import os
import time
import numpy as np
import pickle
import matplotlib.pyplot as plt
import skimage.io as io
from skimage.transform import resize

import tensorflow as tf
from tfrecord.torch.dataset import TFRecordDataset

# Write TFrecord files

In [None]:
saves_folders = "../../../rl_data"
load_paths = [os.path.join(saves_folders, saves_folder) for saves_folder in os.listdir(saves_folders)]

load_path = "../../../rl_data/saves_1"
save_path = "../../../rl_data/tfrecord"

In [None]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def serialize_example(image, image_shape):
    feature = {
        'image': _bytes_feature(image),
        'height': _int64_feature(image_shape[0]),
        'width': _int64_feature(image_shape[1]),
        'depth': _int64_feature(image_shape[2]),
    }
    #  Create a Features message using tf.train.Example.
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

In [None]:
DI_SHAPE = (270, 480, 1)
nb_files = int(len([f for f in os.listdir(load_path) if f.endswith('.p') and os.path.isfile(os.path.join(load_path, f))]) / 5) # five dicts
print("NUMBER OF PICKLE STACKS", nb_files)
for k in range(nb_files):
    obs_load          = pickle.load(open( load_path + "/obs_dump" +str(k) + ".p", "rb"))
    di_load           = pickle.load(open( load_path + "/di_dump" + str(k) + ".p", "rb"))
    action_load       = pickle.load(open( load_path + "/action_dump" + str(k) + ".p", "rb"))
    action_index_load = pickle.load(open( load_path + "/action_index_dump" + str(k) + ".p", "rb"))
    collision_load    = pickle.load(open( load_path + "/collision_dump" + str(k) + ".p", "rb"))

    filename = save_path + '/data' + str(k) + '.tfrecords'
    N_episode = len(di_load)

    with tf.io.TFRecordWriter(filename) as writer:
        for i in range(N_episode):
            di_episode = di_load[i]
            N_images = len(di_episode)

            N_sample_append = 0
            is_first_collide_idx = False
            for j in range(N_images):
                di_current = di_episode[j]

                example = serialize_example(tf.io.serialize_tensor(di_current), DI_SHAPE)
                writer.write(example)

                # augment horizontally flip data
                # io.imshow(di_current[...,0] / 255)
                # io.show()
                di_flip = np.flip(di_current, 1)

                # io.imshow(di_flip[...,0] / 255)
                # io.show()

                # flip the omega_z # TODO what about other states?
                example_flip = serialize_example(tf.io.serialize_tensor(di_flip), DI_SHAPE)
                writer.write(example_flip)   


# Read TFrecord Files

In [None]:
tfrecord_path = "../../../rl_data/tfrecord"
tf_files = os.listdir(tfrecord_path)
tf_files_full = [os.path.join(tfrecord_path, file) for file in tf_files]

In [None]:
raw_dataset = tf.data.TFRecordDataset(tf_files_full)
raw_dataset

In [None]:
for raw_record in raw_dataset.take(1):
    print(repr(raw_record))

In [None]:
tfrecord_path = tf_files_full[0]
index_path = None
import cv2

description = {
    "image": "byte", 
    "height": "int",
    "width": "int",
    "depth": "int"
}

def decode_image(features):
    # get BGR image from bytes
    features["image"] = cv2.imdecode(features["image"], -1)
    return features

dataset = TFRecordDataset(tfrecord_path, index_path=None, description=description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)

data = next(iter(loader))
for i in range(32):
    print(data["image"][i].shape)