Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRU model #53

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ import random
from audio import read_mfcc
from batcher import sample_from_mfcc
from constants import SAMPLE_RATE, NUM_FRAMES
from conv_models import DeepSpeakerModel
from models import ResCNNModel
from test import batch_cosine_similarity

# Reproducible results.
Expand Down
15 changes: 1 addition & 14 deletions batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from audio import pad_mfcc, Audio
from constants import NUM_FRAMES, NUM_FBANKS
from conv_models import DeepSpeakerModel
from models import DeepSpeakerModel
from utils import ensures_dir, load_pickle, load_npy, train_test_sp_to_utt

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -489,16 +489,3 @@ def get_speaker_verification_data(self, positive_speaker, num_different_speakers
data = [anchor, positive]
data.extend([self._select_speaker_data(n) for n in negative_speakers])
return np.vstack(data)


if __name__ == '__main__':
np.random.seed(123)
ltb = LazyTripletBatcher(working_dir='/Users/premy/deep-speaker/',
max_length=NUM_FRAMES,
model=DeepSpeakerModel())
for i in range(1000):
print(i)
start = time()
ltb.get_batch_train(batch_size=9)
print(time() - start)
# ltb.get_batch(batch_size=96)
20 changes: 9 additions & 11 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
from audio import Audio
from batcher import KerasFormatConverter
from constants import SAMPLE_RATE, NUM_FRAMES
from models import GRU_NAME, RES_CNN_NAME
from test import test
from train import start_training
from utils import ClickType as Ct, ensures_dir
from utils import init_pandas

logger = logging.getLogger(__name__)

VERSION = '3.0a'
VERSION = '3.0b'


@click.group()
Expand Down Expand Up @@ -53,8 +54,9 @@ def build_keras_inputs(working_dir, counts_per_speaker):

@cli.command('test-model', short_help='Test a Keras model.')
@click.option('--working_dir', required=True, type=Ct.input_dir())
@click.option('--model_name', required=True, type=click.Choice([RES_CNN_NAME, GRU_NAME]))
@click.option('--checkpoint_file', required=True, type=Ct.input_file())
def test_model(working_dir, checkpoint_file=None):
def test_model(working_dir, model_name, checkpoint_file):
# export CUDA_VISIBLE_DEVICES=0; python cli.py test-model
# --working_dir /home/philippe/ds-test/triplet-training/
# --checkpoint_file ../ds-test/checkpoints-softmax/ResCNN_checkpoint_102.h5
Expand All @@ -64,20 +66,16 @@ def test_model(working_dir, checkpoint_file=None):
# --working_dir /home/philippe/ds-test/triplet-training/
# --checkpoint_file ../ds-test/checkpoints-triplets/ResCNN_checkpoint_175.h5
# f-measure = 0.849, true positive rate = 0.798, accuracy = 0.997, equal error rate = 0.025
test(working_dir, checkpoint_file)
test(working_dir, model_name, checkpoint_file)


@cli.command('train-model', short_help='Train a Keras model.')
@click.option('--working_dir', required=True, type=Ct.input_dir())
@click.option('--model_name', required=True, type=click.Choice([RES_CNN_NAME, GRU_NAME]))
@click.option('--pre_training_phase/--no_pre_training_phase', default=False, show_default=True)
def train_model(working_dir, pre_training_phase):
def train_model(working_dir, model_name, pre_training_phase):
# PRE TRAINING

# commit a5030dd7a1b53cd11d5ab7832fa2d43f2093a464
# Merge: a11d13e b30e64e
# Author: Philippe Remy <premy.enseirb@gmail.com>
# Date: Fri Apr 10 10:37:59 2020 +0900
# LibriSpeech train-clean-data360 (600, 100). 0.985 on test set (enough for pre-training).
# LibriSpeech train-clean-data360 (600, 100). 0.991 on test set (enough for pre-training).

# TRIPLET TRAINING
# [...]
Expand All @@ -89,7 +87,7 @@ def train_model(working_dir, pre_training_phase):
# 2000/2000 [==============================] - 927s 464ms/step - loss: 0.0075 - val_loss: 0.0059
# Epoch 178/1000
# 2000/2000 [==============================] - 948s 474ms/step - loss: 0.0073 - val_loss: 0.0058
start_training(working_dir, pre_training_phase)
start_training(working_dir, model_name, pre_training_phase)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions deep-speaker
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ build_model_inputs)
train_softmax)
# Pre-training (0.92k speakers).
echo "[train_softmax] selected."
python cli.py train-model --working_dir "${PRE_TRAINING_WORKING_DIR}" --pre_training_phase
python cli.py train-model --model_name "$2" --working_dir "${PRE_TRAINING_WORKING_DIR}" --pre_training_phase
;;

train_triplet)
# Triplet-training (2.48k speakers).
echo "[train_triplet] selected."
python cli.py train-model --working_dir "${TRIPLET_TRAINING_WORKING_DIR}"
python cli.py train-model --model_name "$2" --working_dir "${TRIPLET_TRAINING_WORKING_DIR}"
;;

