In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm_notebook as tqdm
from sklearn.model_selection import train_test_split
import torchvision
from torchvision import transforms
import math
from PIL import Image
from torchsummary import summary 
from tqdm import trange 
import glob
import os
import unicodedata
import string
import time 
import random 

%matplotlib inline

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = "cpu"
print(device)

cuda:0


In [2]:
data = open('train.txt', encoding='utf-8').read().strip().split('\n')

print(data[0])

chardict =  {'': 0, '.': 1, "'": 2, '-':3, 'A': 4, 'B': 5, 'C': 6, 'D': 7, 'E': 8, 'F': 9, 'G': 10, 'H': 11, 'I': 12, 'J': 13, 'K': 14, 'L': 15, 'M': 16, 'N': 17, 'O': 18, 'P': 19, 'Q': 20, 'R': 21, 'S': 22, 'T': 23, 'U': 24, 'V': 25, 'W': 26, 'X': 27, 'Y': 28, 'Z': 29}
phonedict = {'AA' : 0, 'AE' : 1, 'AH' : 2, 'AO' : 3, 'AW' : 4, 'AY' : 5, 'B' : 6, 'CH' : 7, 'D' : 8, 'DH' : 9, 'EH' : 10, 'ER' : 11, 'EY' : 12, 'F' : 13, 'G' : 14, 'HH' : 15, 'IH' : 16, 'IY' : 17, 'JH' : 18, 'K' : 19, 'L' : 20, 'M' : 21, 'N' : 22, 'NG' : 23, 'OW' : 24, 'OY' : 25, 'P' : 26, 'R' : 27, 'S' : 28, 'SH' : 29, 'T' : 30, 'TH' : 31, 'UH' : 32, 'UW' : 33, 'V' : 34, 'W' : 35, 'Y' : 36, 'Z' : 37, 'ZH' : 38}

rev_chardict = {v:k for k,v in chardict.items()}
rev_phonedict = {v:k for k,v in phonedict.items()}

n_chars = len(chardict) + 3
n_phones = len(phonedict) + 3
n_words = len(data)

def wordToTensor(line):
    tensor = torch.zeros(len(line) + 1, n_chars)
    for li, letter in enumerate(line):
        tensor[li][chardict[letter] + 3] = 1
    tensor[len(line)][2] = 1
    return tensor

def phoneToTensor(line):
    tensor = torch.zeros(len(line), n_phones)
    linelist = line.split('_')
    for li, letter in enumerate(linelist):
        tensor[li][phonedict[letter] + 3] = 1
    tensor[len(linelist)][2] = 1
    return tensor

def pairTensor(i):
    linelist = data[i].split(' ')
    return (wordToTensor(linelist[0]), phoneToTensor(linelist[1]))

max_wordlen = 36
max_phonelen = 20

data = [word for word in data if len(word.split(' ')[0]) <= max_wordlen]

for word in data:
    max_phonelen = max(max_phonelen, len(word.split(' ')[1]))
    
print(max_phonelen)

LEMIEUX L_AH_M_Y_UW
79


In [3]:
n_hidden = 128
batch_size = 128

In [4]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        
        self.gru = nn.LSTM(n_chars, hidden_size, bidirectional = False, num_layers = 2, dropout = 0.2)

    def forward(self, inp, hidden):
        output, hidden = self.gru(inp, hidden)
        return output, hidden

    def initHidden(self):
        return (torch.zeros(2, batch_size, self.hidden_size, device=device), torch.zeros(2, batch_size, self.hidden_size, device=device))
    
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.bn = nn.BatchNorm1d(n_phones)
        self.gru = nn.LSTM(output_size, hidden_size, bidirectional = False, num_layers = 2, dropout = 0.2)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, inp, hidden):
        output = self.bn(inp.squeeze(0))
        output, hidden = self.gru(output.view(1, batch_size, -1), hidden)
        output = self.out(output[0])
        return output, hidden
    
encoder = EncoderRNN(n_chars, n_hidden).to(device)
decoder = DecoderRNN(n_hidden, n_phones).to(device)

encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.001)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=0.001)

