In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.python as tfp
from tqdm import tqdm

In [None]:
train_df = pd.read_csv('../input/g2net-gravitational-wave-detection/training_labels.csv')
test_df = pd.read_csv('../input/g2net-gravitational-wave-detection/sample_submission.csv')

def get_train_file_path(image_id):
    return "../input/g2net-gravitational-wave-detection/train/{}/{}/{}/{}.npy".format(
        image_id[0], image_id[1], image_id[2], image_id)

def get_test_file_path(image_id):
    return "../input/g2net-gravitational-wave-detection/test/{}/{}/{}/{}.npy".format(
        image_id[0], image_id[1], image_id[2], image_id)

train_df['image_path'] = train_df['id'].apply(get_train_file_path)
test_df['image_path'] = test_df['id'].apply(get_test_file_path)

In [None]:
def _bytes_feature(value):
    if isinstance(value, tfp.framework.ops.EagerTensor):
        value = value.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [None]:
def create_tf_example(wave_id: str, wave: bytes) -> tf.train.Example:
    feature = {
        "wave_id": _bytes_feature(wave_id),
        "wave": _bytes_feature(wave)
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))


def write_tfrecord(df: pd.DataFrame, filename: str):
    options = tf.io.TFRecordOptions("GZIP")
    with tf.io.TFRecordWriter(filename, options=options) as writer:
        for i in tqdm(range(len(df))):
            wave_id = str.encode(df.iloc[i]["id"])
            wave_dir = df.iloc[i]["image_path"]
            wave = np.load(wave_dir).tobytes()
            tf_example = create_tf_example(wave_id, wave)
            writer.write(tf_example.SerializeToString())

In [None]:
test_samples_per_file = 22600
test_number_of_files = len(test_df) // test_samples_per_file

for i in range(0, 5):
    start = i * test_samples_per_file
    end = (i + 1) * test_samples_per_file
    df = test_df.iloc[start:end].reset_index(drop=True)
    filename = f"test{i}.tfrecords"
    write_tfrecord(df, filename)