In [44]:
import os
import pandas as pd
from tensorflow import keras
import tensorflow as tf
import numpy as np
from jiwer import wer
import pickle


class LabelEncoder:

    def __init__(self):
        characters = [x for x in "abcdefghijklmnopqrstuvwxyz'?! "]
        # encoder
        self.char_to_num = keras.layers.StringLookup(vocabulary=characters, oov_token="")
        # decoder
        self.num_to_char = keras.layers.StringLookup(
            vocabulary=self.char_to_num.get_vocabulary(), oov_token="", invert=True
        )

    def encode(self, label):
        label = tf.strings.lower(label)
        label = tf.strings.unicode_split(label, input_encoding="UTF-8")
        return self.char_to_num(label)

    def decode(self, nums):
        return tf.strings.reduce_join(self.num_to_char(nums))


def wav_to_audio(filepath):
    file = tf.io.read_file(filepath)

    audio, _ = tf.audio.decode_wav(file)
    audio = tf.squeeze(audio, axis=-1)

    audio = tf.cast(audio, tf.float32)

    return audio


def audio_to_spectrogram(
        audio,
        frame_length,
        frame_step,
        fft_length,
):
    spectrogram = tf.signal.stft(
        audio,
        frame_length=frame_length,
        frame_step=frame_step,
        fft_length=fft_length
    )

    # 5. We only need the magnitude, which can be derived by applying tf.abs
    spectrogram = tf.abs(spectrogram)
    spectrogram = tf.math.pow(spectrogram, 0.5)

    # 6. normalisation
    means = tf.math.reduce_mean(spectrogram, 1, keepdims=True)
    stddevs = tf.math.reduce_std(spectrogram, 1, keepdims=True)
    spectrogram = (spectrogram - means) / (stddevs + 1e-10)

    return spectrogram


def wav_to_features(
        wav_file,
        wavs_path,
        frame_length,
        frame_step,
        fft_length,
        add_suffix=True,
):
    file_path = wavs_path + wav_file
    if add_suffix:
        file_path += ".wav"
    audio = wav_to_audio(file_path)
    spectrogram = audio_to_spectrogram(
            audio=audio,
            frame_length=frame_length,
            frame_step=frame_step,
            fft_length=fft_length,
    )

    return spectrogram


class FeatureEncoder:

    def __init__(
            self,
            wavs_path,
            frame_length,
            frame_step,
            fft_length,
            add_suffix,
    ):
        self.wavs_path = wavs_path
        self.frame_length = frame_length
        self.frame_step = frame_step
        self.fft_length = fft_length
        self.add_suffix = add_suffix

    def features(
            self,
            wav_file,
    ):
        return wav_to_features(
            wav_file=wav_file,
            wavs_path=self.wavs_path,
            frame_length=self.frame_length,
            frame_step=self.frame_step,
            fft_length=self.fft_length,
            add_suffix=self.add_suffix
        )


class AudioContext:

    _instance = None
    _allowed = False

    def __init__(
            self,
            wavs_path,
            frame_length,
            frame_step,
            fft_length,
            add_suffix,
    ):
        if not AudioContext._allowed:
            raise ValueError("Cannot instantiate AudioContext")

        self.label_encoder = LabelEncoder()
        self.feature_encoder = FeatureEncoder(
            wavs_path=wavs_path,
            frame_length=frame_length,
            frame_step=frame_step,
            fft_length=fft_length,
            add_suffix=add_suffix,
        )

    @staticmethod
    def set(
            wavs_path,
            frame_length,
            frame_step,
            fft_length,
            add_suffix,
    ):
        AudioContext._allowed = True
        AudioContext._instance = AudioContext(
            wavs_path=wavs_path,
            frame_length=frame_length,
            frame_step=frame_step,
            fft_length=fft_length,
            add_suffix=add_suffix,
        )
        AudioContext._allowed = False

    @staticmethod
    def get():
        if AudioContext._instance is None:
            raise ValueError("AudioContext not set yet")

        return AudioContext._instance


def encode_single_sample(wav_file, label):

    lencoder = AudioContext.get().label_encoder
    fencoder = AudioContext.get().feature_encoder

    features = fencoder.features(wav_file)
    encoded_label = lencoder.encode(label)

    return features, encoded_label


def create_dataset(data_df, batch_size):
    _dataset = tf.data.Dataset.from_tensor_slices(
        (
            list(data_df["file_name"]),
            list(data_df["normalized_transcription"])
        )
    )
    _dataset = (
        _dataset.map(
            encode_single_sample,
            num_parallel_calls=tf.data.AUTOTUNE
        )
        .padded_batch(batch_size)
        .prefetch(buffer_size=tf.data.AUTOTUNE)
    )

    return _dataset