In [5]:
def train(input_tensor, target_tensor, max_length=max_wordlen):
    criterion = nn.CrossEntropyLoss(ignore_index = 0)
    encoder_hidden = encoder.initHidden()
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    #input_length = input_tensor.size(0)
    #target_length = target_tensor.size(0)

    loss = 0

    #for ei in range(input_length):
    encoder_output, encoder_hidden = encoder(input_tensor, encoder_hidden)
        
    sos = torch.zeros(1, batch_size, n_phones)
    for i in range (batch_size):
        sos[0][i][1] = 1
    sos = sos.to(device)
    
    decoder_input = sos

    decoder_hidden = encoder_hidden

    decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
    
    target_tensor, target_lengths = nn.utils.rnn.pad_packed_sequence(target_tensor, padding_value=0)
    
    target_length, _ = torch.max(target_lengths, 0)
    target_length = target_length.item()
    
    #print('szz')
    #print(encoder_hidden[0].size())
    #print(target_tensor.size())
    
    for di in range(target_length):
        decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
        loss += criterion(decoder_output, torch.max(target_tensor[di, :, :], 1)[1])
        #print("decoder")
        #print(decoder_output)
        #print("target")
        #print(torch.max(target_tensor[di, :, :], 1)[1])
        decoder_input = target_tensor[di, :, :]
    
    #print(decoder_output.size())
    #print(nn.utils.rnn.pad_packed_sequence(target_tensor, batch_first = True)[0].size())

    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length

In [6]:
import math


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [7]:
def trainIters(encoder, decoder, print_every=1000, learning_rate=0.001, batch_size = 256):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  
    plot_loss_total = 0  
    
    training_pairs = [pairTensor(i)
                      for i in tqdm(range(len(data)))]
    
    random.shuffle(training_pairs)

    for iter in range(1, ((len(data) + 1) // batch_size - 1)):
        input_tensor = [word[0] for word in training_pairs[(iter - 1) * batch_size : iter * batch_size]]
        input_tensor.sort(key=len, reverse = True)
        input_tensor = nn.utils.rnn.pack_sequence(input_tensor)
        
        target_tensor = [word[1] for word in training_pairs[(iter - 1) * batch_size : iter * batch_size]]
        target_tensor.sort(key=len, reverse = True)
        target_tensor = nn.utils.rnn.pack_sequence(target_tensor)
        
        loss = train(input_tensor.to(device), target_tensor.to(device))
        print_loss_total += loss
        plot_loss_total += loss

        if iter % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, iter * batch_size / len(data)),
                                         iter * batch_size, iter * batch_size / len(data) * 100, print_loss_avg))

In [8]:
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib.ticker as ticker
import numpy as np


def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

In [9]:
n_epochs = 100

for i in range(n_epochs):
    print("Epoch %d" % (i + 1))
    trainIters(encoder, decoder, print_every = 200, batch_size = batch_size)

