<a href="https://colab.research.google.com/github/wolfisberg/zhaw-ba-online/blob/main/tfrecord_writer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import numpy as np
import librosa
import os
import math
import scipy
import pathlib


_FILES_PER_SHARD = 600
_SPEECH_DATA_DIR = 'speech'
_NOISE_DATA_DIR = 'noise'
_VALIDATION_DATA_DIR = 'cv'
_TRAINING_DATA_DIR = 'tr'
_TEST_DATA_DIR = 'tt'
_SHARD_BASE_NAME = 'shard'
_SAMPLE_RATE = 16000


def main():
    print('Starting to create tfrecords.')
    create_tfrecords_from_directory(_SPEECH_DATA_DIR, add_ref_data=True)
    create_tfrecords_from_directory(_NOISE_DATA_DIR, add_ref_data=False)


def create_tfrecords_from_directory(dir, add_ref_data=False):
    print(f'Reading from {dir}')

    for data_dir in [_TRAINING_DATA_DIR, _TEST_DATA_DIR, _VALIDATION_DATA_DIR]:
        print(f'Data dir {data_dir}')

        data_dir_path = os.path.join('/', 'content', 'drive', 'MyDrive', 'fs2021-ba', 'data', 'downsampled', dir, data_dir)
        wav_files = librosa.util.find_files(directory=data_dir_path, ext=['wav'], recurse=False, case_sensitive=False)
        number_of_shards = math.ceil(len(wav_files) / _FILES_PER_SHARD)

        if len(wav_files) < 1:
            print(f'No .wav files found in directory {dir}/{data_dir}')

        for shard_number in range(number_of_shards):
            print(f'Shard number {shard_number}')

            shard_name = f'{_SHARD_BASE_NAME}_{dir}_{data_dir}_{str(shard_number).rjust(4, "0")}.tfrecord'
            shard_path = os.path.join('/', 'content', 'drive', 'MyDrive', 'fs2021-ba', 'data', 'tfrecords', dir, shard_name)

            with tf.io.TFRecordWriter(shard_path) as out:
                lower_index = shard_number * _FILES_PER_SHARD
                upper_index = (shard_number + 1) * _FILES_PER_SHARD

                for file_index in range(lower_index, upper_index if upper_index <= len(wav_files) else len(wav_files)):
                    print(f'File index {file_index}')

                    file_path = wav_files[file_index]
                    # read via librosa
                    y, sr = librosa.load(file_path, sr=None)

                    # read via tensorflow
                    # raw_audio = tf.io.read_file(file_path)
                    # y, sr = tf.audio.decode_wav(raw_audio)

                    # read via scipy
                    # sr, y = scipy.io.wavfile.read(file_path)

                    if sr != _SAMPLE_RATE:
                        y = librosa.to_mono(y)
                        y = librosa.resample(y=y, orig_sr=sr, target_sr=_SAMPLE_RATE, res_type='kaiser_best', scale=True)

                    if y.dtype != np.dtype(np.int16):
                        y = (y / np.max(np.abs(y)) * np.iinfo(np.int16).max).astype(np.int16)

                    if add_ref_data:
                        file_name_stem = os.path.splitext(os.path.basename(file_path))[0][4:]
                        ref_matches = list(pathlib.Path(data_dir_path).rglob(f'*{file_name_stem}*.[fF]0'))
                        if len(ref_matches) != 1:
                            print(f'Cannot find ref (f0) file for [ {file_path} ], skipping...')
                            continue

                        ref_data = np.genfromtxt(ref_matches[0], delimiter=' ')

                        example = tf.train.Example(features=tf.train.Features(feature={
                            'data': _bytes_feature(y.tobytes()),
                            'data_sampling_rate': _int64_feature([sr]),
                            'data_num_channels': _int64_feature([1]),
                            'data_width': _int64_feature([len(y)]),
                            'pitch': _float_feature(ref_data.T[2]),
                            'pitch_confidence': _float_feature(ref_data.T[3]),
                        }))

                    else:
                        example = tf.train.Example(features=tf.train.Features(feature={
                            'data': _bytes_feature(y.tobytes()),
                            'data_sampling_rate': _int64_feature([sr]),
                            'data_num_channels': _int64_feature([1]),
                            'data_width': _int64_feature([len(y)]),
                        }))


                    out.write(example.SerializeToString())
            print(f'Shard [ {shard_name} ] successfully written to disk.')



def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


main()


Starting to create tfrecords.
Reading from speech
Data dir tr
Shard number 0
File index 0
File index 1
File index 2
File index 3
File index 4
File index 5
File index 6
File index 7
File index 8
File index 9
File index 10
File index 11
File index 12
File index 13
File index 14
File index 15
File index 16
File index 17
File index 18
File index 19
File index 20
File index 21
File index 22
File index 23
File index 24
File index 25
File index 26
File index 27
File index 28
File index 29
File index 30
File index 31
File index 32
File index 33
File index 34
File index 35
File index 36
File index 37
File index 38
File index 39
File index 40
File index 41
File index 42
File index 43
File index 44
File index 45
File index 46
File index 47
File index 48
File index 49
File index 50
File index 51
File index 52
File index 53
File index 54
File index 55
File index 56
File index 57
File index 58
File index 59
File index 60
File index 61
File index 62
File index 63
File index 64
File index 65
File inde