In [1]:
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
import time
import numpy as np
# script_path_1 = "./"
# sys.path.append(os.path.abspath(script_path_1))
# script_path_2 ="./timit/steps/"
# sys.path.append(os.path.abspath(script_path_2))
# script_path_3 = "./timit/utils/"
# sys.path.append(os.path.abspath(script_path_3))
from libs.utils import get_recursive_files
from scipy.io import wavfile

from models.model_ctc import *
from utils.ctcDecoder import GreedyDecoder, BeamDecoder
from utils.data_loader import Vocab, SpeechDataset, SpeechDataLoader
from steps.train_ctc import Config

import yaml
import argparse

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# parser = argparse.ArgumentParser()
# parser.add_argument('--conf', help='conf file for training')            

In [5]:
def test(confile):
#     args = parser.parse_args()
    try:
        conf = yaml.safe_load(open(confile,'r'))
    except:
        print("Config file not exist!")
        sys.exit(1)    
    
    opts = Config()
    for k,v in conf.items():
        setattr(opts, k, v)
        print('{:50}:{}'.format(k, v))

    use_cuda = opts.use_gpu
    device = torch.device('cuda') if use_cuda else torch.device('cpu')
    
    model_path = os.path.join(opts.checkpoint_dir, opts.exp_name, 'ctc_best_model.pkl')
    package = torch.load(model_path)
    
    rnn_param = package["rnn_param"]
    add_cnn = package["add_cnn"]
    cnn_param = package["cnn_param"]
    num_class = package["num_class"]
    feature_type = package['epoch']['feature_type']
    n_feats = package['epoch']['n_feats']
    drop_out = package['_drop_out']
    mel = opts.mel

    beam_width = opts.beam_width
    lm_alpha = opts.lm_alpha
    decoder_type =  opts.decode_type
    vocab_file = opts.vocab_file
    print(vocab_file)
    vocab = Vocab(vocab_file)
    test_dataset = SpeechDataset(vocab, opts.test_scp_path, opts.test_lab_path, opts)
    test_loader = SpeechDataLoader(test_dataset, batch_size=opts.batch_size, shuffle=False, num_workers=opts.num_workers, pin_memory=False)
    
    model = CTC_Model(rnn_param=rnn_param, add_cnn=add_cnn, cnn_param=cnn_param, num_class=num_class, drop_out=drop_out)
    model.to(device)
    model.load_state_dict(package['state_dict'])
    model.eval()
    
    
    if decoder_type == 'Greedy':#'Beam':
        decoder  = GreedyDecoder(vocab.index2word, space_idx=-1, blank_index=0)
    else:
        decoder = BeamDecoder(vocab.index2word, beam_width=beam_width, blank_index=0, space_idx=-1, lm_path=opts.lm_path, lm_alpha=opts.lm_alpha)    
   
    total_wer = 0
    total_cer = 0
    start = time.time()
    with torch.no_grad():
        for data in test_loader:
            inputs, input_sizes, targets, target_sizes, utt_list = data 
            inputs = inputs.to(device)
            #rnput_sizes = input_sizes.to(device) 
            #target = target.to(device)
            #target_sizes = target_sizes.to(device)
            
            probs = model(inputs)

            max_length = probs.size(0)
            input_sizes = (input_sizes * max_length).long()

            probs = probs.cpu()
            decoded = decoder.decode(probs, input_sizes.numpy().tolist())
            
            targets, target_sizes = targets.numpy(), target_sizes.numpy()
            labels = []
            for i in range(len(targets)):
                label = [ vocab.index2word[num] for num in targets[i][:target_sizes[i]]]
                labels.append(' '.join(label))

            for x in range(len(targets)):
                print("origin : " + labels[x])
                print("decoded: " + decoded[x])
            cer = 0
            wer = 0
            for x in range(len(labels)):
                cer += decoder.cer(decoded[x], labels[x])
                wer += decoder.wer(decoded[x], labels[x])
                decoder.num_word += len(labels[x].split())
                decoder.num_char += len(labels[x])
            total_cer += cer
            total_wer += wer
    CER = (float(total_cer) / decoder.num_char)*100
    WER = (float(total_wer) / decoder.num_word)*100
    print("Character error rate on test set: %.4f" % CER)
    print("Word error rate on test set: %.4f" % WER)
    end = time.time()
    time_used = (end - start) / 60.0
    print("time used for decode %d sentences: %.4f minutes." % (len(test_dataset), time_used))


In [6]:
test("conf/ctc_config.yaml")

exp_name                                          :ctc_fbank_cnn
checkpoint_dir                                    :checkpoint/
vocab_file                                        :data/units
train_scp_path                                    :data/train/fbank.scp
train_lab_path                                    :data/train/phn_text
valid_scp_path                                    :data/dev/fbank.scp
valid_lab_path                                    :data/dev/phn_text
left_ctx                                          :0
right_ctx                                         :2
n_skip_frame                                      :2
n_downsample                                      :2
num_workers                                       :1
shuffle_train                                     :True
feature_dim                                       :81
output_class_dim                                  :39
mel                                               :False
feature_type                              