def CTCLoss(
        y_true,
        y_pred
):
    # Compute the training-time loss value
    batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
    input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
    label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")

    input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
    label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")

    loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
    return loss


def decode_batch_predictions(pred):
    input_len = np.ones(pred.shape[0]) * pred.shape[1]

    # Use greedy search. For complex tasks, you can use beam search
    results = keras.backend.ctc_decode(
        pred,
        input_length=input_len,
        greedy=True
    )[0][0]

    # Iterate over the results and get back the text
    lencoder = AudioContext.get().label_encoder
    res = []
    for result in results:
        result = lencoder.decode(result).numpy().decode("utf-8")
        res.append(result)
    return res


def accuracy(labels, predictions):

    def _acc(l, p):

        l_len = len(l)
        p_len = len(p)
        count = 0
        for i in range(min(l_len, p_len)):
            if l[i] == p[i]:
                count += 1

        return count / max(l_len, p_len)

    total_acc = 0
    for l, p in zip(labels, predictions):
        total_acc += _acc(l, p)

    avg_acc = total_acc / len(labels)
    return avg_acc


def abs_accuracy(labels, predictions):

    def _acc(l, p):

        if l == p:
            return 1

        return 0

    total_acc = 0
    for l, p in zip(labels, predictions):
        total_acc += _acc(l, p)

    avg_acc = total_acc / len(labels)
    return avg_acc


class ModelEvaluator(keras.callbacks.Callback):

    def __init__(self, dataset, model):
        super().__init__()
        self.dataset = dataset
        self.model = model
        self.history = []

    @staticmethod
    def print(wer_score, res_acc, abs_acc, pred_df):
        print("-" * 100)
        print(f"Word Error Rate: {wer_score:.4f}")
        print("-" * 100)
        print(f"Accuracy: {res_acc:.4f}")
        print("-" * 100)
        print(f"Abs accuracy: {abs_acc:.4f}")
        print("-" * 100)
        print(pred_df.sample(n=5))

    def do_prediction(self):
        predictions = []
        targets = []
        lencoder = AudioContext.get().label_encoder
        for batch in self.dataset:
            X, y = batch
            batch_predictions = self.model.predict(X)
            batch_predictions = decode_batch_predictions(batch_predictions)
            predictions.extend(batch_predictions)
            for label in y:
                label = (
                    lencoder.decode(label).numpy().decode("utf-8") \
                    )
                targets.append(label)

        wer_score = wer(targets, predictions)
        res_acc = accuracy(targets, predictions)
        abs_acc = abs_accuracy(targets, predictions)
        pred_df = pd.DataFrame(data={"Label": targets, "predictions": predictions})

        return wer_score, res_acc, abs_acc, pred_df

    def on_epoch_end(self, epoch: int, logs=None):
        wer_score, res_acc, abs_acc, pred_df = self.do_prediction()
        self.history.append((wer_score, res_acc, abs_acc, pred_df))
        self.print(wer_score, res_acc, abs_acc, pred_df)


def evaluate(model, test_dataset):
    print("evaluate model against test dataset")
    test_evaluator = ModelEvaluator(
        dataset=test_dataset,
        model=model
    )
    wer_score, res_acc, abs_acc, pred_df = test_evaluator.do_prediction()
    test_evaluator.print(wer_score, res_acc, abs_acc, pred_df)
    return wer_score, res_acc, abs_acc, pred_df


def write_pickle(data, file, ensure_exist=True):

    if ensure_exist:
        path = os.path.dirname(file)
        if len(path) > 0:
            os.makedirs(path, exist_ok=True)

    with open(file, 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)


def read_pickle(file):
    with open(file, 'rb') as handle:
        b = pickle.load(handle)

    return b


def load_model(model_name, save_path):

    model = keras.models.load_model(
        os.path.join(save_path, model_name),
        custom_objects={"CTCLoss": CTCLoss})
    model_res = read_pickle(os.path.join(save_path, f"{model_name}_validation.pkl"))

    return model, model_res


# params
frame_length = 256
frame_step = 160
fft_length = 384

AudioContext.set(
    wavs_path="",
    frame_length=frame_length,
    frame_step=frame_step,
    fft_length=fft_length,
    add_suffix=False
)

model_name = "model005"
save_path = "C:/data/models"
print("Loading Model")
model005, model_res = load_model(model_name, save_path)

model_name = "model005lj_half"
save_path = "C:/data/models"
print("Loading Model")
model005lj, model_res = load_model(model_name, save_path)

Loading Model
Loading Model