*)
Expand Down
8 changes: 5 additions & 3 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import numpy as np
import random

import numpy as np

from audio import read_mfcc
from batcher import sample_from_mfcc
from constants import SAMPLE_RATE, NUM_FRAMES
from conv_models import DeepSpeakerModel
from models import ResCNNModel
from test import batch_cosine_similarity

np.random.seed(123)
random.seed(123)

model = DeepSpeakerModel()
model = ResCNNModel()
model.m.load_weights('/Users/premy/deep-speaker/checkpoints/ResCNN_triplet_training_checkpoint_175.h5', by_name=True)

mfcc_001 = sample_from_mfcc(read_mfcc('samples/PhilippeRemy/PhilippeRemy_001.wav', SAMPLE_RATE), NUM_FRAMES)
Expand Down
183 changes: 84 additions & 99 deletions conv_models.py → models.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,72 @@
import abc
import logging
import os

import numpy as np
import tensorflow.keras.backend as K
from tensorflow.keras import layers
from tensorflow.keras import regularizers
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import GRU
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Lambda, Dense
from tensorflow.keras.layers import Reshape
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

from constants import NUM_FBANKS, NUM_FRAMES
from triplet_loss import deep_speaker_loss

logger = logging.getLogger(__name__)

RES_CNN_NAME = 'ResCNN'
GRU_NAME = 'GRU'


def select_model_class(name: str):
if name == RES_CNN_NAME:
return ResCNNModel
elif name == GRU_NAME:
return GRUModel
else:
raise Exception(f'Unknown model name: {name}.')


class DeepSpeakerModel:

# I thought it was 3 but maybe energy is added at a 4th dimension.
# would be better to have 4 dimensions:
# MFCC, DIFF(MFCC), DIFF(DIFF(MFCC)), ENERGIES (probably tiled across the frequency domain).
# this seems to help match the parameter counts.
def __init__(self, batch_input_shape=(None, NUM_FRAMES, NUM_FBANKS, 1), include_softmax=False,
num_speakers_softmax=None):
def __init__(self,
batch_input_shape=(None, NUM_FRAMES, NUM_FBANKS, 1),
include_softmax=False,
num_speakers_softmax=None,
name=RES_CNN_NAME):
self.include_softmax = include_softmax
self.num_speakers_softmax = num_speakers_softmax
if self.include_softmax:
assert num_speakers_softmax > 0
assert self.num_speakers_softmax > 0
self.clipped_relu_count = 0

# http://cs231n.github.io/convolutional-networks/
# conv weights
# #params = ks * ks * nb_filters * num_channels_input

# Conv128-s
# 5*5*128*128/2+128
# ks*ks*nb_filters*channels/strides+bias(=nb_filters)

# take 100 ms -> 4 frames.
# if signal is 3 seconds, then take 100ms per 100ms and average out this network.
# 8*8 = 64 features.

# used to share all the layers across the inputs

# num_frames = K.shape() - do it dynamically after.
inputs = Input(batch_shape=batch_input_shape, name='input')
x = self.cnn_component(inputs)
x = self.graph_with_avg_softmax_and_ln(inputs)
self.m = Model(inputs, x, name=name)

x = Reshape((-1, 2048))(x)
@abc.abstractmethod
def graph(self, inputs):
pass

def graph_with_avg_softmax_and_ln(self, inputs):
x = self.graph(inputs)
# Temporal average layer. axis=1 is time.
x = Lambda(lambda y: K.mean(y, axis=1), name='average')(x)
if include_softmax:
if self.include_softmax:
logger.info('Including a Dropout layer to reduce overfitting.')
# used for softmax because the dataset we pre-train on might be too small. easy to overfit.
# x = Dropout(0.25)(x) # was for GRU. Does 0.5 work with GRU as well?
x = Dropout(0.5)(x)
x = Dense(512, name='affine')(x)
if include_softmax:
if self.include_softmax:
# Those weights are just when we train on softmax.
x = Dense(num_speakers_softmax, activation='softmax')(x)
x = Dense(self.num_speakers_softmax, activation='softmax')(x)
else:
# Does not contain any weights.
x = Lambda(lambda y: K.l2_normalize(y, axis=1), name='ln')(x)
self.m = Model(inputs, x, name='ResCNN')
return x

def keras_model(self):
return self.m
Expand All @@ -82,6 +83,28 @@ def clipped_relu(self, inputs):
self.clipped_relu_count += 1
return relu

def set_weights(self, w):
for layer, layer_w in zip(self.m.layers, w):
layer.set_weights(layer_w)
logger.info(f'Setting weights for [{layer.name}]...')


class ResCNNModel(DeepSpeakerModel):

def __init__(self,
batch_input_shape=(None, NUM_FRAMES, NUM_FBANKS, 1),
include_softmax=False,
num_speakers_softmax=None):
super().__init__(batch_input_shape, include_softmax, num_speakers_softmax, RES_CNN_NAME)

