## Creating TFRecords
- https://keras.io/examples/keras_recipes/creating_tfrecords/#define-tfrecords-helper-functions

In [None]:
import glob
import random
import os
import math

import pandas as pd
import numpy as np
import matplotlib.pylab as plt
from tqdm.notebook import tqdm
import tensorflow as tf
from tensorflow.data import AUTOTUNE
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
from sklearn.model_selection import train_test_split, GroupShuffleSplit
from kaggle_datasets import KaggleDatasets

plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = (12, 10)

In [None]:
!ls ../input/happy-whale-and-dolphin

In [None]:
train_paths = glob.glob('../input/happy-whale-and-dolphin/train_images/*')
test_paths = glob.glob('../input/happy-whale-and-dolphin/test_images/*')   
test_labels = np.repeat(-1, len(test_paths))
print('Number of train images:', len(train_paths))
print('Number of test images:', len(test_paths))

In [None]:
train = pd.read_csv('../input/happy-whale-and-dolphin/train.csv')
ind2label = {x: i for i, x in enumerate(train['individual_id'].unique())}
label2ind = {x[1]: x[0] for x in ind2label.items()}
train['label'] = train['individual_id'].map(ind2label)

# thanks to https://www.kaggle.com/c/happy-whale-and-dolphin/discussion/305574
train['species'] = train['species'].replace({
    "globis": "short_finned_pilot_whale",
    "pilot_whale": "short_finned_pilot_whale",
    "kiler_whale": "killer_whale",
    "bottlenose_dolpin": "bottlenose_dolphin"
})

In [None]:
print('Number of unique species:', train['species'].unique().shape[0])

In [None]:
ax = train['species'].value_counts(ascending=True).plot.barh()
ax.bar_label(ax.containers[0]);

In [None]:
ax = train['species'].value_counts(ascending=True, normalize=True).plot.barh()
ax.bar_label(ax.containers[0]);

## Splitting

In [None]:
class CFG:
    seed = 42
    val_size = 0.25
    img_height = 128
    img_width = 128

In [None]:
train_imgs, val_imgs, train_labels, val_labels = train_test_split(
    train['image'],
    train['label'],
    test_size=CFG.val_size,
    random_state=CFG.seed,
    stratify=train['species']
)

# https://stackoverflow.com/questions/54797508/how-to-generate-a-train-test-split-based-on-a-group-id
# tr, val = next(
#     GroupShuffleSplit(
#         n_splits=2,
#         test_size=Config.val_size,
#         random_state=Config.seed
#     ).split(train['image'], groups=train['label'])
# )
# train_imgs, val_imgs = train['image'].iloc[tr], train['image'].iloc[val]
# train_labels, val_labels = train['label'].iloc[tr], train['label'].iloc[val]

train_imgs = ('../input/happy-whale-and-dolphin/train_images/' + train_imgs).values
val_imgs = ('../input/happy-whale-and-dolphin/train_images/' + val_imgs).values
train_labels = train_labels.values
val_labels = val_labels.values

train_imgs.shape, val_imgs.shape, train_labels.shape, val_labels.shape

In [None]:
print(len(np.unique(train_labels)))
print(len(np.unique(val_labels)))

In [None]:
len(set(train_labels) & set(val_labels))

### Defining TFRecords helper functions

In [None]:
num_samples = 4096

def calculate_num_tf_records(paths):
    num_tfrecords = len(paths) // num_samples
    if len(paths) % num_samples:
        num_tfrecords += 1  # add one record if there are any remaining samples
    return num_tfrecords

def image_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()])
    )

def bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode()]))

def float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def float_feature_list(value):
    """Returns a list of float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def create_example(image, label, path):
    feature = {
        "image": image_feature(image),
        "label": int64_feature(label),
        "path": bytes_feature(path)
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

In [None]:
train_len_records = calculate_num_tf_records(train_imgs)
val_len_records = calculate_num_tf_records(val_imgs)
test_len_records = calculate_num_tf_records(test_paths)
print(train_len_records, val_len_records, test_len_records)

### Writing TFRecords

In [None]:
def write_to_tfrecords(labels, paths, len_records: int, prefix: str):
    for tfrec_num in range(len_records):
        start, end = (tfrec_num * num_samples), ((tfrec_num + 1) * num_samples)
        sample_labels = labels[start:end]
        sample_paths = paths[start:end]
        filename = prefix + "-file_%.2i-%i.tfrec" % (tfrec_num, len(sample_paths))
        with tf.io.TFRecordWriter(filename) as writer:
            for sample_label, sample_path in zip(sample_labels, sample_paths):
                image = tf.io.decode_jpeg(tf.io.read_file(sample_path))
                image = tf.cast(image, tf.float32)
                image = tf.image.resize(image, size=(CFG.img_height, CFG.img_width))
                image = tf.cast(image, tf.uint8)
                example = create_example(image, sample_label, sample_path)
                writer.write(example.SerializeToString())
        print('Wrote', filename)

In [None]:
%%time

write_to_tfrecords(
    labels=train_labels,
    paths=train_imgs,
    len_records=train_len_records,
    prefix='train'
)

In [None]:
%%time

write_to_tfrecords(
    labels=val_labels,
    paths=val_imgs,
    len_records=val_len_records,
    prefix='val'
)

In [None]:
%%time

write_to_tfrecords(
    labels=test_labels,
    paths=test_paths,
    len_records=test_len_records,
    prefix='test'
)

### Reading TFRecords

In [None]:
def read_tfrecord(example):
    features = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64),  
        'path': tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, features)
    image = tf.io.decode_jpeg(example['image'], channels=3)
    label = example['label']
    path = example['path']
    return image, label, path

In [None]:
filenames = tf.io.gfile.glob('./train*.tfrec')
ds = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
ds = ds.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
ds = ds.shuffle(300, seed=1)
print(ds)

In [None]:
# fast read!
for image, label, path in ds.take(3):
    image = image.numpy()
    label = label.numpy()
    path = path.numpy()
    
    plt.imshow(image)
    plt.title(f'label: {label}\npath: {path.decode().split("/")[-1]}')
    plt.show();