In [None]:
## reference: https://www.tensorflow.org/tutorials/load_data/tfrecord
import numpy as np
import pandas as pd
import tensorflow as tf
from tqdm import tqdm

In [None]:
from sklearn.model_selection import StratifiedKFold
df = pd.read_csv('/kaggle/input/dog-breed-identification/labels.csv')
df['image_path'] = '/kaggle/input/dog-breed-identification/train/' + df['id'] + '.jpg'
df['label'] = df.groupby(['breed'], sort=True).ngroup()


kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=2021)
for fold, (train_index, val_index) in enumerate(kf.split(df.values, df["breed"])):
    df.loc[val_index, "fold"] = int(fold)
df["fold"] = df["fold"].astype(int)
df['shard'] = np.random.randint(0, 5, (len(df)))
df

# Write

In [None]:
# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.

def _bytes_feature(value):
    '''Returns a bytes_list from a string / byte.'''
    if isinstance(value, type(tf.constant(0))):
    #if isinstance(value, tf.python.framework.ops.EagerTensor):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    '''Returns a float_list from a float / double.'''
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    '''Returns an int64_list from a bool / enum / int / uint.'''
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [None]:
def serialize_example(image,image_id,label,breed):
    feature = {
        'image': _bytes_feature(image),
        'image_id': _bytes_feature(image_id),
        'label': _int64_feature(label),
        'breed': _bytes_feature(breed),
      }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

In [None]:
def write_tfrecord(df: pd.DataFrame, filename: str):
    options = tf.io.TFRecordOptions("GZIP")
    with tf.io.TFRecordWriter(filename, options=options) as writer:
        for i in tqdm(range(len(df))):
            image = tf.io.read_file(df.iloc[i]['image_path'])
            image_id = str.encode(df.iloc[i]['id'])
            label = df.iloc[i]['label']
            breed = str.encode(df.iloc[i]['breed'])
            #wave = np.load(npy_path).tobytes()
            
            tf_example = serialize_example(image,image_id,label,breed)
            writer.write(tf_example)

In [None]:
#from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
#with ProcessPoolExecutor(max_workers=8) as executor:
#    futures = [executor.submit(write_tfrecord, df[df['fold']==i], f'dogbreed_train_fold{i}.tfrec') for i in tqdm(range(5))]
import joblib
_ = joblib.Parallel(n_jobs=8)(
        joblib.delayed(write_tfrecord)(df[(df['fold']==i) & (df['shard']==j)], f'dogbreed_train_fold{i}_{j}.tfrec') for i in range(5) for j in range(5)
    )

#for i in range(5):
#    write_tfrecord(df[df['fold']==i], f'dogbreed_train_fold{i}.tfrec')

# Read 

In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    return image

def read_labeled_tfrecord(example):
    tfrec_format = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "image_id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        'label': tf.io.FixedLenFeature([], tf.int64),
        'breed': tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(example, tfrec_format)
    image = decode_image(example['image'])
    label = example['label']
    return image, label

def preprocess(image, label):
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.image.resize(image,[256, 256])
    return image, label

In [None]:
import glob
import matplotlib.pyplot as plt

files = glob.glob('dogbreed_train_fold*.tfrec')
AUTO = tf.data.experimental.AUTOTUNE
dataset = tf.data.TFRecordDataset(files, num_parallel_reads=AUTO, compression_type="GZIP")
dataset = dataset.map(read_labeled_tfrecord)
dataset = dataset.map(preprocess)

for d in dataset.take(5):
    image = d[0].numpy()
    plt.imshow(image)
    plt.show()