In [None]:
if 'google.colab' in str(get_ipython()):
  print('Running on Colab')
else:
  print('Not running on Colab')
  assert(False)

# Download Kaggle Dataset

In [None]:
!pip install --upgrade --force-reinstall --no-deps --quiet kaggle

[?25l[K     |█████▋                          | 10 kB 29.5 MB/s eta 0:00:01[K     |███████████▏                    | 20 kB 9.7 MB/s eta 0:00:01[K     |████████████████▊               | 30 kB 8.3 MB/s eta 0:00:01[K     |██████████████████████▎         | 40 kB 3.6 MB/s eta 0:00:01[K     |███████████████████████████▉    | 51 kB 4.1 MB/s eta 0:00:01[K     |████████████████████████████████| 58 kB 3.0 MB/s 
[?25h  Building wheel for kaggle (setup.py) ... [?25l[?25hdone


In [None]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle competitions download -c happy-whale-and-dolphin -p /content/happywhale

In [None]:
!unzip -q /content/happywhale/*.zip -d /content/happywhale/

In [None]:
!rm /content/happywhale/*.zip

# Create TF Dataset

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import glob
import tensorflow as tf
from tensorflow.data import Dataset

In [None]:
train_csv_path = '/content/happywhale/train.csv'
train_df = pd.read_csv(train_csv_path)

In [None]:
train_paths = glob.glob("/content/happywhale/train_images/*.jpg")
train_img_names = [path.split(os.path.sep)[-1] for path in train_paths]
train_species_names = (train_df.set_index('image').loc[train_img_names])['species']
train_ids = (train_df.set_index('image').loc[train_img_names])['individual_id']

In [None]:
ds = Dataset.from_tensor_slices((train_paths, train_img_names, train_species_names, train_ids))
N = tf.data.experimental.cardinality(ds).numpy()
val_size = int(N * 0.2)
train_ds = ds.skip(val_size)
val_ds = ds.take(val_size)

IMG_HEIGHT = 128
IMG_WIDTH = 128
def get_family_name(species_name):
    parts = tf.strings.split(species_name, '_')
    if (parts[-1] == b'whale') or \
       (parts[-1] == b'beluga') or \
       (parts[-1] == b'globis'):
       family_name = 'whale'
    elif (parts[-1] == b'dolphin') or \
         (parts[-1] == b'dolpin'):
         family_name = 'dolphin'
    else:
        family_name = 'unknown'
    return family_name

def load_img(path):
    img = tf.io.read_file(path)
    img = tf.io.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, [IMG_HEIGHT, IMG_WIDTH])
    img = tf.cast(img, tf.uint8)
    return img

def cast_to_float(img):
    return tf.cast(img, tf.float32) / 255.

ds = ds.map(lambda w,x,y,z: {'image': load_img(w),
                             'image_name': x,
                             'species_name': y,
                             'individual_id': z,
                             'family_name': get_family_name(y)}, num_parallel_calls=tf.data.AUTOTUNE)

train_ds = train_ds.map(lambda w,x,y,z: {'image': load_img(w),
                                         'image_name': x,
                                         'species_name': y,
                                         'individual_id': z,
                                         'family_name': get_family_name(y)}, num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.map(lambda w,x,y,z: {'image': load_img(w),
                                     'image_name': x,
                                     'species_name': y,
                                     'individual_id': z,
                                     'family_name': get_family_name(y)}, num_parallel_calls=tf.data.AUTOTUNE)

# Cache as 8bit int to save on speed (loading very large images is slow)
train_ds = train_ds.cache()
val_ds = val_ds.cache()

# Then convert to 32bit float
train_ds = train_ds.map(lambda data: {'image': cast_to_float(data['image']),
                                      'image_name': data['image_name'],
                                      'species_name': data['species_name'],
                                      'individual_id': data['individual_id'],
                                      'family_name': data['family_name']})

val_ds = val_ds.map(lambda data: {'image': cast_to_float(data['image']),
                                  'image_name': data['image_name'],
                                  'species_name': data['species_name'],
                                  'individual_id': data['individual_id'],
                                  'family_name': data['family_name']})

In [None]:
fig, axs = plt.subplots(3,3, figsize=(10,10)); axs = axs.flatten()
for sample,ax in zip(train_ds.take(9),axs):
    ax.imshow(sample['image'])
    ax.set_title(sample['species_name'].numpy().decode())

# Create TFRecords

In [None]:
import os, json, random, cv2
import numpy as np, pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf, re, math
from tqdm import tqdm

In [None]:
def _bytes_feature(value):
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # EagerTensor unpackable
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _bytes_array_feature(nonscalar):
    if isinstance(nonscalar, type(tf.constant(0))):
        nonscalar = nonscalar.numpy()
    serialized_nonscalar = tf.io.serialize_tensor(nonscalar)
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[serialized_nonscalar.numpy()]))

def serialize_example(image, image_name, species_name, individual_id, family_name):
    feature = {
        'image': _bytes_array_feature(image),
        'image_name': _bytes_feature(image_name),
        'species_name': _bytes_feature(species_name),
        'individual_id': _bytes_feature(individual_id),
        'family_name': _bytes_feature(family_name)
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

def tf_serialize_example(data_dict):
    tf_string = tf.py_function(
        serialize_example,
        (data_dict['image'],
         data_dict['image_name'],
         data_dict['species_name'],
         data_dict['individual_id'],
         data_dict['family_name']),
         tf.string
    )
    return tf.reshape(tf_string, ()) # Res is scalar

feature_description = {
    'image': tf.io.FixedLenFeature([], tf.string, default_value=''),
    'image_name': tf.io.FixedLenFeature([], tf.string, default_value=''),
    'species_name': tf.io.FixedLenFeature([], tf.string, default_value=''),
    'individual_id': tf.io.FixedLenFeature([], tf.string, default_value=''),
    'family_name': tf.io.FixedLenFeature([], tf.string, default_value=''),
}

def _parse_function(example_proto):
    # .parse_example if batched
    parsed_ex = tf.io.parse_single_example(example_proto, feature_description)
    parsed_ex['image'] = tf.io.parse_tensor(parsed_ex['image'], tf.uint8)
    return parsed_ex

serial_ex = tf_serialize_example(next(iter(ds)))
print(serial_ex)
print(_parse_function(serial_ex))

In [None]:
serialized_ds = ds.map(tf_serialize_example)

In [None]:
fname = 'happywhale.tfrecord'
writer = tf.data.experimental.TFRecordWriter(fname)
writer.write(serialized_ds)

## Read TFRecord

In [None]:
fname = ['/content/drive/MyDrive/happywhale.tfrecord']
raw_ds = tf.data.TFRecordDataset(fname)
raw_ds

In [None]:
parsed_ds = raw_ds.map(_parse_function)

In [None]:
next(iter(parsed_ds))