In [1]:
import os
import datetime
import torch
import random
import pickle
import numpy as np
import pandas as pd

#os.chdir(os.path.join(os.getcwd(), 'LAS Model'))
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from data import SpeechDataset, AudioDataLoader
from listener import Listener
from attend_and_spell import AttendAndSpell
from seq2seq import Seq2Seq
from utils import  train

### Load Training data

In [4]:
# Used for ai_shell dataset
def make_train_df(dataset_dir):
    data = []
    files = os.listdir(dataset_dir)
    for f in files:
        if '.txt' in f:
            with open(os.path.join(dataset_dir, f), 'r') as text_file:
                data_list = text_file.readlines()
            for example in data_list:
                id_, sent = str(example.split(' ')[0]), str(' '.join(example.split(' ')[1:])) # -1 to remove '\n'
                data.append((id_, sent))

    train_df = pd.DataFrame(data, columns=['id', 'sent'])
    train_df.to_csv(os.path.join(dataset_dir, 'train_df.csv'), header=None)#save


dataset_dir = '../../../Dataset/data_aishell/'
DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else 'cpu'
print('DEVICE :', DEVICE)

train_df = pd.read_csv(os.path.join(dataset_dir, 'train_df.csv'), names=['id', 'sent'])
train_df = train_df.dropna(how='any')
print(train_df.head())

DEVICE : cuda:1
                 id                        sent
0  BAC009S0002W0122     而 对 楼市 成交 抑制 作用 最 大 的 限
1  BAC009S0002W0123             也 成为 地方 政府 的 眼中
2  BAC009S0002W0124  自 六月 底 呼和浩特 市 率先 宣布 取消 限 购
3  BAC009S0002W0125                  各地 政府 便 纷纷
4  BAC009S0002W0126              仅 一 个 多 月 的 时间


### Dataset Analysis and Cleaning

Mandarin data contains relatively small sentences and does not need removal of long sentences.

In [6]:
import torchaudio
from torchaudio.transforms import MelSpectrogram

specgram = MelSpectrogram()
audio, sent = train_df.iloc[2]
waveform, sample_rate = torchaudio.load(os.path.join(dataset_dir, audio+'.wav'))
x = specgram(waveform)

print("sample rate:", sample_rate)
print("x.shape:", x.shape)
print("sent len:", len(sent))

sample rate: 16000
x.shape: torch.Size([1, 128, 433])
sent len: 26


### DataLoaders and hyperparams

In [7]:
def get_chars(save_file, train_df=None):
    try:
        with open(save_file, 'rb') as f:
            chars = pickle.load(f) # load file
    except FileNotFoundError:
        chars = [' ', '<sos>']
        for idx in range(train_df.shape[0]):
            _, sent = train_df.iloc[idx]
            for c in sent:
                if c not in chars:
                    chars.append(c)
        chars = chars + ['<eos>', '<pad>', '<unk>']
        with open(save_file, 'wb') as f:
            pickle.dump(chars, f) # save file
    print('Number of chars', len(chars))
    return chars


save_file = os.path.join(dataset_dir, 'chars')
chars = get_chars(save_file, train_df)
char_to_token = {c:i for i,c in enumerate(chars)} 
token_to_char = {i:c for c,i in char_to_token.items()}
sos_token = char_to_token['<sos>']
eos_token = char_to_token['<eos>']
pad_token = char_to_token['<pad>']


tensorboard_dir = os.path.join('tb_summary')
train_dataset = SpeechDataset(train_df, dataset_dir, sos_token, char_to_token, 
                              eos_token, device=DEVICE, file_extension='.wav')
train_loader = AudioDataLoader(pad_token, train_dataset, batch_size=32, 
                               shuffle=True, drop_last=True)

Number of chars 4256


### Instantiate model

In [4]:
input_size = 128    # num rows in instagram
hidden_dim = 64  # 256*2 nodes in each LSTM
num_layers = 3
dropout = 0.1
layer_norm = False   
encoder = Listener(input_size, hidden_dim, num_layers, dropout=dropout, layer_norm=layer_norm)

hid_sz = 64
embed_dim = 30
vocab_size = len(chars)
decoder = AttendAndSpell(embed_dim, hid_sz, encoder.output_size, vocab_size)

