In [1]:
import os
import sys
import datetime
import string
import random
import pickle
import numpy as np
import pandas as pd


NAME = 'AMSGrad_pat_10_cool_5' # helps to differentiate between various training instances

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchtext.data import Field, BucketIterator, TabularDataset

sys.path.append(os.path.abspath(os.path.join('..')))

from models.las_model.data import SpeechDataset, AudioDataLoader
from models.las_model.listener import Listener
from models.las_model.attend_and_spell import AttendAndSpell
from models.las_model.seq2seq import Seq2Seq
# from models.las_model.utils import  train

In [2]:
DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else 'cpu'
print('DEVICE :', DEVICE)

DEVICE : cuda:1


### Preprocessing

In [3]:
data_dir = '../../../Dataset/Sinhala'

remove_chars = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', \
                 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', \
                'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w',  'x', 'y', 'z', \
                '“', '”', '\u200b', '\u200c', '\u200d', 'µ', '\x94', '»', 'ª', '’', '‘']


def preprocess(s):
    s = s.replace('\n', '')  # remove '\n'
    return s.translate(str.maketrans('', '', string.punctuation)) # remove punctuation


# reading the main transcript
lines = []
with open(os.path.join(data_dir, 'utt_spk_text.tsv'), 'r', encoding='utf-8') as f:
    lines = f.readlines()

examples = []
for l in lines:
    append = True
    id_, _, sent = l.split('\t')
    sent = preprocess(sent)
    for c in sent:
        if c in remove_chars:  # removing sentences with eng_chars
            append = False
            break
    if append:
        examples.append((id_+'.flac', sent))

data_df = pd.DataFrame(examples, columns=['path', 'sent'])
data_df.to_csv(os.path.join(data_dir, 'data_df.csv')) # save
print("Number of Training examples:", data_df.shape[0])
data_df.head(5)

Number of Training examples: 149569


Unnamed: 0,path,sent
0,0000f47c22.flac,මහවැලි ගඟට ගොස් ආපසු එන ගමනේදී
1,000101700f.flac,උන්වහන්සේ කපාපු
2,000107b539.flac,එය එතනින් අවසන් නොවී
3,00016825d3.flac,සිතින් අයහපතෙහි හැසිරීම නිසයි
4,0002205a57.flac,ඊට අවසරයද හිමිවූ බව ඇය කියන්නීය


We have tried removing all the unnecessary characters from the dataset. The others will be replaced by unknown token, while training.

### Load data

In [4]:
from sklearn.model_selection import train_test_split

data_df = pd.read_csv(os.path.join(data_dir, 'data_df.csv'), usecols=['path', 'sent'])
train_df, val_df = train_test_split(data_df, test_size=0.01)
print("Num training example:", train_df.shape)
print("Num validation example", val_df.shape)
train_df.head()

Num training example: (148073, 2)
Num validation example (1496, 2)


Unnamed: 0,path,sent
30931,354cb92d3a.flac,ඇවිල්ල මහන්සි හින්ද
97697,a7957688f2.flac,ඒවායේ නියම කරන
46228,4fa1969ee0.flac,ඉතින් මචන් ලක්ෂයක් නොවේ
37775,412e85986b.flac,ඔවුන් තුළ ඇතිවන්නේ
105650,b5683e2366.flac,ඊට එරෙහිව දකුණේ සිංහල තරුණ කැරලි මගින් ඇති කරන


### Vocabulary

In [5]:
def get_chars(train_df):
    chars = ['<pad>', '<unk>', '<sos>', '<eos>']
    for idx in range(train_df.shape[0]):
        id_, sent = train_df.iloc[idx]
        for c in sent:
            if c not in chars:
                chars.append(c)
    return chars
    

chars = get_chars(train_df)
char_to_token = {c:i for i,c in enumerate(chars)} 
token_to_char = {i:c for c,i in char_to_token.items()}
sos_token = char_to_token['<sos>']
eos_token = char_to_token['<eos>']
pad_token = char_to_token['<pad>']
unk_token = char_to_token['<unk>']

print("Number of characters:", len(chars))
print(chars)

Number of characters: 82
['<pad>', '<unk>', '<sos>', '<eos>', 'ඇ', 'ව', 'ි', 'ල', '්', ' ', 'ම', 'හ', 'න', 'ස', 'ද', 'ඒ', 'ා', 'ය', 'ේ', 'ක', 'ර', 'ඉ', 'ත', 'ච', 'ෂ', 'ො', 'ඔ', 'ු', 'ළ', 'ඊ', 'ට', 'එ', 'ෙ', 'ණ', 'ං', 'ැ', 'ග', 'අ', 'ප', 'ධ', 'ී', 'ජ', 'ශ', 'ඞ', 'බ', 'ඳ', 'ඩ', 'ඕ', 'ෑ', 'ආ', 'ඥ', 'ූ', 'උ', 'ෞ', 'ෝ', 'ඛ', 'ථ', 'ඟ', 'භ', 'ෘ', 'ඹ', '–', 'ඬ', 'ඝ', 'ෆ', 'ඨ', 'ඈ', 'ඡ', 'ඓ', 'ෛ', 'ඌ', 'ඤ', 'ඃ', 'ඖ', 'ඵ', 'ෲ', 'ඣ', 'ඍ', 'ෳ', 'ඪ', 'ෟ', '෴']