def graph(self, inputs):
x = self.conv_and_res_block(inputs, 64, stage=1)
x = self.conv_and_res_block(x, 128, stage=2)
x = self.conv_and_res_block(x, 256, stage=3)
x = self.conv_and_res_block(x, 512, stage=4)
x = Reshape((-1, 2048))(x)
return x

def identity_block(self, input_tensor, kernel_size, filters, stage, block):
conv_name_base = f'res{stage}_{block}_branch'

Expand Down Expand Up @@ -128,72 +151,34 @@ def conv_and_res_block(self, inp, filters, stage):
o = self.identity_block(o, kernel_size=3, filters=filters, stage=stage, block=i)
return o

def cnn_component(self, inp):
x = self.conv_and_res_block(inp, 64, stage=1)
x = self.conv_and_res_block(x, 128, stage=2)
x = self.conv_and_res_block(x, 256, stage=3)
x = self.conv_and_res_block(x, 512, stage=4)
return x

def set_weights(self, w):
for layer, layer_w in zip(self.m.layers, w):
layer.set_weights(layer_w)
logger.info(f'Setting weights for [{layer.name}]...')
class GRUModel(DeepSpeakerModel):

def __init__(self,
batch_input_shape=(None, NUM_FRAMES, NUM_FBANKS, 1),
include_softmax=False,
num_speakers_softmax=None):
super().__init__(batch_input_shape, include_softmax, num_speakers_softmax, GRU_NAME)

def graph(self, inputs):
x = Conv2D(64, kernel_size=5, strides=2, padding='same', kernel_initializer='glorot_uniform',
name='conv1', kernel_regularizer=regularizers.l2(l=0.0001))(inputs)
# shape = (BATCH_SIZE , num_frames/2, 64/2, 64)
x = BatchNormalization(name='bn1')(x) # does it work with BN?
x = self.clipped_relu(x)

# 4d -> 3d.
_, frames_dim, fbank_dim, conv_output_dim = K.int_shape(x)
x = Reshape((frames_dim, fbank_dim * conv_output_dim))(x)
x = Reshape((frames_dim, fbank_dim * conv_output_dim))(x)

# shape = (BATCH_SIZE , num_frames/2, 1024)
x = GRU(1024, name='GRU1', return_sequences=True)(x)
if self.include_softmax:
x = Dropout(0.2)(x)
x = GRU(1024, name='GRU2', return_sequences=True)(x)
if self.include_softmax:
x = Dropout(0.2)(x)
x = GRU(1024, name='GRU3', return_sequences=True)(x)
return x

def main():
# Looks correct to me.
# I have 37K but paper reports 41K. which is not too far.
dsm = DeepSpeakerModel()
dsm.m.summary()

# I suspect num frames to be 32.
# Then fbank=64, then total would be 32*64 = 2048.
# plot_model(dsm.m, to_file='model.png', dpi=300, show_shapes=True, expand_nested=True)


def _train():
# x = np.random.uniform(size=(6, 32, 64, 4)) # 6 is multiple of 3.
# y_softmax = np.random.uniform(size=(6, 100))
# dsm = DeepSpeakerModel(batch_input_shape=(None, 32, 64, 4), include_softmax=True, num_speakers_softmax=100)
# dsm.m.compile(optimizer=Adam(lr=0.01), loss='categorical_crossentropy')
# print(dsm.m.predict(x).shape)
# print(dsm.m.evaluate(x, y_softmax))
# w = dsm.get_weights()
dsm = DeepSpeakerModel(batch_input_shape=(None, 32, 64, 4), include_softmax=False)
# dsm.m.set_weights(w)
dsm.m.compile(optimizer=Adam(lr=0.01), loss=deep_speaker_loss)

# it works!!!!!!!!!!!!!!!!!!!!
# unit_batch_size = 20
# anchor = np.ones(shape=(unit_batch_size, 32, 64, 4))
# positive = np.array(anchor)
# negative = np.ones(shape=(unit_batch_size, 32, 64, 4)) * (-1)
# batch = np.vstack((anchor, positive, negative))
# x = batch
# y = np.zeros(shape=(len(batch), 512)) # not important.
# print('Starting to fit...')
# while True:
# print(dsm.m.train_on_batch(x, y))

# should not work... and it does not work!
unit_batch_size = 20
negative = np.ones(shape=(unit_batch_size, 32, 64, 4)) * (-1)
batch = np.vstack((negative, negative, negative))
x = batch
y = np.zeros(shape=(len(batch), 512)) # not important.
print('Starting to fit...')
while True:
print(dsm.m.train_on_batch(x, y))


def _test_checkpoint_compatibility():
dsm = DeepSpeakerModel(batch_input_shape=(None, 32, 64, 4), include_softmax=True, num_speakers_softmax=10)
dsm.m.save_weights('test.h5')
dsm = DeepSpeakerModel(batch_input_shape=(None, 32, 64, 4), include_softmax=False)
dsm.m.load_weights('test.h5', by_name=True)
os.remove('test.h5')


if __name__ == '__main__':
_test_checkpoint_compatibility()