In [1]:
import os
import argparse

import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import *
from tensorflow.keras.utils import plot_model

from sklearn.model_selection import train_test_split

import tensorflow_io as tfio

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

In [2]:
def ArgParser():
    parser = argparse.ArgumentParser()

    # RNN layer
    parser.add_argument("--units", dest="units", type=int, default=50)
    parser.add_argument("--n_layers", dest="n_layers", type=int, default=2)
    parser.add_argument("--dropout", dest="dropout", type=int, default=0.1)
    parser.add_argument("--bidirectional", dest="bidirectional", type=bool, default=True, choices=[True, False])

    # Segmentor
    parser.add_argument("--n_classes", dest="n_classes", type=int, default=61)
    parser.add_argument("--batch_size", dest="batch_size", type=int, default=2)
    parser.add_argument("--n_mels", dest="n_mels", type=int, default=64)
    parser.add_argument("--max_seg_size", dest="max_seg_size", type=int, default=100)
    parser.add_argument("--min_seg_size", dest="min_seg_size", type=int, default=0)

    # Dataset
    parser.add_argument("--main_dir", dest="main_dir", type=str, default="Datasets/TIMIT-dataset/tfrec_data")
    parser.add_argument("--buffer_size", dest="buffer_size", type=int, default=512)
    parser.add_argument("--test_size", dest="test_size", type=float, default=0.2)

    return parser.parse_known_args()[0]

args = ArgParser()
args

Namespace(batch_size=2, bidirectional=True, buffer_size=512, dropout=0.1, main_dir='Datasets/TIMIT-dataset/tfrec_data', max_seg_size=100, min_seg_size=0, n_classes=61, n_layers=2, n_mels=64, test_size=0.2, units=50)

In [3]:
class TIMITDataset():
    def __init__(self, args):
        self.files = [os.path.join(args.main_dir, f) for f in os.listdir(args.main_dir)]
        self.AUTOTUNE = tf.data.experimental.AUTOTUNE
        self.args = args
        self.train_files, self.test_files = train_test_split(
            self.files, test_size=args.test_size, shuffle=True)

    def decode_audio(self, audio):
        audio = tf.audio.decode_wav(audio)[0]
        return tf.squeeze(audio, axis=-1)

    def read_tfrecord(self, example):
        feature_description = {
            'spectrogram': tf.io.FixedLenFeature([], tf.string),
            'framewise_label': tf.io.FixedLenFeature([], tf.string),
            'binary_label': tf.io.FixedLenFeature([], tf.string),
            'filename': tf.io.FixedLenFeature([], tf.string)}
        
        example = tf.io.parse_single_example(example, feature_description)
        example['spectrogram'] = tf.io.parse_tensor(
            example['spectrogram'], out_type=tf.float32)
        example['framewise_label'] = tf.io.parse_tensor(
            example['framewise_label'], out_type=tf.int32)
        example['binary_label'] = tf.io.parse_tensor(
            example['binary_label'], out_type=tf.int32)
        return example


    def load_dataset(self, files):
        ignore_order = tf.data.Options()
        ignore_order.experimental_deterministic = False
        dataset = tf.data.TFRecordDataset(files)
        dataset = dataset.with_options(ignore_order)
        dataset = dataset.map(self.read_tfrecord, num_parallel_calls=self.AUTOTUNE)
        return dataset


    def SpecAugment(self, sample):
        spectrogram = sample['spectrogram']
        spectrogram = tfio.audio.freq_mask(spectrogram, param=10)
        spectrogram = tfio.audio.time_mask(spectrogram, param=10)
        sample['spectrogram'] = spectrogram
        return sample


    def train(self):
        dataset = self.load_dataset(self.train_files)
        dataset = dataset.map(self.SpecAugment, num_parallel_calls=self.AUTOTUNE)
        dataset = dataset.repeat()
        dataset = dataset.shuffle(self.args.buffer_size)
        dataset = dataset.batch(self.args.batch_size)
        dataset = dataset.prefetch(self.AUTOTUNE)
        return dataset


    def test(self):
        dataset = self.load_dataset(self.test_files)
        dataset = dataset.shuffle(self.args.buffer_size)
        dataset = dataset.batch(self.args.batch_size)
        dataset = dataset.cache()
        dataset = dataset.prefetch(self.AUTOTUNE)
        return dataset

dataset = TIMITDataset(args).train()
binary_label, filename, framewise_label, spectrogram = list(next(iter(dataset)).values())

print("spectrogram shape:", spectrogram.shape)
print("binary_label shape:", binary_label.shape)
print("framewise_label shape:", framewise_label.shape)

spectrogram shape: (2, 438, 64)
binary_label shape: (2, 438)
framewise_label shape: (2, 438)