hyperparams = {'input_size':input_size, 'hidden_dim':hidden_dim, 'num_layers':num_layers,
                'dropout':dropout, 'layer_norm':layer_norm, 'hid_sz':hid_sz, 'embed_dim':embed_dim}

criterion = nn.CrossEntropyLoss()
model = Seq2Seq(encoder, decoder, criterion, tf_ratio = 1.0, device=DEVICE).to(DEVICE)

### Training

In [None]:
optimizer = optim.ASGD(model.parameters(), lr=0.2)  # lr = 0.2 used in paper
# optimizer = optim.Adadelta(model.parameters())
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.98)

load = False
if load:
    saved_file = 'Trained Models/Training_2019-12-25 00:09:23.921978/las_model_6'
    model.load_state_dict(torch.load(saved_file))
    start_epoch = int(saved_file[-1]) + 1
    time = os.listdir(tensorboard_dir)[-1]  # use the last one
else:
    start_epoch = 0
    time = str(datetime.datetime.now())

save_dir = os.path.join('trained_models_mandarin', f'Training_{time}')
try:    
    os.mkdir(save_dir);
except FileExistsError:
    pass

summary_dir = os.path.join(tensorboard_dir, time)
writer = SummaryWriter(summary_dir)

# Saving hyperparmas
with open(os.path.join(save_dir, 'info.txt'), 'wb') as f:
    pickle.dump(hyperparams, f)

    
log_interval = 5
print_interval = 40
epochs = 20

for epoch in range(start_epoch, epochs):
    print("\nTeacher forcing ratio:", model.tf_ratio)
    train(model, DEVICE, train_loader, optimizer, epoch, print_interval, writer, log_interval)
    scheduler.step()                                    # Decrease learning rate
    torch.save(model.state_dict(), os.path.join(save_dir, f'las_model_{epoch}'))
    model.tf_ratio = max(model.tf_ratio - 0.05, 0.8)    # Decrease teacher force ratio


Teacher forcing ratio: 1.0
Training, Logging: Mean loss of previous 40 batches 


Teacher forcing ratio: 0.95
Training, Logging: Mean loss of previous 40 batches 



### TEST

In [6]:
def decode_pred_sent(out):
    pred_sent = []
    for t in out:
        lol = t.max(dim=1)[1].item()
        pred_sent.append(token_to_char[lol])
    return ''.join(pred_sent)


def decode_true_sent(y):
    sent = []
    for t in y:
        sent.append(token_to_char[t.item()])
    return ''.join(sent)

In [9]:
DEVICE = torch.device('cpu')

In [13]:
num_sent = 10
model.eval()
model.to(DEVICE)
model.device = DEVICE
model.tf_ratio = 0.9

for _ in range(num_sent):
    
    idx = random.randint(0, train_df.shape[0])
    trial_dataset = SpeechDataset(train_df, dataset_dir, sos_token, char_to_token, eos_token, file_extension='.flac')

    x, y = trial_dataset.__getitem__(idx)
    # plt.imshow(x[0,:,:].detach())

    # Model output
    target = y.unsqueeze(dim=0).to(DEVICE)
    data = x.permute(0, 2, 1).to(DEVICE)
    loss, output = model(data, target)
    print("True sent : ", decode_true_sent(y), end='\n\n')
    print("Pred sent : ", decode_pred_sent(output))
    print("Loss :", loss.item())    
    print("\n")

True sent :  <sos>who was now approaching womanhood he would sometimes talk with her differently from the manner in which he would speak to a mere girl but on her part she seemed not to notice the difference and for their daily amusement either go<eos>

Pred sent :   uholshs sornsnpropsh ng shmpndsun sorshuld shmp hnpl shll shlh hor shgfordds f soom shv sord r sn sholl sorshmld shrrssh nsnd r dsorlpsurosf sor srrshsho shnnud sor shvsor sorshvososfordd krsnd sor shv r sonnh snosh ond sngh<eos>r soo
Loss : 722.2744140625


True sent :  <sos>now sworn to the service of his most christian majesty<eos>

Pred sent :   uor shordssh shj shnoongdshosos sors soooshgnn soruss  
Loss : 178.7652587890625


True sent :  <sos>and a paper cap on his head has the strong conscience and the strong sense the blended susceptibility and self command of our friend adam he was not an average man yet such men as he are reared here and there in every generation of our peasant artisans<eos>

Pred sent :   und sh