Epoch 1


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 35s) (25600 30%) 1.1584
0m 24s (- 0m 15s) (51200 61%) 0.9620
0m 33s (- 0m 2s) (76800 92%) 0.8706
Epoch 2


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 35s) (25600 30%) 0.8300
0m 24s (- 0m 15s) (51200 61%) 0.8007
0m 33s (- 0m 2s) (76800 92%) 0.7728
Epoch 3


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.7643
0m 24s (- 0m 15s) (51200 61%) 0.7522
0m 33s (- 0m 2s) (76800 92%) 0.7344
Epoch 4


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 16s (- 0m 36s) (25600 30%) 0.7122
0m 24s (- 0m 15s) (51200 61%) 0.7177
0m 33s (- 0m 2s) (76800 92%) 0.7185
Epoch 5


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.6976
0m 24s (- 0m 15s) (51200 61%) 0.6980
0m 34s (- 0m 2s) (76800 92%) 0.6946
Epoch 6


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.6778
0m 24s (- 0m 15s) (51200 61%) 0.6792
0m 32s (- 0m 2s) (76800 92%) 0.6842
Epoch 7


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 33s) (25600 30%) 0.6765
0m 23s (- 0m 14s) (51200 61%) 0.6661
0m 32s (- 0m 2s) (76800 92%) 0.6664
Epoch 8


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 35s) (25600 30%) 0.6588
0m 24s (- 0m 15s) (51200 61%) 0.6543
0m 33s (- 0m 2s) (76800 92%) 0.6604
Epoch 9


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.6479
0m 24s (- 0m 15s) (51200 61%) 0.6438
0m 32s (- 0m 2s) (76800 92%) 0.6518
Epoch 10


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 33s) (25600 30%) 0.6456
0m 23s (- 0m 14s) (51200 61%) 0.6356
0m 32s (- 0m 2s) (76800 92%) 0.6418
Epoch 11


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.6323
0m 23s (- 0m 14s) (51200 61%) 0.6344
0m 32s (- 0m 2s) (76800 92%) 0.6365
Epoch 12


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.6188
0m 24s (- 0m 15s) (51200 61%) 0.6326
0m 32s (- 0m 2s) (76800 92%) 0.6342
Epoch 13


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.6219
0m 23s (- 0m 14s) (51200 61%) 0.6208
0m 32s (- 0m 2s) (76800 92%) 0.6170
Epoch 14


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.6136
0m 23s (- 0m 14s) (51200 61%) 0.6236
0m 32s (- 0m 2s) (76800 92%) 0.6184
Epoch 15


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 35s) (25600 30%) 0.6088
0m 24s (- 0m 15s) (51200 61%) 0.6060
0m 33s (- 0m 2s) (76800 92%) 0.6127
Epoch 16


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.6077
0m 24s (- 0m 15s) (51200 61%) 0.6013
0m 32s (- 0m 2s) (76800 92%) 0.6173
Epoch 17


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.5968
0m 24s (- 0m 15s) (51200 61%) 0.5994
0m 33s (- 0m 2s) (76800 92%) 0.6032
Epoch 18


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.5947
0m 24s (- 0m 15s) (51200 61%) 0.5936
0m 33s (- 0m 2s) (76800 92%) 0.5973
Epoch 19


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 16s (- 0m 36s) (25600 30%) 0.5932
0m 25s (- 0m 15s) (51200 61%) 0.5945
0m 34s (- 0m 2s) (76800 92%) 0.5987
Epoch 20


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 16s (- 0m 36s) (25600 30%) 0.5858
0m 24s (- 0m 15s) (51200 61%) 0.5969
0m 34s (- 0m 2s) (76800 92%) 0.5911
Epoch 21


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.5887
0m 23s (- 0m 14s) (51200 61%) 0.5896
0m 32s (- 0m 2s) (76800 92%) 0.5886
Epoch 22


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.5888
0m 24s (- 0m 15s) (51200 61%) 0.5884
0m 32s (- 0m 2s) (76800 92%) 0.5834
Epoch 23


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.5803
0m 24s (- 0m 14s) (51200 61%) 0.5868
0m 32s (- 0m 2s) (76800 92%) 0.5859
Epoch 24


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.5841
0m 24s (- 0m 15s) (51200 61%) 0.5795
0m 33s (- 0m 2s) (76800 92%) 0.5763
Epoch 25


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 35s) (25600 30%) 0.5792
0m 24s (- 0m 15s) (51200 61%) 0.5796
0m 33s (- 0m 2s) (76800 92%) 0.5829
Epoch 26


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.5749
0m 24s (- 0m 15s) (51200 61%) 0.5838
0m 32s (- 0m 2s) (76800 92%) 0.5746
Epoch 27


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.5697
0m 24s (- 0m 15s) (51200 61%) 0.5708
0m 33s (- 0m 2s) (76800 92%) 0.5808
Epoch 28


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 35s) (25600 30%) 0.5642
0m 24s (- 0m 15s) (51200 61%) 0.5754
0m 33s (- 0m 2s) (76800 92%) 0.5815
Epoch 29


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.5687
0m 24s (- 0m 15s) (51200 61%) 0.5707
0m 33s (- 0m 2s) (76800 92%) 0.5710
Epoch 30


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 34s) (25600 30%) 0.5634
0m 24s (- 0m 15s) (51200 61%) 0.5706
0m 33s (- 0m 2s) (76800 92%) 0.5764
Epoch 31


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 35s) (25600 30%) 0.5659
0m 25s (- 0m 15s) (51200 61%) 0.5684
0m 33s (- 0m 2s) (76800 92%) 0.5687
Epoch 32


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 35s) (25600 30%) 0.5650
0m 25s (- 0m 15s) (51200 61%) 0.5694
0m 35s (- 0m 2s) (76800 92%) 0.5706
Epoch 33


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 16s (- 0m 36s) (25600 30%) 0.5557
0m 25s (- 0m 15s) (51200 61%) 0.5686
0m 33s (- 0m 2s) (76800 92%) 0.5667
Epoch 34


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 16s (- 0m 36s) (25600 30%) 0.5587
0m 25s (- 0m 15s) (51200 61%) 0.5630
0m 34s (- 0m 2s) (76800 92%) 0.5717
Epoch 35


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 16s (- 0m 36s) (25600 30%) 0.5561
0m 25s (- 0m 16s) (51200 61%) 0.5628
0m 35s (- 0m 2s) (76800 92%) 0.5713
Epoch 36


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 16s (- 0m 36s) (25600 30%) 0.5600
0m 25s (- 0m 15s) (51200 61%) 0.5604
0m 34s (- 0m 2s) (76800 92%) 0.5626
Epoch 37


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 16s (- 0m 36s) (25600 30%) 0.5658
0m 25s (- 0m 15s) (51200 61%) 0.5611
0m 34s (- 0m 2s) (76800 92%) 0.5575
Epoch 38


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 16s (- 0m 38s) (25600 30%) 0.5547
0m 26s (- 0m 16s) (51200 61%) 0.5587
0m 35s (- 0m 2s) (76800 92%) 0.5597
Epoch 39


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 35s) (25600 30%) 0.5498
0m 25s (- 0m 15s) (51200 61%) 0.5524
0m 35s (- 0m 2s) (76800 92%) 0.5598
Epoch 40


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 16s (- 0m 37s) (25600 30%) 0.5533
0m 25s (- 0m 16s) (51200 61%) 0.5589
0m 35s (- 0m 2s) (76800 92%) 0.5507
Epoch 41


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))


