In [None]:
'''
Purpose:
    1) tfrecord to be converted into npy format for pytorch running.
    2) label index is converted into index ranging from 0 to 1000.
'''

In [None]:
import os
import numpy as np
import pandas as pd
import glob
import tensorflow as tf

In [None]:
is_validate = True # True/False for validation/test sets, respectively.

In [None]:
input_dir = '/run/media/hoosiki/WareHouse3/mtb/datasets/VU/'
data_dir = input_dir + 'active_datasets/'
frame_dir = data_dir + 'frame/'

In [None]:
if is_validate == True:
    datatype = 'validate'
else:
    datatype = 'test'

In [None]:
file_paths = glob.glob(frame_dir + '{}*.tfrecord'.format(datatype))
out_dir   = data_dir + 'npy_formatted_frame/{}/'.format(datatype)
os.makedirs(out_dir, exist_ok=True)

In [None]:
df_vocab = pd.read_csv(input_dir + 'vocabulary.csv')
vocab_label2idx_dict = {0: 0}
for i, label in enumerate(df_vocab['Index']):
    vocab_label2idx_dict[label] = i+1

In [None]:
def parser(record):
    context_features = {
        'id': tf.FixedLenFeature([], tf.string),
        'labels': tf.VarLenFeature(tf.int64),
        'segment_start_times': tf.VarLenFeature(tf.int64),
        'segment_end_times': tf.VarLenFeature(tf.int64),
        'segment_labels': tf.VarLenFeature(tf.int64),
        'segment_scores': tf.VarLenFeature(tf.float32)        
    }
    sequence_features = {
        'rgb': tf.FixedLenSequenceFeature([], tf.string),
        'audio': tf.FixedLenSequenceFeature([], tf.string)
    }
    contexts, sequences = tf.parse_single_sequence_example(record,
                                                           context_features=context_features,
                                                           sequence_features=sequence_features)
    video_id = contexts['id']
    video_labels = contexts['labels']
    segment_start_times = contexts['segment_start_times']
    segment_end_times = contexts['segment_end_times']
    segment_labels = contexts['segment_labels']
    segment_scores = contexts['segment_scores']
    frame_rgb = tf.reshape(tf.cast(tf.decode_raw(sequences['rgb'], tf.uint8), tf.float32), [-1, 1024])
    frame_audio = tf.reshape(tf.cast(tf.decode_raw(sequences['audio'], tf.uint8), tf.float32), [-1, 128])
    return video_id, video_labels, segment_start_times, segment_end_times, segment_labels, segment_scores, frame_rgb, frame_audio

In [None]:
with tf.Session() as sess:
    for ifile in range(0, len(file_paths)):

        frame_lvl_record = frame_dir + '{}%04d.tfrecord'.format(datatype) % ifile
        print(frame_lvl_record)
    
        tf_dataset = tf.data.TFRecordDataset(frame_lvl_record)
        tf_dataset = tf_dataset.map(parser)
        iterator = tf_dataset.make_one_shot_iterator()
        next_element = iterator.get_next()
        try:
            while True:
                data_record = sess.run(next_element)
                dataset = dict()
                dataset['video_id'] = data_record[0].decode()
                dataset['video_labels'] = list(data_record[1].values)
                dataset['segment_start_times'] = list(data_record[2].values)
                dataset['segment_end_times'] = list(data_record[3].values)
                dataset['segment_labels'] = list(data_record[4].values)
                dataset['segment_scores'] = list(data_record[5].values)
                dataset['frame_rgb'] = list(data_record[6])
                dataset['frame_audio'] = list(data_record[7])
                for i, segment_label in enumerate(dataset['segment_labels']):
                    dataset['segment_labels'][i] = vocab_label2idx_dict[segment_label]                    
                np.save(out_dir + dataset['video_id'] + '.npy', np.array(dataset))
        except:
            pass