Skip to content

Commit

Permalink
Fix #1, #2
Browse files Browse the repository at this point in the history
  • Loading branch information
tjysdsg committed Aug 28, 2019
1 parent 057c5e1 commit 08d8992
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 21 deletions.
24 changes: 14 additions & 10 deletions src/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,34 @@ def cli():
import argparse
parser = argparse.ArgumentParser(description='CLI for speech recognition functionality.')
parser.add_argument('action', metavar='ACTION', type=str, nargs=1,
choices=['train', 'test', 'record', 'continuous'],
choices=['isolated_train', 'isolated_test', 'record', 'continuous_train'],
help='Action to perform.')
parser.add_argument("-d", "--model-directory", default='models-4gaussians-em',
help="Directory which the trained models are stored, or test models are used.")
parser.add_argument("-i", "--input", default='models-continuous-4gaussians-em-realign/',
help="Directory which the input models are stored for continuous training, or test models are"
"used. \nIgnored for isolated model training")
parser.add_argument("-o", "--output", default='models-continuous-4gaussians-em-realign/',
help="Directory which the trained models are stored")
# TODO: use -g and -e for continuous training
parser.add_argument("-g", "--gmm", help="Use GMM-HMM as the model.", default=False, action='store_true')
parser.add_argument("-e", "--em", help="Use EM algorithm to train models.", default=False, action='store_true')
args = parser.parse_args()

if args.action[0] == 'train':
if args.action[0] == 'isolated_train':
print('training...')
# data folder location
folder = os.path.join(data_path, 'train')
for digit in digit_names:
filenames = [os.path.join(folder, f) for f in os.listdir(folder) if
re.match('[A-Z]+_' + digit + '[AB].wav', f)]

train(filenames, args.model_directory, digit, n_segs=5, use_gmm=args.gmm, use_em=args.em)
train(filenames, args.output, digit, n_segs=5, use_gmm=args.gmm, use_em=args.em)

if args.action[0] == 'test':
if args.action[0] == 'isolated_test':
print('testing...')
# get all the models from pickle files
models = []
for digit in digit_names:
file = open(os.path.join(args.model_directory, digit + '.pkl'), 'rb')
file = open(os.path.join(args.input, digit + '.pkl'), 'rb')
models.append(pickle.load(file))
file.close()
# get file patterns for each digit
Expand All @@ -46,7 +50,7 @@ def cli():
# get all the models from pickle files
models = []
for digit in digit_names:
file = open(args.model_directory + digit + '.pkl', 'rb')
file = open(args.input + digit + '.pkl', 'rb')
models.append(pickle.load(file))
file.close()

Expand All @@ -61,8 +65,8 @@ def cli():
best_model = i
print(best_model)

if args.action[0] == 'continuous':
aurora_continuous_train()
if args.action[0] == 'continuous_train':
aurora_continuous_train(args.input, args.output)


if __name__ == "__main__":
Expand Down
12 changes: 5 additions & 7 deletions src/sr/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def make_HMM(filenames, n_segs, use_gmm, use_em):
return model


def train(filenames, model_folder, model_name, n_segs, use_gmm, use_em):
def train(filenames, output_path, model_name, n_segs, use_gmm, use_em):
models = make_HMM(filenames, n_segs, use_gmm=use_gmm, use_em=use_em)
file = open(os.path.join(model_folder, model_name + '.pkl'), 'wb')
file = open(os.path.join(output_path, model_name + '.pkl'), 'wb')
pickle.dump(models, file)


Expand Down Expand Up @@ -94,14 +94,12 @@ def test(models, folder, file_patterns):
return n_passed / n_tests


def aurora_continuous_train():
def aurora_continuous_train(input_path, output_path):
models = []
hmm_index = 0
# for digit in digit_names:
for digit in range(11):
# TODO: use command line argument for input model path
file = open('models-continuous-4gaussians-em-realign/' + str(digit) + '.pkl', 'rb')
# file = open('models-4gaussians-em/' + str(digit) + '.pkl', 'rb')
file = open(input_path + str(digit) + '.pkl', 'rb')

model: HMM = pickle.load(file)
# set value of hmm_state.parent to the index of the hmm it belongs to
Expand Down Expand Up @@ -140,4 +138,4 @@ def aurora_continuous_train():
pickle.dump(data, f)
f.close()

continuous_train(data, models, labels)
continuous_train(data, models, labels, output_path)
8 changes: 4 additions & 4 deletions src/sr/recognition/continuous_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .hmm import HMM
import os
import pickle
from typing import List
from typing import List, AnyStr
import copy
import numpy as np

Expand Down Expand Up @@ -53,7 +53,8 @@ def build_state_sequences(hmms: List[HMM], label_matrix: List[List[int]]):
return seq, trans, end_state_indices[-1]


def continuous_train(data: List[np.ndarray], models: List[HMM], label_seqs: List[List[int]], n_gaussians: int = 4,
def continuous_train(data: List[np.ndarray], models: List[HMM], label_seqs: List[List[int]], output_path: AnyStr,
n_gaussians: int = 4,
n_segments: int = 5,
max_iteration: int = 1000):
# remember old models
Expand Down Expand Up @@ -163,9 +164,8 @@ def continuous_train(data: List[np.ndarray], models: List[HMM], label_seqs: List
new_models[mi].transitions[si, si] = -np.log(1 - p_jump)

# save new models to files
# TODO: use command line argument for output model path
for i in range(len(new_models)):
file = open(os.path.join('models-continuous-4gaussians-em-realign', str(i) + '.pkl'), 'wb')
file = open(os.path.join(output_path, str(i) + '.pkl'), 'wb')
pickle.dump(new_models[i], file)
file.close()
# check if converged
Expand Down

0 comments on commit 08d8992

Please sign in to comment.