In [1]:
%load_ext autoreload
%autoreload 2
import tensorflow as tf
from video_prediction import datasets

## Example: bair

In [None]:
dataset = 'bair'
input_dir = '/data/vision/phillipi/gen-models/video_prediction/data/bair'
dataset_hparams = 'use_state=True'
batch_size = 16

In [None]:
VideoDataset = datasets.get_dataset_class(dataset)
train_dataset = VideoDataset(
    input_dir,
    mode='train',
    hparams=dataset_hparams)
train_tf_dataset = train_dataset.make_dataset(batch_size)
train_iterator = train_tf_dataset.make_one_shot_iterator()
train_handle = train_iterator.string_handle()
iterator = tf.data.Iterator.from_string_handle(
    train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes)
inputs = iterator.get_next()

In [None]:
print(inputs.keys())
sess = tf.Session()
x = sess.run(inputs)
print(x['images'].shape, x['states'].shape, x['actions'].shape)

---

## Helper function for read/write tfrecord

In [2]:
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _bytes_list_feature(values):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _floats_list_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def save_tf_record(output_fname, sequences):
    print('saving sequences to %s' % output_fname)
    with tf.python_io.TFRecordWriter(output_fname) as writer:
        for (images, metadata, push_name) in sequences:
            assert len(images) == len(metadata)
            num_frames = len(images)
            encoded_images = [image.tostring() for image in images]
            states = np.concatenate((metadata[:, :2], np.zeros((num_frames, 1))), axis=1).reshape(-1)
            actions = np.repeat(metadata[:-1, 2], 4)
            features = tf.train.Features(feature={
                'images/encoded': _bytes_list_feature(encoded_images),
                'actions': _floats_list_feature(actions),
                'states': _floats_list_feature(states),
                'sequence_length': _int64_feature(num_frames),
                'push_name': _bytes_feature(push_name)
            })
            example = tf.train.Example(features=features)
            writer.write(example.SerializeToString())

---

## Example: omnipush_stitch

In [3]:
import glob
import os
import numpy as np
from skimage.io import imread
dataset = 'omnipush_1_weight_stitch'
dataset_dir = '/data/vision/phillipi/gen-models/svg/dataset/{}/'.format(dataset)
output_dir = '/data/vision/phillipi/gen-models/video_prediction/data/{}/'.format(dataset)
splits = ['train', 'val']
for split in splits:
    split_path = os.path.join(output_dir, split)
    if not os.path.exists(split_path):
        os.makedirs(split_path)

## Write tfrecord

In [4]:
n_trajs = 0
splits = ['train', 'test']
for split in splits:
    shape_names = os.listdir(os.path.join(dataset_dir, '{}/'.format(split)))
    for shape_name in shape_names:
        sequences = []
        dnames = glob.glob(os.path.join(dataset_dir, '{}/{}/**/'.format(split, shape_name)))
        for dname in dnames:
            n_images = len(glob.glob(os.path.join(dname, '*.png')))
            images = []
            for i in range(n_images):
                fname = os.path.join(dname, '{}.png'.format(i))
                images.append(imread(fname))
            metadata = np.load(os.path.join(dname, 'actions.npy'))
            push_name = str.encode(dname.split('/')[-2])
            sequences.append((images, metadata, push_name))
            n_trajs += 1
        output_split = 'val' if split == 'test' else 'train'
        save_tf_record('/data/vision/phillipi/gen-models/video_prediction/data/{}/{}/{}.tfrecord'.format(dataset, output_split, shape_name), sequences)
        
print(n_trajs)

saving sequences to /data/vision/phillipi/gen-models/video_prediction/data/omnipush_1_weight_stitch/train/3a3a3a3c.tfrecord
saving sequences to /data/vision/phillipi/gen-models/video_prediction/data/omnipush_1_weight_stitch/train/3a4a3c4a.tfrecord
saving sequences to /data/vision/phillipi/gen-models/video_prediction/data/omnipush_1_weight_stitch/train/1a3c4a4a.tfrecord
saving sequences to /data/vision/phillipi/gen-models/video_prediction/data/omnipush_1_weight_stitch/train/2C4a4a4a.tfrecord
saving sequences to /data/vision/phillipi/gen-models/video_prediction/data/omnipush_1_weight_stitch/train/1a2a4a3c.tfrecord
saving sequences to /data/vision/phillipi/gen-models/video_prediction/data/omnipush_1_weight_stitch/train/2C4a3a2a.tfrecord
saving sequences to /data/vision/phillipi/gen-models/video_prediction/data/omnipush_1_weight_stitch/train/1a3a4a2c.tfrecord
saving sequences to /data/vision/phillipi/gen-models/video_prediction/data/omnipush_1_weight_stitch/train/1a4a1b4a.tfrecord
saving s

saving sequences to /data/vision/phillipi/gen-models/video_prediction/data/omnipush_1_weight_stitch/val/1a3a3a2B.tfrecord
saving sequences to /data/vision/phillipi/gen-models/video_prediction/data/omnipush_1_weight_stitch/val/3a4a3b4a.tfrecord
saving sequences to /data/vision/phillipi/gen-models/video_prediction/data/omnipush_1_weight_stitch/val/1c4a2a3a.tfrecord
saving sequences to /data/vision/phillipi/gen-models/video_prediction/data/omnipush_1_weight_stitch/val/3a4a3c4a.tfrecord
saving sequences to /data/vision/phillipi/gen-models/video_prediction/data/omnipush_1_weight_stitch/val/3a3a3a3c.tfrecord
saving sequences to /data/vision/phillipi/gen-models/video_prediction/data/omnipush_1_weight_stitch/val/1a3c4a4a.tfrecord
saving sequences to /data/vision/phillipi/gen-models/video_prediction/data/omnipush_1_weight_stitch/val/2C4a4a4a.tfrecord
saving sequences to /data/vision/phillipi/gen-models/video_prediction/data/omnipush_1_weight_stitch/val/1a2a4a3c.tfrecord
saving sequences to /dat