### Instantiate model

In [6]:
input_size = 128    # num rows in instagram
hidden_dim = 640  # 256*2 nodes in each LSTM
num_layers = 4
dropout = 0.1
layer_norm = True   
encoder = Listener(input_size, hidden_dim, num_layers, dropout=dropout, layer_norm=layer_norm)

hid_sz = 640
embed_dim = 50
vocab_size = len(chars)
decoder = AttendAndSpell(embed_dim, hid_sz, encoder.output_size, vocab_size)

hyperparams = {'input_size':input_size, 'hidden_dim':hidden_dim, 
               'num_layers':num_layers,'dropout':dropout, 
               'layer_norm':layer_norm, 'hid_sz':hid_sz, 
               'embed_dim':embed_dim, 'vocab_size':vocab_size}

model = Seq2Seq(encoder, decoder, tf_ratio = 1.0, device=DEVICE).to(DEVICE)
model.train()

Seq2Seq(
  (encoder): Listener(
    (layers): ModuleList(
      (0): piBLSTM(
        (lstm): LSTM(128, 640, batch_first=True, bidirectional=True)
        (ln): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (dp): Dropout(p=0.1, inplace=False)
      )
      (1): piBLSTM(
        (lstm): LSTM(2560, 640, batch_first=True, bidirectional=True)
        (ln): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (dp): Dropout(p=0.1, inplace=False)
      )
      (2): piBLSTM(
        (lstm): LSTM(2560, 640, batch_first=True, bidirectional=True)
        (ln): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (dp): Dropout(p=0.1, inplace=False)
      )
      (3): piBLSTM(
        (lstm): LSTM(2560, 640, batch_first=True, bidirectional=True)
        (ln): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (dp): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): AttendAndSpell(
    (embedding): Embedding(82, 50)
    (attention_layer): At

### Training

In [7]:
# model.load_state_dict(torch.load(os.path.join(save_dir, 'las_model_1')))
# model.train()

# load = False
# if load:
#     saved_file = 'Trained Models/Training_2019-12-25 00:09:23.921978/las_model_6'
#     model.load_state_dict(torch.load(saved_file))
#     start_epoch = int(saved_file[-1]) + 1
#     time = os.listdir(tensorboard_dir)[-1]  # use the last one 

time = str(datetime.datetime.now())
save_dir = os.path.join('trained_models', f'{NAME}_{time}')
try:    
    os.mkdir(save_dir);
except FileExistsError:
    pass

# Saving hyperparmas
with open(os.path.join(save_dir, 'info.pickle'), 'wb') as f:
    pickle.dump(hyperparams, f)


train_dataset = SpeechDataset(train_df, data_dir, char_to_token, n_fft=1024, hop_length=256)
train_loader = AudioDataLoader(pad_token, train_dataset, batch_size=64, 
                               shuffle=True, drop_last=True, num_workers=8)

In [8]:
def train(model, device, train_loader, optimizer, epoch, 
          print_interval, writer=None, log_interval=-1, scheduler=None):
    
    model.train()
    print(f'Training, Logging: Mean loss of previous {print_interval} batches \n')
    
    running_loss = []
    date1 = datetime.datetime.now()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        loss, _ = model(data, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss.append(loss.detach().item())    # update running loss
        
        # writing to console after print_interval batches
        if (batch_idx+1) % print_interval == 0:
            date2 = datetime.datetime.now()
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tMean Loss : {:.6f}\t lr {}\t time {}:'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), 
                np.mean(running_loss[-print_interval:]), 
                optimizer.state_dict()['param_groups'][0]['lr'],
                date2 - date1))
            date1 = date2
            if scheduler:
                scheduler.step(np.mean(running_loss[-print_interval:]))

        # Writing to tensorboard
        if (batch_idx+1) % log_interval == 0:
            if writer:
                global_step = epoch * len(train_loader) + batch_idx
                writer.add_scalar('Loss', np.mean(running_loss[-log_interval:]), global_step)

In [9]:
# optimizer = optim.SGD(model.parameters(), lr=0.2)  # lr = 0.2 used in paper
# optimizer = optim.Adadelta(model.parameters())
optimizer = optim.Adam(model.parameters(), amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, verbose=True, cooldown=5, min_lr=0.00001)

# hence approximately waiting for print_interval*batch_size*(patience+cooldown) to improve
log_interval = 5
print_interval = 50

epochs = 20
load = False

writer = SummaryWriter(save_dir)
print('save_dir', save_dir)

for epoch in range(0, epochs): 
    train(model, DEVICE, train_loader, optimizer, epoch, print_interval, writer, log_interval, scheduler)
    
    #save model
    torch.save(model.state_dict(), os.path.join(save_dir, f'las_model_{epoch}'))
    
    # Decrease tf_ratio
    if (epoch+1)%10 == 0:
        model.tf_ratio = model.tf_ratio - 0.5
        print("\nTeacher forcing ratio:", model.tf_ratio)

save_dir trained_models/AMSGrad_pat_10_cool_5_2019-12-30 22:04:33.443664
Training, Logging: Mean loss of previous 50 batches 

Training, Logging: Mean loss of previous 50 batches 



Training, Logging: Mean loss of previous 50 batches 

Epoch    99: reducing learning rate of group 0 to 1.0000e-04.
Epoch   115: reducing learning rate of group 0 to 1.0000e-05.
Epoch   135: reducing learning rate of group 0 to 1.0000e-06.
Training, Logging: Mean loss of previous 50 batches 

Epoch   161: reducing learning rate of group 0 to 1.0000e-07.


KeyboardInterrupt: 

In [None]:
### DOES DEEPER NETWORK HELP ?
YES

### DOES AMSGRAD HELP ?

### DOES LAYER NORMALIZATION HELP ?
YES, WITH SGD

### TEST

In [10]:
def decode_pred_sent(out):
    pred_sent = []
    out = out.squeeze(0)
    for t in out:
        lol = t.max(dim=0)[1].item()
        pred_sent.append(token_to_char[lol])
    return ''.join(pred_sent)


def decode_true_sent(y):
    sent = []
    for t in y:
        sent.append(token_to_char[t.item()])
    return ''.join(sent)

In [11]:
num_sent = 10
model.eval()

for _ in range(num_sent):
    
    idx = random.randint(0, train_df.shape[0])
    trial_dataset = SpeechDataset(train_df, data_dir, char_to_token)

    x, y = trial_dataset.__getitem__(idx)
    # plt.imshow(x[0,:,:].detach())

    # Model output
    print(y.shape)
    
    target = y.unsqueeze(dim=0).to(DEVICE)
    data = x.permute(0, 2, 1).to(DEVICE)
    loss, output = model(data, target)
    print(output.shape)
    print("True sent : ", decode_true_sent(y), end='\n\n')
    print("Pred sent : ", decode_pred_sent(output))
    print("Loss :", loss.item())    
    print("\n")

torch.Size([14])
torch.Size([1, 14, 82])
True sent :  දුකක් දැනෙනවා<eos>

Pred sent :  මා නනනකනනනනනනන
Loss : 4.253574848175049


torch.Size([9])
torch.Size([1, 9, 82])
True sent :  මේ ශාසනය<eos>

Pred sent :  මා කනනනනන
Loss : 4.093790054321289


torch.Size([35])
torch.Size([1, 35, 82])
True sent :  හම්බවෙන පිළිවෙලටනේ සයිට් එකේ යන්නේ<eos>

Pred sent :  මා  නනනනකනනනනනනන න කනනන නකකන කනනනන 
Loss : 4.266133785247803


torch.Size([45])
torch.Size([1, 45, 82])
True sent :  මනා සංවරයෙන් යුතු භික්ෂුන් වහන්සේ නමක් ලෙසයි<eos>

Pred sent :  මා නකනකනනනනනනකනනනනකනනනනනනනනකනනනනන කනනන කනනනනන
Loss : 4.285437107086182


torch.Size([30])
torch.Size([1, 30, 82])
True sent :  කෙනෙකු විසින් මරාදමා ඇති බවයි<eos>

Pred sent :  මානනනනනකනනනනනනකනනනන නකනනනකනනනන
Loss : 4.340115070343018


torch.Size([25])
torch.Size([1, 25, 82])
True sent :  කොපි කරන එකේ අපි එයාගෙන්<eos>

Pred sent :  මාකනනකනනනකකන කනනනකකනනනනනන
Loss : 4.238231658935547


torch.Size([28])
torch.Size([1, 28, 82])
True sent :  නමින් හැඳින්වීම අර්ථාන්විත

### Trying with Torchtext

In [None]:
## Knowing the frequency of words

def process(s):
    return list(s)

si_field = Field(
    tokenizer_language='si',
    lower=True, 
    init_token='<sos>', 
    eos_token='<eos>',
    batch_first=True,
    preprocessing=process
)

dataset = TabularDataset(
    path=os.path.join(data_dir, 'temp.csv'),
    format='CSV',
    fields=[('index', None),('unnamed', None), ('sent', si_field)]
)

In [None]:
si_field.build_vocab(dataset, min_freq=2)
print(len(si_field.vocab.stoi))

### Hacking the optimizer

In [None]:
optimizer = optim.SGD(model.parameters(), lr=0.03)

In [None]:
optimizer.state_dict()['param_groups'][0]['lr']