0m 15s (- 0m 35s) (25600 30%) 0.5531
0m 24s (- 0m 15s) (51200 61%) 0.5538
0m 34s (- 0m 2s) (76800 92%) 0.5613
Epoch 42


HBox(children=(IntProgress(value=0, max=83194), HTML(value='')))




KeyboardInterrupt: 

In [14]:
batch_size = 1

def evaluate(encoder, decoder, sentence, max_length=max_wordlen):
    encoder.eval()
    decoder.eval()
    
    with torch.no_grad():
        input_tensor = wordToTensor(sentence).to(device)
        input_length = input_tensor.size()[0]
        encoder_hidden = encoder.initHidden()

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input_tensor[ei].view(1, 1, -1),
                                                     encoder_hidden)

        sos = torch.zeros(1, 1, n_phones)
        sos[0][0][1] = 1
        sos = sos.to(device)
    
        decoder_input = sos
    
        decoder_hidden = encoder_hidden

        decoded_chars = []

        for di in range(max_length):
            decoder_output, decoder_hidden = decoder(
            decoder_input, decoder_hidden)
            
            topv, topi = decoder_output.data.topk(1)
            
            if topi.item() == 2:
                break
            else:
                decoded_chars.append(rev_phonedict[topi.item() - 3])
            
            decoder_input = torch.zeros(1, 1, n_phones)
            decoder_input[0][0][topi.item()] = 1
            decoder_input = decoder_input.to(device)
            
        print(sentence)
        print("_".join(decoded_chars))
        return "_".join(decoded_chars)

In [15]:
for i in range(40):
    evaluate(encoder, decoder, data[i].split(' ')[0])

LEMIEUX
K_S_EH_L_IH_NG
MINDING
K_S_EH_L_IH_NG
STRIPED
K_S_EH_L_IH_NG
KEN
K_L_AO
CONFERENCE
K_S_EH_L_AH_N_S_T_IY_Z
IMMOLATE
K_S_EH_L_AH_N_T_S
TRANSGRESS
K_S_EH_L_AH_N_S_T_IY_Z
RABBLE
K_S_EH_L_ER_Z
AIRSHARE
K_S_EH_L_AH_N_T_S
INTOLERANCE
K_S_EH_L_AH_N_T_IY_S_AH_N
ILVA
K_B_R_IY
RYGEL
K_S_IY_N_IY
MARLETTE
K_S_EH_L_AH_N_T_S
DILDO
K_S_IY_N_IY
ORELIA
K_S_EH_L_ER_Z
MCNISH
K_S_EH_L_ER_Z
FURBISHED
K_S_EH_L_AH_N_T_IY
COMFED
K_S_EH_L_ER_Z
WALKENHORST
K_S_EH_L_AH_N_T_R_IY_AH_L
MILLIRONS
K_S_EH_L_AH_N_T_IY
JERE
K_B_R_IY
LIVAN'S
K_S_EH_L_IH_NG
PREVIEW
K_S_EH_L_IH_NG
GRAYING
K_S_EH_L_IH_NG
KU
K_AH
FREEHOLD
K_S_EH_L_AH_N_T_S
CONCA
K_S_IY_N_IY
TECK'S
K_S_EH_L_ER_Z
QUINTER
K_S_EH_L_IH_NG
CIRCUMSTANTIAL
K_K_AA_N_T_ER_M_IH_L_Y_AH_N_EH_R
RYDELL
K_S_EH_L_ER_Z
ROTOTILLER
K_S_EH_L_AH_N_S_T_IY_Z
HAVINGTON'S
K_S_IH_M_AH_N_AH_L_AY_Z_D
DECALS
K_S_EH_L_ER_Z
DIBATTISTA
K_S_EH_L_AH_N_T_IY_Z
RAVI'S
K_S_EH_L_ER_Z
INTERCEPTING
K_K_AA_N_T_ER_M_EH_N_T_AH_L
FROMMELT
K_S_EH_L_AH_N_T_S
ACCOMPANIED
K_S_IH_M_AH_N_AH_L_AY_Z_D
WH

In [None]:
import pandas as pd

test = pd.read_csv('test.csv')
test_x = test['Word'].tolist()

In [None]:
test_y = [evaluate(encoder, decoder, word) for word in tqdm(test_x)]
print(test_y)

In [None]:
test.rename(columns={'Word':'Transcription'}, inplace=True) 
test['Transcription'] = test_y

In [None]:
test.to_csv("lstm_one_hot_2_layer_6_epochs.csv",index=False)