## Load tfrecord

In [5]:
dataset_class = 'omnipush'
input_dir = '/data/vision/phillipi/gen-models/video_prediction/data/{}'.format(dataset)
print(input_dir)
dataset_hparams = 'use_state=True'
batch_size = 16

/data/vision/phillipi/gen-models/video_prediction/data/omnipush_1_weight_stitch


In [8]:
VideoDataset = datasets.get_dataset_class(dataset_class)
train_dataset = VideoDataset(
    input_dir,
    mode='train',
    hparams=dataset_hparams)

train_tf_dataset = train_dataset.make_dataset(batch_size)
train_iterator = train_tf_dataset.make_one_shot_iterator()
train_handle = train_iterator.string_handle()
iterator = tf.data.Iterator.from_string_handle(
    train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes)
inputs = iterator.get_next()

In [9]:
print(inputs.keys())

sess = tf.Session()
x = sess.run(inputs)
print(x['images'].shape, x['states'].shape, x['actions'].shape, x['push_name'])

odict_keys(['images', 'states', 'push_name', 'actions'])
(16, 24, 64, 64, 3) (16, 24, 3) (16, 23, 4) [[b'motion_surface=abs_shape=1c4a2a3a_v=50_rep=0004_push=0000_t=0.849']
 [b'motion_surface=abs_shape=1a1c4a3a_v=50_rep=0011_push=0000_t=-2.155']
 [b'motion_surface=abs_shape=1a1c4a3a_v=50_rep=0014_push=0000_t=-2.472']
 [b'motion_surface=abs_shape=2B3a2a4a_v=50_rep=0050_push=0000_t=-0.421']
 [b'motion_surface=abs_shape=1a1a1c4a_v=50_rep=0019_push=0000_t=-0.882']
 [b'motion_surface=abs_shape=1a3b4a4a_v=50_rep=0039_push=0000_t=1.004']
 [b'motion_surface=abs_shape=1a3a3a2C_v=50_rep=0044_push=0000_t=0.893']
 [b'motion_surface=abs_shape=2B3a2a4a_v=50_rep=0000_push=0000_t=-0.854']
 [b'motion_surface=abs_shape=1a3a2b3a_v=50_rep=0056_push=0000_t=-2.361']
 [b'motion_surface=abs_shape=1a1c2a3a_v=50_rep=0030_push=0000_t=-2.440']
 [b'motion_surface=abs_shape=2C2a2a4a_v=50_rep=0041_push=0000_t=1.391']
 [b'motion_surface=abs_shape=1a3b4a4a_v=50_rep=0053_push=0000_t=0.362']
 [b'motion_surface=abs_shape

## Visualize data from tfrecord

In [None]:
import imageio
from IPython.display import Image
from IPython.display import display
from IPython.display import clear_output
def inspect_seq(seq):
    clear_output(wait=True)
    if os.path.exists('./tmp.gif'):
        os.remove('./tmp.gif')
    imageio.mimsave('./tmp.gif', seq)
    with open('./tmp.gif','rb') as f:
        display(Image(data=f.read(), format='gif', width=200, height=200))

idx = 2
inspect_seq(x['images'][idx])
print(x['actions'][idx, :, 0] / np.pi * 180)

## Write tfrecord with different actions

In [None]:
modify_action = True
n_trajs = 0
splits = ['test']
for split in splits:
    shape_names = os.listdir(os.path.join(dataset_dir, '{}/'.format(split)))
    for shape_name in shape_names:
        sequences = []
        dnames = glob.glob(os.path.join(dataset_dir, '{}/{}/**/'.format(split, shape_name)))
        for dname in dnames:
            n_images = len(glob.glob(os.path.join(dname, '*.png')))
            images = []
            for i in range(n_images):
                fname = os.path.join(dname, '{}.png'.format(i))
                images.append(imread(fname))
            metadata = np.load(os.path.join(dname, 'actions.npy'))
            # Modify the action here
            if modify_action:
                # Note: change all metadata to zeros!!!
                metadata = np.zeros(metadata.shape)
                len_push = metadata.shape[0]
                metadata[:len_push//2, 2] = np.linspace(-0.5*np.pi, 0.5 * np.pi, num=len_push//2)
                metadata[len_push//2:, 2] = np.linspace(0.5*np.pi, -0.5 * np.pi, num=len_push//2)
            sequences.append((images, metadata))
            n_trajs += 1
        output_split = 'val_actions=S' if split == 'test' else 'train'
        path = '/data/vision/phillipi/gen-models/video_prediction/data/{}/{}'.format(dataset, output_split)
        if not os.path.exists(path):
            os.makedirs(path)
        save_tf_record('/data/vision/phillipi/gen-models/video_prediction/data/{}/{}/{}.tfrecord'.format(dataset, output_split, shape_name), sequences)
        
print(n_trajs)