In [None]:
import numpy as np, pandas as pd, os
import tensorflow as tf, re, math
from tqdm.notebook import tqdm
import random
from sklearn.model_selection import train_test_split

In [None]:
def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [None]:
def serialize_example(feature0, feature1):
    feature = {
      'image': _float_feature(feature0),
      'label': _int64_feature(feature1),
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

In [None]:
def write_tfrecord(file_list, output):
    with tf.io.TFRecordWriter(output) as writer:
        print('Writing TFRecord ...')
        for image_path, label in tqdm(file_list):
            img = np.load(image_path)
            if img.shape[1] != 150 or img.shape[2] != 150: print(img.shape)
            img = img.reshape(-1)
            example = serialize_example(img, label)
            writer.write(example)
    print('Finished!')

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/Colab\ Notebooks/ML4SCI

In [None]:
# Download Dataset
!gdown http://drive.google.com/uc?id=1B_UZtU4W65ZViTJsLeFfvK-xXCYUhw2A
!unzip -q dataset.zip

In [None]:
train_path1 = './dataset/train/no'
train_files1 = [(os.path.join(train_path1, f), 0) for f in os.listdir(train_path1) if f.endswith(".npy")]
train_path2 = './dataset/train/sphere'
train_files2 = [(os.path.join(train_path2, f), 1) for f in os.listdir(train_path2) if f.endswith(".npy")]
train_path3 = './dataset/train/vort'
train_files3 = [(os.path.join(train_path3, f), 2) for f in os.listdir(train_path3) if f.endswith(".npy")]

train_files = train_files1 + train_files2 + train_files3
random.shuffle(train_files)

train, test = train_test_split(train_files, test_size=0.2)

In [None]:
write_tfrecord(train_files, 'tfrecord_train_full_shuffle.tfrec')

In [None]:
write_tfrecord(train, 'tfrecord_train_shuffle.tfrec')
write_tfrecord(test, 'tfrecord_train_val_shuffle.tfrec')

In [None]:
val_path1 = './dataset/val/no'
val_files1 = [(os.path.join(val_path1, f), 0) for f in os.listdir(val_path1) if f.endswith(".npy")]
val_path2 = './dataset/val/sphere'
val_files2 = [(os.path.join(val_path2, f), 1) for f in os.listdir(val_path2) if f.endswith(".npy")]
val_path3 = './dataset/val/vort'
val_files3 = [(os.path.join(val_path3, f), 2) for f in os.listdir(val_path3) if f.endswith(".npy")]

val_files = val_files1 + val_files2 + val_files3

write_tfrecord(val_files, 'tfrecord_val.tfrec')