In [None]:
import os
import pickle
from multiprocessing import Pool

import numpy as np
import tensorflow as tf
import pandas as pd
from tqdm.auto import tqdm
import json

In [None]:
from kaggle_secrets import UserSecretsClient

secrets = UserSecretsClient()

os.environ['KAGGLE_USERNAME'] = secrets.get_secret("KAGGLE_USERNAME")
os.environ['KAGGLE_KEY'] = secrets.get_secret("KAGGLE_KEY")

## Helper functions

In [None]:
def to_feature(value, dtype):
    if dtype in ['int', 'int64', 'uint32', 'uint64', 'bool', 'enum']:
        return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
    
    if dtype in ['float', 'double', 'float32', 'float64']:
        return tf.train.Feature(float_list=tf.train.FloatList(value=value))
    
    if dtype in ['string', 'bytes']:
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
    
    else:
        error_msg = (
            f"dtype '{dtype}' not recognized. You need to use a dtype compatible with TF protos."
            "See: https://www.tensorflow.org/tutorials/load_data/tfrecord"
        )
        raise ValueError(error_msg)

def to_example(feature):
    return tf.train.Example(features=tf.train.Features(feature=feature))

def serialize(example):
    return example.SerializeToString()

def deserialize(string):
    return tf.train.Example.FromString(string)

## Initialize Kaggle Dataset Metadata

In [None]:
os.makedirs('/kaggle/dataset/', exist_ok=True)

# Change below
meta = dict(
    id="xhlulu/seti-tfrecords-train",
    title="SETI Train Split in TF Records",
    isPrivate=False,
    licenses=[dict(name="other")]
)

with open('/kaggle/dataset/dataset-metadata.json', 'w') as f:
    json.dump(meta, f)

In [None]:
# # First time only:
# !touch /kaggle/dataset/dummy.txt
# !kaggle datasets create -p "/kaggle/dataset" --dir-mode zip
# !rm /kaggle/dataset/dummy.txt

## Start generating TF Records

In [None]:
labels = pd.read_csv('../input/seti-breakthrough-listen/train_labels.csv')

In [None]:
split = 'train'
n_jobs = 8

N = labels.shape[0]
chunk_size = np.ceil(N / (n_jobs * 10)).astype(int)
indices = np.arange(0, N, chunk_size)

In [None]:
def write_record(arg):
    chunk, idx, chunk_size, split = arg
    
    record_path = f"/kaggle/dataset/{split}-{chunk}.tfrecord"
    chunk_df = labels.values[idx: idx+chunk_size]
    
    print("Starting to write", record_path)
    with tf.io.TFRecordWriter(record_path) as writer:
        for idx, target in chunk_df:
            path = os.path.join("../input/seti-breakthrough-listen", split, idx[0], idx + '.npy')
            X = np.load(path)
            y = target

            feature = {
                "X": to_feature(X.flatten(), 'float'),
                "y": to_feature([y], 'int')
            }

            example = to_example(feature)
            serialized = serialize(example)
            writer.write(serialized)
        
    print("Finished to write", record_path)

In [None]:
args = [
    (chunk, idx, chunk_size, split)
    for chunk, idx in enumerate(indices)
]

In [None]:
%%time
with Pool(4) as p:
    p.map(write_record, args)

## Upload dataset

In [None]:
!kaggle datasets version -p "/kaggle/dataset" -m "Updated via notebook" --dir-mode zip

## Code for reading TF Records

Only for reference, not run here

In [None]:
# def parse_train_example(example):
#     feature_description = {
#         "X": tf.io.FixedLenFeature((6, 273, 256), tf.float32),
#         "y": tf.io.FixedLenFeature([], tf.int64)
#     }

#     example = tf.io.parse_single_example(example, feature_description)
#     X = tf.transpose(example['X'], (0, 2, 1))
    
#     return X, example['y']

# def parse_test_example(example):
#     feature_description = {
#         "X": tf.io.FixedLenFeature((6, 273, 256), tf.float32)
#     }

#     example = tf.io.parse_single_example(example, feature_description)
#     X = tf.transpose(example['X'], (0, 2, 1))
    
#     return X


# parsed_dataset = (
#     tf.data.TFRecordDataset([f"/kaggle/dataset/train-{chunk}.tfrecord" for chunk in range(4)])
#     .map(parse_train_example)
# )