# Set up

In [6]:
import numpy as np
import torch
from torch.utils.data import Dataset,DataLoader
from torch.utils import data
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from torch.nn.utils.rnn import *
import gtn
import soundfile as sf
from IPython.display import display, Image

In [7]:
etc_path = "/home/ubuntu/project/an4/etc/"
wav_path = "/home/ubuntu/project/an4/wav/"

In [8]:
train_fileids_path = etc_path + "an4_train.fileids"
train_trans_path = etc_path + "an4_train.transcription"

test_fileids_path = etc_path + "an4_test.fileids"
test_trans_path = etc_path + "an4_test.transcription"

phones_path = etc_path + "an4.phone"
filler_path = etc_path + "an4.filler"
dict_path = etc_path + "an4.dic"

In [9]:
with open(phones_path, 'r') as f:
    phones = f.read().splitlines()

print(phones)
print(len(phones))
phones_plus_one = ['#'] + phones
label_to_phone = dict(enumerate(phones_plus_one))

['AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'EH', 'ER', 'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 'N', 'OW', 'P', 'R', 'S', 'SIL', 'T', 'TH', 'UW', 'V', 'W', 'Y', 'Z']
34


In [10]:
num_phones = len(phones_plus_one)
print(num_phones)

35


In [11]:
word_to_phones = {}
phones_to_words = {}

for d in [dict_path, filler_path]:
    with open(d, 'r') as f:
        mappings = f.read().splitlines()
        for pair in mappings:
            items = pair.split()
            word = items[0]
            transc = tuple(phones.index(i)+1 for i in items[1:])

            if word.endswith(")"):
                word = word[:-3]
                word_to_phones[word].append(transc)
            else:
                word_to_phones[word] = [transc]

            if transc not in phones_to_words:
                phones_to_words[transc] = []
            phones_to_words[transc].append(word)

In [12]:
# word_graphs = {}
def make_word_graph(word, calc_grad=True, blank=0):
    pron = word_to_phones[word]
    G = None
    for pidx in range(len(pron)):
        pronunciation = pron[pidx]
        g = gtn.Graph(calc_grad)
        u = len(pronunciation) * 2 + 1
        for l in range(u):
            idx = (l-1) // 2
            is_start = (l == 0)
            is_accept = (l == u-1) or (l == u-2)
            g.add_node(is_start, is_accept)
            label = pronunciation[idx] if l % 2 else blank
            g.add_arc(l, l, label, label)
            if l > 0:
                g.add_arc(l-1, l, label, label)
            if (l % 2) and (l > 1) and label != pronunciation[idx-1]: # not a repetition
                g.add_arc(l-2, l, label, label)
        
        if pidx == 0:
            G = g
        else:
            G = gtn.union((G, g))
    return G


In [16]:
with open(test_trans_path, 'r') as f:
    transc = f.read().splitlines()
    for line in transc:
        print(line.split()[:-1])

