In [1]:
import sys  
sys.path.insert(0, '/home/ubuntu/bawk/models')

In [2]:
from seqtoseq_v2 import *
from  create_dataset import *
from train_v2 import trainIters
from predict_v2 import evaluateRandomly, inference_from_file
import torch
import librosa

################################################################################
###          (please add 'export KALDI_ROOT=<your_path>' in your $HOME/.profile)
###          (or run as: KALDI_ROOT=<your_path> python <your_script>.py)
################################################################################



In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
encoder = Encoder(80,40, 2, dropout=0.1, bidirectional=True).to(device)
decoder = Decoder(vocab_size=29, embedding_dim=15, hidden_size=20, num_layers=2).to(device)

In [5]:
enc_path = '../models/output/enc_model_las'
encoder.load_state_dict(torch.load(enc_path))
encoder.eval()

dec_path = '../models/output/dec_model_las'
decoder.load_state_dict(torch.load(dec_path))
decoder.eval()

Decoder(
  (embedding): Embedding(29, 15)
  (rnn): ModuleList(
    (0): LSTMCell(35, 20)
    (1): LSTMCell(20, 20)
  )
  (attention): DotProductAttention()
  (mlp): Sequential(
    (0): Linear(in_features=40, out_features=20, bias=True)
    (1): Tanh()
    (2): Linear(in_features=20, out_features=29, bias=True)
  )
)

In [6]:
curated = ['/home/ubuntu/wav_clips/common_voice_en_17849044.wav',
           '/home/ubuntu/wav_clips/common_voice_en_19787247.wav',
           '/home/ubuntu/wav_clips/common_voice_en_22270958.wav']
processed_files = [mel_from_wav(a).unsqueeze(0) for a in curated]

In [7]:
def evaluate_from_file(encoder, decoder, features,beam=5, nbest=5):
    with torch.no_grad():
        input_tensor = features.to(device)
        input_length = torch.tensor([features.size(1)], device=device) 
        decoded_words = []

        encoder_outputs, _ = encoder(input_tensor, input_length)
        nbest_hyps = decoder.recognize_beam(encoder_outputs[0], beam, nbest)
        word_index = nbest_hyps[0]['yseq']
        decoded_word = [dictOfindex[a] for a in word_index]

    return decoded_word

In [10]:
def transcribe(lists):
    for data in lists:
        top_words =" ".join(evaluate_from_file(encoder, decoder,data)[1:-1])
        print(top_words)
        print("########\n")


In [11]:
transcribe(processed_files)

t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   t o m e   t h e   s
########

t h e   c o n s t e n   t h e   t o u n d   t h e   t o u n d   t h e   t o u n d   t h e   t o u n d   t h e   t o u n d   t h e   t o u n d   t h e   t o u n d   t h e   t o u n d   t h 