In [49]:
import sounddevice as sd
from scipy.io.wavfile import write
import numpy as np
import queue
import soundfile as sf
import os
from playsound import playsound
import uuid


class Recorder:
    def __init__(self, model):
        self.filepath = None
        self.SAMPLE_RATE = 44100
        self.CHANNELS = 1 # our decoder only handles 1 channel right now
        self.sound_file = None
        self.q = queue.Queue()
        self.model = model

    def record(self):
        try:
            self.filepath = f"{uuid.uuid1()}_recording.wav"
            with sf.SoundFile(self.filepath,
                              mode='x',
                              samplerate=self.SAMPLE_RATE,
                              channels=self.CHANNELS,
                              subtype=None) as file:
                self.sound_file = file

                with sd.InputStream(samplerate=self.SAMPLE_RATE,
                                    # device=self.mic_id,
                                    channels=self.CHANNELS,
                                    callback=self.callback):

                    print("Recording started")
                    while True:
                        file.write(self.q.get())
        except Exception as e:
            # print(e)
            print("Recording stopped")

    def callback(self, indata, frames, time, status):
        if status:
            print(status, file=sys.stderr)
        self.q.put(indata.copy())
            
    def stop(self):
        try:
            self.sound_file.flush()
            self.sound_file.close()
        except:
            pass
        
    def play(self):
        print("Playing recording")
        if os.path.exists(self.filepath):
            playsound(self.filepath)
            print("Play recording complete")
        else:
            print("No recording found")
            
    def predict(self):
        df = pd.DataFrame(data={"file_name": [self.filepath], "normalized_transcription": ["a"]})

        test_dataset = create_dataset(data_df=df, batch_size=2)
        test_evaluator = ModelEvaluator(dataset=test_dataset, model=self.model)
        wer_score, res_acc, abs_acc, pred_df = test_evaluator.do_prediction()
        res = pred_df.loc[0]["predictions"]
        print(f"Prediction: {res}")
        
rec005 = Recorder(model=model005)
rec005lj = Recorder(model=model005lj)

In [50]:
from ipywidgets import GridspecLayout, Button, Layout, ButtonStyle
import threading


def click_record(btn):
    t = threading.Thread(target=rec005.record)
    t.start()
    
def click_stop(btn):
    rec005.stop()
    
def click_play(btn):
    rec005.play()
    rec005.predict()

def create_player005():
    grid = GridspecLayout(1, 3)
    grid[0, 0] = Button(
        description="Play", 
        layout=Layout(height='auto', width='auto')
    )
    grid[0, 0].on_click(click_play)
    
    grid[0, 1] = Button(
        description="Stop", 
        layout=Layout(height='auto', width='auto')
    )
    grid[0, 1].on_click(click_stop)
    
    grid[0, 2] = Button(
        description="Record", 
        layout=Layout(height='auto', width='auto')
    )
    grid[0, 2].on_click(click_record)
    
    return grid

In [51]:
from ipywidgets import GridspecLayout, Button, Layout, ButtonStyle
import threading


def click_record_lj(btn):
    t = threading.Thread(target=rec005lj.record)
    t.start()
    
def click_stop_lj(btn):
    rec005lj.stop()
    
def click_play_lj(btn):
    rec005lj.play()
    rec005lj.predict()

def create_player005lj():
    grid = GridspecLayout(1, 3)
    grid[0, 0] = Button(
        description="Play", 
        layout=Layout(height='auto', width='auto')
    )
    grid[0, 0].on_click(click_play_lj)
    
    grid[0, 1] = Button(
        description="Stop", 
        layout=Layout(height='auto', width='auto')
    )
    grid[0, 1].on_click(click_stop_lj)
    
    grid[0, 2] = Button(
        description="Record", 
        layout=Layout(height='auto', width='auto')
    )
    grid[0, 2].on_click(click_record_lj)
    
    return grid

In [52]:
player = create_player005()
player

GridspecLayout(children=(Button(description='Play', layout=Layout(grid_area='widget001', height='auto', width=…

Recording started
Recording stopped
Playing recording
Play recording complete
Prediction: alexa
Recording started
Recording stopped
Playing recording
Play recording complete
Prediction: alexa


In [53]:
player_lj = create_player005lj()
player_lj

GridspecLayout(children=(Button(description='Play', layout=Layout(grid_area='widget001', height='auto', width=…

Recording started
Recording stopped
Playing recording
Play recording complete
Prediction: f a
Recording started
Recording stopped
Playing recording
Play recording complete
Prediction: f ear
Recording started
Recording stopped
Playing recording
Play recording complete
Prediction: f  h