origin : sil ah sil k eh r ih sil jh sil k r ah n ch sil b ay ih sil s sil d ih m l ay sil t s f ih l sil t r ih ng th uw dh ih sil g l uw m sil
decoded:  sil ah sil k eh r sil jh sil k r ih n sh s sil b ay sil z sil d ih m l ay n sil t s sil th ow sil t ih ng th uw dh ih sil g l uw m sil
origin : sil dh ey m ey sil k ah s sil k n f aa r m ih s l uh sil g uh sil sil
decoded:  sil dh ey m ey sil k ah s sil k n f aa r m ih s l uh sil g uh sil sil
origin : sil w iy ow w ih z th aa w ih w uh sil d ay w ih th aa r sil b uw sil t s aa n sil
decoded:  sil w iy aa l w ih z dh th aw w iy w ih sil d ay w ah th er sil b uw sil d s ah n sil
origin : sil hh uw aa th er ay z dh iy ih n l ih m ih dx ih dx ih sil k s sil p eh n s ih sil k aw n sil sil
decoded:  sil hh uw aa aw th sil t er r ay th iy ih n l ih m ih n ih dx ih sil k s sil p eh s ih sil k aw sil sil
origin : sil t uw f er dh er hh ih z sil p r ih s sil t iy sil jh hh iy sil k ey sh n l ih r iy sil z dh ah w aa l s sil t r iy sil jh er n 

origin : sil hh iy sil t r eh m sil b l sil l eh s sil t hh ih z sil p iy sh ih sil f ey l sil
decoded:  sil hh iy sil ch r ae m v ah l sil l eh s sil t ih z sil p iy sh uh sil d f r ey l sil
origin : sil n ow dh ah m ae n w ah z sil n aa sil d r ah ng sil k sil hh iy w ah n sil d er sil d hh aw iy sil g aa sil t ay dx ah sil w ih th ih s sil t r ey n jh er sil
decoded:  sil n ow dh ah m ae n w ah z sil n aa r sil d r ah n sil t sil hh iy w ah n sil d er sil d hh aw iy sil g aa r sil t ay sil d ah sil w ah th ih s sil t r ey n jh er sil
origin : sil hh ih z w er sil b iy sil g ae n jh ih s ih sil k s sil d ey z ae f sil t ah dh ah f l ah sil sil
decoded:  sil hh ih z w er sil b ih sil g ae n sil jh ih s ih sil k s sil t ey z ae f sil t ah sil t ah f l ah sil sil
origin : sil d ih sil jh uw sil b ay ih n iy sil k aa r dx er oy ow v er aa l s sil
decoded:  sil d ih sil jh uw sil b ay ih n iy sil k aa r dx er ey ow v er aa l z sil
origin : sil s sil t iy v w aa r ah sil b r ay sil r eh si

origin : sil w ih m ey s ey dh eh sil hh ih z sil p r aa sil l ah m w ah z sil d ay sil n ow s sil t sil sil b ah sil dh ae sil t hh iy r uw f uw z sil t r iy sil m ih n sil sil
decoded:  sil w ih m ey s ey dh eh sil hh ih s sil p r aa sil b ah m w ah s sil b ay sil n ow z sil t sil b ih sil dh eh sil k hh iy r ih f y uw s sil t r iy sil m ih sil sil
origin : sil w iy sil k ih ng sil g eh dx ih sil ih f w ih sil d ih ih sil g sil hh iy s eh sil p ey sh n sil t l iy sil
decoded:  sil w iy sil k ih ng sil g ih dx ih sil ih f er sil d iy sil g sil ey s ih sil p ey sh ih sil d l iy sil
origin : sil d uw w ih th aw sil f ae n sil s iy sil t ey sil b l sil k l aa dh s sil
decoded:  sil d uw w ah th aw f eh n sil s iy sil t ey sil b l k w aa l ah z sil
origin : sil dh ah sil b eh s sil t w ey sil t ah l er n sil ih z sil t ih s aa l v eh sil k s sil t er sil p r aa sil b l ah m s sil
decoded:  sil dh ah sil b ae s sil k w ey sil t ah m er n sil ih z sil t ih s aa l v eh sil k s sil t er sil p

In [None]:
three_dim_ary=np.array([
                    # t, c, n, s
                    [[1, 16, 1, 1],
                    [6, 24, 2, 2],
                    [6, 32, 3, 2],
                    [6, 64, 4, 2],
                    [6, 96, 3, 1],
                    [6, 160, 3, 2],
                    [6, 320, 1, 1]],
    
                    [[1, 16, 1, 1],
                    [6, 24, 2, 2],
                    [6, 32, 3, 2],
                    [6, 64, 4, 2],
                    [6, 96, 3, 1],
                    [6, 160, 3, 2],
                    [6, 320, 1, 1]]
              ])

In [12]:
print(three_dim_ary.shape)

(2, 7, 4)


In [13]:
tensor_ary = torch.from_numpy(three_dim_ary)

In [29]:
a = torch.transpose(tensor_ary,1,2)
b = torch.transpose(tensor_ary,0,-1)

In [30]:
print(a.shape)
print(b.shape)
print(b)

torch.Size([2, 4, 7])
torch.Size([4, 7, 2])
tensor([[[  1,   1],
         [  6,   6],
         [  6,   6],
         [  6,   6],
         [  6,   6],
         [  6,   6],
         [  6,   6]],

        [[ 16,  16],
         [ 24,  24],
         [ 32,  32],
         [ 64,  64],
         [ 96,  96],
         [160, 160],
         [320, 320]],

        [[  1,   1],
         [  2,   2],
         [  3,   3],
         [  4,   4],
         [  3,   3],
         [  3,   3],
         [  1,   1]],

        [[  1,   1],
         [  2,   2],
         [  2,   2],
         [  2,   2],
         [  1,   1],
         [  2,   2],
         [  1,   1]]])