In [10]:
class Segmentor(Model):
    def __init__(self, args):
        super().__init__(name="Segmentor")
        self.args = args
        self.rnn = self.rnn_block()
        self.scorer = self.scorer_block()
        self.classifier = self.classifier_block()
        self.bi_classifier = self.bi_classifier_block()

    def rnn_block(self):
        LSTM_layer = Bidirectional(LSTM(self.args.units, return_sequences=True))
        return Sequential(
            [LSTM_layer for _ in range(self.args.n_layers-1)]
            + [Bidirectional(LSTM(self.args.units, return_sequences=True))], 
            name="rnn_block")

    def scorer_block(self):
        return Sequential([
            PReLU(),
            Dense(100),
            PReLU(),
            Dense(1)], name="scorer")

    def classifier_block(self):
        return Sequential([
            PReLU(),
            Dense(self.args.n_classes * 2),
            PReLU(),
            Dense(self.args.n_classes)], name="classifier")

    def bi_classifier_block(self):
        return Sequential([
            PReLU(),
            Dense(self.args.n_classes * 2),
            PReLU(),
            Dense(2)], name="bi_classifier")

    def compute_phi(self, rnn_out):
        batch_size, seq_len, feat_dim = rnn_out.shape
        batch_size = self.args.batch_size

        rnn_cum = tf.math.cumsum(rnn_out, axis=1)
        output_shape = [batch_size, seq_len, seq_len, feat_dim]
        
        a = tf.repeat(rnn_cum, [1, seq_len, 1])
        b = tf.reshape(tf.repeat(rnn_cum, [1, 1, seq_len]), [batch_size, -1, feat_dim])
        c = tf.reshape(tf.math.subtract(a, b), output_shape)
        d = tf.reshape(tf.repeat(rnn_out, [1, 1, seq_len]), output_shape)
        e = tf.reshape(tf.repeat(rnn_out, [1, seq_len, 1]), output_shape)
        return tf.concat([c, d, e], axis=-1)

    def segment_search(self, scores):
        batch_size, seq_len = scores.shape[:-1]
        batch_size = self.args.batch_size

        best_scores = tf.zeros([batch_size, seq_len])
        segments = tf.zeros([batch_size, 1, 2], dtype=tf.int32)

        for i in range(1, seq_len):
            start_idx = max(0, i - self.args.max_seg_size)
            end_idx = i
            current_score = tf.zeros([batch_size, end_idx - start_idx])
            indices, updates = [], []

            for j in range(start_idx, end_idx):
                index = tf.constant([[k, (j - start_idx)] for k in range(batch_size)])
                update = (best_scores[:, j] + scores[:, j, i])
                tf.tensor_scatter_nd_update(current_score, index, update)

            best_score, best_index = tf.math.top_k(current_score, k=1)
            best_score = tf.squeeze(best_score, axis=1)
            best_index += start_idx
            best_indices = tf.constant([[m, i] for m in range(batch_size)])
            tf.tensor_scatter_nd_update(best_scores, best_indices, best_score)

            segment_indices, segment_updates = [], []
            for n in range(batch_size):
                segment_indices.append(tf.concat([tf.constant([n]), best_index[n]], axis=-1))
                segment_updates.append(tf.concat([best_index[n], tf.constant([i])], axis=-1))

            segment_indices = tf.stack(segment_indices)
            segment_updates = tf.stack(segment_updates)
            tf.tensor_scatter_nd_update(segments, segment_indices, segment_updates)

        return segments


    def compute_segmentation_score(self, scores, segments):
        out_scores = tf.zeros([scores.shape[0]])
        print(out_scores, segments)
        
        return out_scores


    def call(self, inputs):
        spectrogram, binary_label, framewise_label = inputs
        rnn_out = self.rnn(spectrogram)
        phi = self.compute_phi(rnn_out)
        scores = tf.squeeze(self.scorer(phi), axis=-1)
        segments = self.segment_search(scores)

        return {
            "classifier_out": self.classifier(rnn_out),
            "bi_classifier_out": self.bi_classifier(rnn_out),
            "segments": segments,
            "segmentation_scores": self.compute_segmentation_score(scores, segments)
        }


model = Segmentor(args)
model([Input([438, 64]), Input([438]), Input([438])])
model.summary()

Tensor("Segmentor/zeros_439:0", shape=(2,), dtype=float32) Tensor("Segmentor/zeros_1:0", shape=(2, 1, 2), dtype=int32)
Model: "Segmentor"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
rnn_block (Sequential)       (None, 438, 100)          106400    
_________________________________________________________________
scorer (Sequential)          (2, 438, 438, 1)          76767801  
_________________________________________________________________
classifier (Sequential)      (None, 438, 61)           117061    
_________________________________________________________________
bi_classifier (Sequential)   (None, 438, 2)            109804    
Total params: 77,101,066
Trainable params: 77,101,066
Non-trainable params: 0
_________________________________________________________________