['RUBOUT', 'G', 'M', 'E', 'F', 'THREE', 'NINE']
['ERASE', 'C', 'Q', 'Q', 'F', 'SEVEN']
['B', 'A', 'O', 'Z', 'FIVE', 'THREE']
['GO']
['RUBOUT', 'N', 'I', 'M', 'N', 'ONE']
['W', 'O', 'O', 'D']
['C', 'I', 'N', 'D', 'Y']
['ONE', 'THREE', 'SEVEN']
['M', 'E', 'L', 'V', 'I', 'N']
['P', 'L', 'E', 'A', 'S', 'A', 'N', 'T', 'H', 'I', 'L', 'L', 'S']
['ONE', 'FIVE', 'TWO', 'THREE', 'SIX']
['SIX', 'FIVE', 'FIVE', 'EIGHT', 'SEVEN', 'FOUR', 'ZERO']
['ELEVEN', 'TWENTY', 'SEVEN', 'FIFTY', 'SEVEN']
['NO']
['ENTER', 'NINE', 'ONE', 'SIX', 'NINE']
['ENTER', 'EIGHT', 'NINETY', 'SEVEN']
['J', 'P', 'E', 'G', 'FOUR']
['ERASE', 'X', 'A', 'G', 'N', 'A', 'SIX', 'THIRTY', 'FIVE']
['P', 'A', 'T', 'T', 'E', 'R', 'S', 'O', 'N']
['J', 'A', 'N', 'E', 'T']
['ONE', 'FIFTY']
['S', 'P', 'E', 'E', 'R']
['M', 'C', 'K', 'E', 'E', 'S', 'R', 'O', 'C', 'K', 'S']
['ONE', 'FIVE', 'ONE', 'THREE', 'SIX']
['THREE', 'THREE', 'ONE', 'OH', 'ONE', 'EIGHT', 'EIGHT']
['TWELVE', 'TWENTY', 'NINE', 'FIFTY', 'NINE']
['ENTER', 'NINE', 'TWO', 'EI

In [17]:
def make_sent_graph(word_list, calc_grad=True, blank=0):
    interleaved_list = ['<s>']
    for word in word_list:
        interleaved_list.append(word)
        interleaved_list.append('<sil>')
    interleaved_list = interleaved_list[:-1]
    interleaved_list.append('</s>')
    graphs = [make_word_graph(wrd, calc_grad, blank) for wrd in interleaved_list]
    return gtn.concat(graphs)

# Dataset

In [18]:
kernel_size = 100
stride = 50
padding = 0

def get_conv_out_size(in_len):
    return ((in_len - kernel_size) // stride) + 1

In [19]:
class MyDataset(Dataset):
    def __init__(self, fileids_path, trans_path, end_marked=False):
        super().__init__()
        
        with open(fileids_path, 'r') as f:
            self.audio_files = f.read().splitlines()
        with open(trans_path, 'r') as f:
            transc = f.read().splitlines()   
            self.trans_graphs = []
            for line in transc:
                # make the corresponding graph
                transcription = line.split()[:-1]
                if end_marked:
                    transcription = transcription[1:-1]
                self.trans_graphs.append(make_sent_graph(transcription))
        self.len = len(self.audio_files)
        assert(self.len == len(self.trans_graphs))
        
    def __len__(self):
        return self.len

    def __getitem__(self, ind):
        full_audio_path = wav_path + self.audio_files[ind] + ".sph"
        audio_tensor = torch.FloatTensor(sf.read(full_audio_path)[0]).unsqueeze(1)
        audio_tensor_len = get_conv_out_size(audio_tensor.shape[0])
#         print(audio_tensor.shape, audio_tensor_len)
        return audio_tensor, audio_tensor_len, ind


In [20]:
train_dataset = MyDataset(train_fileids_path, train_trans_path, True)

In [22]:
cuda = torch.cuda.is_available()
print(cuda)
num_workers = 8 if cuda else 0

True


In [23]:
class MyBatch:
    def __init__(self, data):
        zipped = list(zip(*data))
#         print([i.shape for i in zipped[0]])
#         print([a.shape for a in zipped[0]])
        self.X = pad_sequence(zipped[0])
#         print(self.X.shape, [a.shape for a in zipped[0]])
        self.X_len = torch.LongTensor(zipped[1])
#         self.Y = pad_sequence(zipped[2], batch_first=True)
#         self.Y_len = torch.LongTensor(zipped[3])
        self.Y = torch.LongTensor(zipped[2])

    def pin_memory(self):
        self.X = self.X.pin_memory()
        self.X_len = self.X_len.pin_memory()
        self.Y = self.Y.pin_memory()
        return self

def collate_wrapper(batch):
    return MyBatch(batch)

In [36]:
batch_size = 20

In [37]:
train_loader_args = dict(shuffle=True, batch_size=batch_size,
                            num_workers=num_workers, pin_memory=False, drop_last=True, collate_fn=collate_wrapper) if cuda\
                    else dict(shuffle=True, batch_size=batch_size, collate_fn=collate_wrapper)
train_loader = data.DataLoader(train_dataset, **train_loader_args)

In [38]:
test_dataset = MyDataset(test_fileids_path, test_trans_path, True)
test_loader_args = dict(shuffle=False, batch_size=batch_size,
                            num_workers=num_workers, pin_memory=False, drop_last=True, collate_fn=collate_wrapper) if cuda\
                    else dict(shuffle=False, batch_size=batch_size, collate_fn=collate_wrapper)
test_loader = data.DataLoader(test_dataset, **test_loader_args)

# Model

In [112]:
class Model(nn.Module):
    def __init__(self, in_vocab, out_vocab, embed_size, hidden_size):
        super(Model, self).__init__()
        self.cnn = nn.Conv1d(1, 5, kernel_size, stride=stride, padding=padding)
        self.bn = nn.BatchNorm1d(5)
        self.relu = F.relu

        self.lstm = nn.LSTM(5, hidden_size, num_layers=2, bidirectional=True, batch_first=False)
        self.output = nn.Linear(hidden_size*2, out_vocab)
        self.lsm = nn.LogSoftmax(2)
    
    def forward(self, X, lengths):
        X_t = X.permute(1,2,0)
        cnn_xt = self.cnn(X_t)
        cnn_xt = self.bn(cnn_xt)
        cnn_xt = self.relu(cnn_xt)
        cnn_x = cnn_xt.permute(2,0,1)
        
        packed_X = pack_padded_sequence(cnn_x, lengths, enforce_sorted=False)
        packed_out = self.lstm(packed_X)[0]
        out, out_lens = pad_packed_sequence(packed_out)
        out = self.output(out)
        out = self.lsm(out)
        return out, out_lens

In [115]:
class GTNLossFunction(torch.autograd.Function):
        
    @staticmethod
    def forward(ctx, inputs, input_lengths, targets): #targets is index in trans_graph
        _, B, _ = inputs.shape
        losses = [None] * B
        emissions_graphs = [None] * B
    
        def forward_single(b):
            T = input_lengths[b]
            weights = torch.clamp(inputs[:T, b, :], min=-1e-20).flatten().contiguous()
            flat_weights = weights.data_ptr()
            
            ind = targets[b]

            emit = gtn.linear_graph(T, num_phones, inputs.requires_grad)
            emit.set_weights(flat_weights) # set in the same order as the phoneme order

            constraint = train_dataset.trans_graphs[ind]

            alignment = gtn.compose(constraint, emit)
            ctc_loss = gtn.subtract(gtn.forward_score(emit), gtn.forward_score(alignment))

            losses[b] = ctc_loss
            emissions_graphs[b] = emit

        gtn.parallel_for(forward_single, range(B))

        ctx.auxiliary_data = (losses, emissions_graphs, inputs.shape, input_lengths, targets)
        
        retval = torch.tensor([l.item() for l in losses])
        return retval

    @staticmethod
    def backward(ctx, grad_output, retain_graph=True):
        losses, emissions_graphs, in_shape, input_lengths, targets = ctx.auxiliary_data
        T, B, C = in_shape
        input_grad = torch.empty((T, B, C))

        def backward_single(b):
            gtn.backward(losses[b], retain_graph=retain_graph)
            emissions = emissions_graphs[b]
            grad = emissions.grad().weights_to_numpy()
            input_grad[:int(input_lengths[b].detach()),b,:] = torch.from_numpy(grad).view(
                                                                int(input_lengths[b].detach()), C)
            train_dataset.trans_graphs[targets[b]].zero_grad()


        gtn.parallel_for(backward_single, range(B))
        return input_grad, None, None

# make an alias for the loss function:
GTNLoss = GTNLossFunction.apply

# Training Utilities

In [116]:
import time

def train_epoch(model, train_loader, criterion, optimizer):

    model.train()
    
    running_loss = 0.0
    start_time = time.time()
    for batch_ids, batch in enumerate(train_loader):
        optimizer.zero_grad()
        
        x = batch.X.to(device)
        x_len = batch.X_len.to(device)
        y = batch.Y
        out, out_lens = model(x, x_len)
        out = out.to("cpu")
        out_lens = out_lens.to("cpu")
        loss = criterion(out, out_lens, y).mean()
        print("batch:", batch_ids, "loss", loss.item())
        running_loss += loss.item()
        loss.backward(retain_graph=False)

        optimizer.step()
        
    end_time = time.time()
    
    running_loss /= len(train_loader)
    
    print("Training Loss: ", running_loss, "Time: ", end_time - start_time, "s")
    return running_loss

In [117]:
def test_model(model, test_loader, criterion):
    with torch.no_grad():
        model.eval()
        
        running_loss = 0.0
        
        for batch_idx, batch in enumerate(test_loader):
            X = batch.X.to(device)
            X_len = batch.X_len.to(device)
            Y = batch.Y
            
            out, out_lens = model(X, X_len)
            out = out.to("cpu")
            out_lens = out_lens.to("cpu")

            loss = criterion(out, out_lens, Y).mean()
            running_loss += loss.item()
            
        running_loss /= len(test_loader)
        print("Testing Loss: ", running_loss)
        return running_loss

In [123]:
torch.manual_seed(32)
model = Model(40, num_phones, 5, 50)
criterion = GTNLoss
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1
                           )
torch.autograd.set_detect_anomaly(True)

device = torch.device("cuda" if cuda else "cpu")
model.to(device)
print(model)

Model(
  (cnn): Conv1d(1, 5, kernel_size=(100,), stride=(50,))
  (bn): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (lstm): LSTM(5, 50, num_layers=2, bidirectional=True)
  (output): Linear(in_features=100, out_features=35, bias=True)
  (lsm): LogSoftmax()
)


In [124]:
num_epochs = 5
training_losses = []
testing_losses = []

for i in range(num_epochs):
    train_loss = train_epoch(model, train_loader, criterion, optimizer)
    torch.save({'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
               }, "Model_"+str(i))
    
    test_loss = test_model(model, test_loader, criterion)
    training_losses.append(train_loss)
    testing_losses.append(test_loss)
#     for name, param in model.named_parameters():
#         if param.requires_grad:
#             print(name, param.data)
    print("=====End of Epoch", i, "=====")

batch: 0 loss 2855.896728515625
batch: 1 loss 3096.24658203125
batch: 2 loss 3302.682373046875
batch: 3 loss 2775.615234375
batch: 4 loss 3078.640869140625
batch: 5 loss 2957.93115234375
batch: 6 loss 2454.426513671875
batch: 7 loss 2680.68994140625
batch: 8 loss 2464.47607421875
batch: 9 loss 3280.743408203125
batch: 10 loss 2879.221923828125
batch: 11 loss 3005.934814453125
batch: 12 loss 2687.364501953125
batch: 13 loss 2483.84716796875
batch: 14 loss 2571.17919921875
batch: 15 loss 2900.325439453125
batch: 16 loss 3227.47509765625
batch: 17 loss 2280.962158203125
batch: 18 loss 3034.94921875
batch: 19 loss 2818.93408203125
batch: 20 loss 2659.01904296875
batch: 21 loss 2890.999755859375
batch: 22 loss 2499.42138671875
batch: 23 loss 2743.695068359375
batch: 24 loss 3134.795654296875
batch: 25 loss 3098.53076171875
batch: 26 loss 3130.005859375
batch: 27 loss 2535.150146484375
batch: 28 loss 2206.78564453125
batch: 29 loss 3119.938232421875
batch: 30 loss 2510.900634765625
batch: 31

# Decoding

In [88]:
def decode_model(model, test_loader):
    with torch.no_grad():
        model.eval()
        
        running_loss = 0.0
        
        for batch_idx, batch in enumerate(test_loader):
            X = batch.X.to(device)
            X_len = batch.X_len.to(device)
            Y = batch.Y
            
            out, out_lens = model(X, X_len)
            out = out.to("cpu")
            out_lens = out_lens.to("cpu")
            
            for b in range(batch_size):
                print(b)
                T = out_lens[b]
                weights = out[:T, b, :].flatten().contiguous()
                flat_weights = weights.data_ptr()

                ind = Y[b]

                emit = gtn.linear_graph(T, num_phones, out.requires_grad)
                emit.set_weights(flat_weights)
                prediction = gtn.viterbi_path(emit).labels_to_list(False)
                print(prediction)
                break  
            break
        

In [89]:
decode_model(model, test_loader)

0
[6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6