# Acknowledgement!

### Adapted some codes from https://www.tensorflow.org/tutorials/load_data/tfrecord


# Get gsbucket PATH!

### Using TFrecords often requires gsbucket address. You could get the gsbucket address by using kaggle API as below:

In [None]:
from kaggle_datasets import KaggleDatasets
import tensorflow as tf
import os

gcsPATH = KaggleDatasets().get_gcs_path('tpu-getting-started')
gcsPATH



# Get the list of tfrecs from PATH! tf.io.gfile.glob 

### Next, we will get list of tfrecord to use [TFRecordDataset in tf.data API](http://https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset). Use [tf.io.gfile.glob](http://https://www.tensorflow.org/api_docs/python/tf/io/gfile/glob) to get files by querying with some general patterns. Accepted patterns are: 

pattern: { term }

term:

    '*': matches any sequence of non-'/' characters
    '?': matches a single non-'/' character
    '[' [ '^' ] { match-list } ']': matches any single character (not) on the list
    c: matches character c where c != '*', '?', '\\', '['
    '\\' c: matches character c

character range:

    c: matches character c while c != '\\', '-', ']'
    '\\' c: matches character c
    lo '-' hi: matches character c for lo <= c <= hi



In [None]:
config = dict()
config['tfrec_shape'] = 512
ls_tfrecs = tf.io.gfile.glob(gcsPATH+"/*{}*/train/*tfrec".format(config['tfrec_shape'])) 
ls_tfrecs


# Show the format of TFREC! tf.train.Example() and ParseFromString()


### You need to know the feature desciption, essentially the key and the datatype to parse your TFREC to the dataset.
### Look at the TFREC with tf.train.Example() and ParseFromString(), and accessing [protobuf message](http://https://developers.google.com/protocol-buffers/)
###  Link to tutorial: [Protocol Buffer Basics: Python](http://https://developers.google.com/protocol-buffers/docs/pythontutorial)

In [None]:
ds_raw = tf.data.TFRecordDataset(ls_tfrecs)

for raw_record in ds_raw.take(1):
    example = tf.train.Example()
    example.ParseFromString(raw_record.numpy())

tfrec_dtype = {}

for key, feature in example.features.feature.items():
    kind = feature.WhichOneof('kind')
    tfrec_dtype[key] = kind

tfrec_dtype    

### Great! We now know all the info goes into feature description dictionary. Now define a function to get feature description:

In [None]:
def get_tfrec_format(dictionary_obj):
    tfrec_format= dict()
    for key, value in dictionary_obj.items():
        if value == "bytes_list":
            tf_dtype =  tf.string
        elif value == "int64_list":
            tf_dtype = tf.int64
        tfrec_format[key] = tf.io.FixedLenFeature([], tf_dtype)   
    return tfrec_format

get_tfrec_format(tfrec_dtype)

# Parse the raw data using the above key - datatype info!

### And parse the raw data with the above feature description:

In [None]:
ds_parsed = ds_raw.map(lambda raw: tf.io.parse_single_example(raw, get_tfrec_format(tfrec_dtype)))
ds_parsed

In [None]:
batch_size = 64
iter_ds = iter(ds_parsed)

# Decode bytes to jpeg and display sample image!

Now we want to show the data to see if everything is as expected:

In [None]:
import matplotlib.pyplot as plt
from tqdm import tqdm

n_col = 8
n_row = batch_size // n_col
FIGSIZE = (25,25)

figs, axs = plt.subplots(n_row, n_col, figsize= FIGSIZE)
for row in tqdm(range(n_row)):
    for col in range(n_col):
        next_item = next(iter_ds)
        axs[row,col].set_xticks([])
        axs[row,col].set_yticks([])
        axs[row,col].imshow(tf.io.decode_jpeg(next_item['image']))
        axs[row,col].title.set_text("id: {}, class: {}".format(next_item['id'].numpy().decode('ascii'), next_item['class']))


# Batch data and display sample! 

In [None]:
def get_labels(item, list_labels):
    return [item[key] for key in list_labels]

tr_ds = ds_parsed.map(lambda item: get_labels(item, ['class','id']))

next(iter(tr_ds.batch(32)))