In [1]:
import torch.nn as nn
import torch
from torch.autograd import Variable
torch.cuda.set_device(2)
torch.manual_seed(7)
torch.cuda.manual_seed_all(7)

class BLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout=0.0, n_layers=1, bidirectional=True):
        super(BLSTM, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.bidirectional = bidirectional
        
        self.rnn = nn.LSTM(
            input_size=self.input_dim << 1,
            hidden_size=self.hidden_dim,
            num_layers=self.n_layers,
            bidirectional = self.bidirectional,
            dropout=dropout,
            batch_first=True
        )

    def forward(self, inputs):
        if isinstance(inputs, tuple):
            inputs, hc = inputs
        batch_size = inputs.size(0)
        seq_len = inputs.size(1)
        input_size = inputs.size(2)
        if seq_len % 2:
            zeros = torch.zeros((inputs.size(0), 1, inputs.size(2))).cuda()
            inputs = torch.cat([inputs, zeros], dim = 1)
            seq_len += 1
        inputs = inputs.contiguous().view(batch_size, int(seq_len / 2), input_size * 2)
        
        output, hc = self.rnn(inputs, hc)
        return (output, hc)


class Listener(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout=0.0, n_layers=1, bidirectional=True):
        super(Listener, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.bidirectional = True
        
        self.pblstm = nn.Sequential(
            BLSTM(
                input_dim=self.input_dim,
                hidden_dim=self.hidden_dim,
                dropout=dropout,
                n_layers=n_layers,
                bidirectional = self.bidirectional
            ),
            BLSTM(
                input_dim=self.hidden_dim << 1 if self.bidirectional else 0,
                hidden_dim=self.hidden_dim,
                dropout=dropout,
                n_layers=n_layers,
                bidirectional = self.bidirectional
            ),
            BLSTM(
                input_dim=self.hidden_dim << 1 if self.bidirectional else 0,
                hidden_dim=self.hidden_dim,
                dropout=dropout,
                n_layers=n_layers,
                bidirectional = self.bidirectional
            ))
    
    def init_hidden(self, batch_size):
        hidden = Variable(torch.zeros(self.n_layers * 2 if self.bidirectional else 1, batch_size, self.hidden_dim))
        cell = Variable(torch.zeros(self.n_layers * 2 if self.bidirectional else 1, batch_size, self.hidden_dim))
        return (hidden.cuda(),cell.cuda())
    
    def forward(self, inputs):
        hc = self.init_hidden(inputs.size(0))
        listener_features, _ = self.pblstm((inputs, hc))

        return listener_features

In [2]:
import numpy as np

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dim):
        super(ScaledDotProductAttention, self).__init__()
        self.dim = dim

    def forward(self, query, value):
        score = torch.bmm(query, value.transpose(1, 2)) / np.sqrt(self.dim)
        attn = F.softmax(score, dim=-1)
        context = torch.bmm(attn, value)
        return context, attn


class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads=8):
        super(MultiHeadAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.dim = int(hidden_dim / num_heads)
        self.scaled_dot = ScaledDotProductAttention(self.dim)
        self.query_projection = nn.Linear(hidden_dim, self.dim * num_heads)
        self.value_projection = nn.Linear(hidden_dim, self.dim * num_heads)
        self.out_projection = nn.Linear(hidden_dim << 1, hidden_dim, bias=True)

    def forward(self, query, value, prev_attn=None):
        batch_size = value.size(0)
        residual = query

        query = self.query_projection(query).view(batch_size, -1, self.num_heads, self.dim)
        value = self.value_projection(value).view(batch_size, -1, self.num_heads, self.dim)

        query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.dim)
        value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.dim)

        context, attn = self.scaled_dot(query, value)
        context = context.view(self.num_heads, batch_size, -1, self.dim)

        context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.dim)
        combined = torch.cat([context, residual], dim=2)

        output = torch.tanh(self.out_projection(combined.view(-1, self.hidden_dim << 1))).view(batch_size, -1, self.hidden_dim)
        return output, context

In [3]:
class Speller(nn.Module):
    def __init__(self, num_classes, hidden_dim, max_step=300, sos_token=1, eos_token=2, dropout=0.0, n_layers=2, num_heads=4):
        super(Speller, self).__init__()
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.num_heads = num_heads
        
        self.max_step = max_step
        self.eos_token = eos_token
        self.sos_token = sos_token
        
        #self.blendding = nn.Linear(self.hidden_dim<<1, self.hidden_dim)
        self.emb = nn.Embedding(self.num_classes, self.hidden_dim)
        self.rnn = nn.LSTM(
            input_size=self.hidden_dim,
            hidden_size=self.hidden_dim,
            num_layers=self.n_layers,
            bidirectional = False,
            dropout=dropout,
            batch_first=True
        )
        self.init_rnn_weights()
        self.attention = MultiHeadAttention(self.hidden_dim, self.num_heads)
        #self.attention = Attention(dec_dim=self.hidden_dim, enc_dim=self.hidden_dim, conv_dim=1, attn_dim=self.hidden_dim)
        self.character_distribution = nn.Linear(self.hidden_dim, num_classes)
        self.softmax = nn.LogSoftmax(dim=-1)
    
    def init_rnn_weights(self, low=-0.1, high=0.1):
        for name, param in self.rnn.named_parameters():
            if 'weight_ih' in name:
                torch.nn.init.uniform_(param.data, a=low, b=high)
            elif 'weight_hh' in name:
                torch.nn.init.uniform_(param.data, a=low, b=high)
            elif 'bias' in name:
                param.data.fill_(0)

    def init_hidden(self, batch_size):
        hidden = Variable(torch.zeros(self.n_layers, batch_size, self.hidden_dim))
        cell = Variable(torch.zeros(self.n_layers, batch_size, self.hidden_dim))
        return (hidden.cuda(),cell.cuda())
    
    def forward_step(self,inputs, hc, listener_features):
        decoder_output, hc = self.rnn(inputs, hc)
        att_out, context = self.attention(decoder_output,listener_features)
        #context = torch.sum(att_out, dim=1).unsqueeze(dim=1)
        #concat_output = torch.cat((decoder_output,context),dim=-1)
        logit = self.softmax(self.character_distribution(att_out))

        return logit, hc, context
    
    def forward(self, listener_features, ground_truth=None, teacher_forcing_rate = 0.9, use_beam=False, beam_size=3):
        if ground_truth is None:
            teacher_forcing_rate = 0
        teacher_forcing = True if np.random.random_sample() < teacher_forcing_rate else False
        
        if (ground_truth is None) and (not teacher_forcing):
            max_step = self.max_step
        else:
            max_step = ground_truth.size(1)-1
        
        input_word = torch.zeros(listener_features.size(0), 1).long().cuda()
        input_word[:,0] = self.sos_token
        
        init_context = torch.zeros_like(listener_features[:,0:1,:])
        inputs = self.emb(input_word)
        hc = self.init_hidden(input_word.size(0))
        logits = []
        
        if not use_beam:
            for step in range(max_step):
                logit, hc, context = self.forward_step(inputs, hc, listener_features)
                logits.append(logit.squeeze())
                if teacher_forcing:
                    output_word = ground_truth[:,step+1:step+2]
                else:
                    output_word = logit.topk(1)[1].squeeze(-1)
                inputs = self.emb(output_word)

            logits = torch.stack(logits, dim=1)
            #y_hats = torch.max(logits, dim=-1)[1]
            return logits
        else:
            btz = listener_features.size(0)
            y_hats = torch.zeros(btz, max_step).long().cuda()
            logit, hc, context = self.forward_step(inputs, hc, listener_features)
            output_words = logit.topk(beam_size)[1].squeeze(1)
            for bi in range(btz):
                b_output_words = output_words[bi,:].unsqueeze(0).transpose(1,0).contiguous()
                b_inputs = self.emb(b_output_words)
                b_listener_features = listener_features[bi,:,:].unsqueeze(0).expand((beam_size,-1,-1)).contiguous()
                if isinstance(hc, tuple):
                    b_h = hc[0][:,bi,:].unsqueeze(1).expand((-1,beam_size,-1)).contiguous()
                    b_c = hc[1][:,bi,:].unsqueeze(1).expand((-1,beam_size,-1)).contiguous()
                    b_hc = (b_h, b_c)
                else:
                    b_hc = hc[:,bi,:].unsqueeze(1).expand((-1,beam_size,-1)).contiguous()
                    
                scores = torch.zeros(beam_size,1).cuda()
                ids = torch.zeros(beam_size, max_step, 1).long().cuda()
                for step in range(max_step):
                    logit, b_hc, context = self.forward_step(b_inputs, b_hc, b_listener_features)
                    score, id = logit.topk(1)
                    scores += score.squeeze(1)
                    ids[:,step,:] = id.squeeze(1)
                    output_word = logit.topk(1)[1].squeeze(-1)
                    b_inputs = self.emb(output_word)
                #print(scores.squeeze(1).topk(1)[1])
                y_hats[bi,:] = ids[scores.squeeze(1).topk(1)[1],:].squeeze(2)
            return y_hats

In [4]:
class LAS(nn.Module):
    def __init__(self, listener, speller):
        super(LAS, self).__init__()
        self.listener = listener
        self.speller = speller
        
    def forward(self, inputs, ground_truth=None, teacher_forcing_rate=0.9, use_beam=False, beam_size=16):
        listener_features = self.listener(inputs)
        logits = self.speller(listener_features, ground_truth, 
                              teacher_forcing_rate=teacher_forcing_rate, use_beam=use_beam, beam_size=beam_size)
        
        return logits

In [19]:
import pickle

with open('data/aihub/uchar2id.pkl', 'rb') as f:
    char2id = pickle.load(f)
    
with open('data/aihub/uid2char.pkl', 'rb') as f:
    id2char = pickle.load(f)

PAD_TOKEN = int(char2id['<pad>'])
SOS_TOKEN = int(char2id['<sos>'])
EOS_TOKEN = int(char2id['<eos>'])
UNK_TOKEN = int(char2id['<unk>'])

In [22]:
with open('data/aihub/train_data.pkl', 'rb') as f:
    train_data = pickle.load(f)
    
with open('data/aihub/test_data.pkl', 'rb') as f:
    test_data = pickle.load(f)

In [23]:
from torch.utils.data import Dataset, DataLoader, random_split
import torchaudio
from torchaudio import transforms, functional
import torch.nn.functional as F
from specaugment.spec_augment_pytorch import spec_augment
import random

class SpeechDataset(Dataset):

    def __init__(self, pair_data, char2id, max_len=0, specaugment=False, pkwargs=None):
        
        super(SpeechDataset, self).__init__()
        self.pair_data = list(pair_data)
        self.char2id = char2id
        self.max_len = max_len
        self.specaugment = specaugment
        self.pkwargs = pkwargs
        if self.specaugment:
            self.origin_data = list(self.pair_data)
            self.aug_ids = [id for id in range(len(self.pair_data), len(self.pair_data)*2)]
            self.pair_data.extend(self.pair_data)
        
    def __len__(self):
        return len(self.pair_data)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        audio_path, label = self.pair_data[idx]
        if audio_path.split('.')[-1] == 'pcm':
            #pcm = np.memmap(audio_path, dtype='h', mode='r')
            #audio = torch.FloatTensor(pcm).unsqueeze(0)
            audio = np.memmap(audio_path, dtype='h', mode='r')
            audio = audio.astype('float32') / 32767

            audio = torch.from_numpy(audio).type(torch.FloatTensor)
        else:
            audio, _ = torchaudio.load(audio_path)
            
        #audio = self.trim(audio)
        #x = self.log_scale(self.transform(audio)).unsqueeze(0)
        """
        spectrogram = torch.stft(
            signal,
            self.pkwargs['n_fft'],
            hop_length=self.pkwargs['hop_length'],
            win_length=self.pkwargs['n_fft'],
            window=torch.hamming_window(self.n_fft),
            center=False,
            normalized=False,
            onesided=True
        )
        """
        x = torchaudio.compliance.kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80).t().unsqueeze(0)
        if self.specaugment:
            if idx in self.aug_ids:
                x = spec_augment(x, time_warping_para=40, frequency_masking_para=13,
                                time_masking_para=30, frequency_mask_num=2, time_mask_num=2)
        x = x[0,:,:].squeeze(1).t()
        if self.max_len:
            x = np.pad(x, ((0, 0), (0, self.max_len - x.shape[1])), "constant")
            
        y = []
        y.append(SOS_TOKEN)
        for char in label:
            try:
                y.append(self.char2id[char])
            except:
                y.append(self.char2id['<unk>'])
        y.append(EOS_TOKEN)
        y = np.array(y)
        return (x, y)
    
    def shuffle(self):
        self.pair_data = list(self.origin_data)
        random.shuffle(self.pair_data)
        if self.specaugment:
            self.aug_ids = [id for id in range(len(self.pair_data), len(self.pair_data)*2)]
            self.pair_data.extend(self.pair_data)
            
    def trim(self, sig, hop_size=64, threshhold=0.002):
        head = None
        tail = None
        #threshhold = ((sig.max())/2)*threshhold
        sig_len = len(sig)
        for i in range(int(sig_len/hop_size)):
            pre = sig[i*hop_size:(i+1)*hop_size].abs().sum().item()
            post = sig[(i+1)*hop_size:(i+2)*hop_size].abs().sum().item()
            grad = abs((post-pre)/hop_size)
            if grad>threshhold:
                head = (i+1)*hop_size
                break

        for i in range(int(sig_len/hop_size)):
            pre = sig[sig_len-(i+1)*hop_size:sig_len-i*hop_size].abs().sum().item()
            post = sig[sig_len-(i+2)*hop_size:sig_len-(i+1)*hop_size].abs().sum().item()
            grad = abs((post-pre)/hop_size)
            if grad>threshhold:
                tail = sig_len-(i+1)*hop_size
                break
        #print(head, tail)
        return sig[head:tail]

"""
def pad(batch):
    global transform, do_normalize
    max_len_x = 0
    max_len_y = 0
    data = []
    target = []
    for sample in batch:
        data += [sample[0]]
        target += [sample[1]]
        n = sample[0].shape[1]
        m = sample[1].shape[0]
        if max_len_x < n:
            max_len_x = n
        if max_len_y < m:
            max_len_y = m
    data = torch.tensor([F.pad(input=x, pad=(0, max_len_x - x.shape[1])).numpy() for x in data])
    target = [np.pad(y, (0, max_len_y - y.shape[0]), 'constant') for y in target]
    
    return data.contiguous(), torch.tensor(target).contiguous()
"""

def _collate_fn(batch):
    """ functions that pad to the maximum sequence length """
    def seq_length_(p):
        return len(p[0])

    def target_length_(p):
        return len(p[1])

    seq_lengths = [len(s[0]) for s in batch]
    target_lengths = [len(s[1]) for s in batch]

    max_seq_sample = max(batch, key=seq_length_)[0]
    max_target_sample = max(batch, key=target_length_)[1]

    max_seq_size = max_seq_sample.size(0)
    max_target_size = len(max_target_sample)

    feat_size = max_seq_sample.size(1)
    batch_size = len(batch)

    seqs = torch.zeros(batch_size, max_seq_size, feat_size)

    targets = torch.zeros(batch_size, max_target_size).to(torch.long)
    targets.fill_(0)

    for x in range(batch_size):
        sample = batch[x]
        tensor = sample[0]
        target = sample[1]
        seq_length = tensor.size(0)
        seqs[x].narrow(0, 0, seq_length).copy_(tensor)
        targets[x].narrow(0, 0, len(target)).copy_(torch.LongTensor(target))
    return seqs, targets

In [24]:
batch_size = 16

num_channels = 80
train_dataset = SpeechDataset(train_data, char2id, specaugment=True, max_len=0)
test_dataset = SpeechDataset(test_data, char2id, max_len=0)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, pin_memory=True,
                              collate_fn=_collate_fn, num_workers=16)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True,
                             collate_fn=_collate_fn, num_workers=16)

In [25]:
import Levenshtein as Lev


def WER(s1, s2):
    """
    Computes the Word Error Rate, defined as the edit distance between the
    two provided sentences after tokenizing to words.
    Arguments:
        s1 (string): space-separated sentence
        s2 (string): space-separated sentence
    """

    b = set(s1.split() + s2.split())
    word2char = dict(zip(b, range(len(b))))

    w1 = [chr(word2char[w]) for w in s1.split()]
    w2 = [chr(word2char[w]) for w in s2.split()]
    return Lev.distance(''.join(w1), ''.join(w2))


def CER(s1, s2):
    """
    Computes the Character Error Rate, defined as the edit distance.
    Arguments:
        s1 (string): space-separated sentence
        s2 (string): space-separated sentence
    """
    s1, s2, = s1.replace(' ', ''), s2.replace(' ', '')
    return Lev.distance(s1, s2)

def label_to_string(labels, id2char):
    """
    Converts label to string (number => Hangeul)

    Args:
        labels (list): number label
        id2char (dict): id2char[id] = ch
        eos_id (int): identification of <end of sequence>

    Returns: sentence
        - **sentence** (str or list): Hangeul representation of labels
    """
    sos_id = id2char.index('<sos>')
    eos_id = id2char.index('<eos>')
    #unk_id = id2char.index('<unk>')
    if len(labels.shape) == 1:
        sentence = str()
        for label in labels:
            if label.item() == sos_id:
                continue
            if label.item() == eos_id:
                break
            sentence += id2char[label.item()]
        return sentence

    elif len(labels.shape) == 2:
        sentences = list()
        for batch in labels:
            sentence = str()
            for label in batch:
                if label.item() == sos_id:
                    continue
                if label.item() == eos_id:
                    break
                sentence += id2char[label.item()]
            sentences.append(sentence)
        return sentences

In [26]:
import torch
import torch.nn as nn

class LabelSmoothingLoss(nn.Module):
    def __init__(self, vocab_size, ignore_index, smoothing=0.1, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.vocab_size = vocab_size
        self.dim = dim
        self.ignore_index = ignore_index

    def forward(self, logit, target):
        with torch.no_grad():
            label_smoothed = torch.zeros_like(logit)
            label_smoothed.fill_(self.smoothing / (self.vocab_size - 1))
            #print(label_smoothed, target.data.unsqueeze(1))
            label_smoothed.scatter_(1, target.data.unsqueeze(1), self.confidence)
            label_smoothed[target == self.ignore_index, :] = 0
        return torch.sum(-label_smoothed * logit)

In [27]:
import torch.optim as optim

listener = Listener(num_channels, 256)
speller = Speller(len(id2char), 512, num_heads=4, dropout=0.3)
model = LAS(listener, speller).cuda()

optimizer = optim.Adam(model.parameters(), lr=3e-4)
#optimizer = optim.ASGD(model.parameters(), lr=3e-4)
criterion = LabelSmoothingLoss(len(char2id), ignore_index = PAD_TOKEN, smoothing = 0.1, dim = -1).cuda()
#criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=PAD_TOKEN).cuda()

In [28]:
"""
checkpoint = torch.load('checkpoint/aihub-3/last_checkpoint.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
"""

"\ncheckpoint = torch.load('checkpoint/aihub-3/last_checkpoint.pt')\nmodel.load_state_dict(checkpoint['model_state_dict'])\noptimizer.load_state_dict(checkpoint['optimizer_state_dict'])\nepoch = checkpoint['epoch']\n"

In [31]:
def last_checkpoint(path, epoch, model, optimizer, loss):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, path+'/last_checkpoint.pt')
    
def score(pred, y):
    cer = 0.
    wer = 0.
    pred_sts = label_to_string(pred, id2char)
    gold_sts = label_to_string(y, id2char)
    for i, (res,gt) in enumerate(zip(pred_sts,gold_sts)):
        c_length = len(gt.replace(' ',''))
        cer += CER(res,gt)/c_length
        
        w_length = len(gt.split())
        wer += WER(res,gt)/w_length
    return cer, wer

In [32]:
def scheduler_sampling(epoch, e_min=5, ratio_s=0.9, ratio_e=0, n_epoch_ramp=10):
    if epoch>e_min:
        epoch -= e_min
        teacher_forcing_ratio = max(ratio_s - (ratio_s-ratio_e)*epoch/n_epoch_ramp, ratio_e)
    else:
        teacher_forcing_ratio = 0.9
    return teacher_forcing_ratio

In [33]:
def set_lr(optimizer, lr):
        for g in optimizer.param_groups:
            g['lr'] = lr

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

#scheduler = ReduceLROnPlateau(optimizer, 'min')

epochs = 100
checkpoint = 1
print_step = 50
save_epoch = 1
checkpoint_path = 'checkpoint/aihub'

for epoch in range(epochs):
    total_loss = 0.
    num_samples = 0
    cer = 0.
    wer = 0.
    step= 0
    if epoch==5:
        set_lr(optimizer, 1e-5)
    total_step = len(train_dataloader)
    for batch in train_dataloader:
        model.train()
        optimizer.zero_grad()

        x, y = batch

        batch_size = x.size(0)
        
        x = x.cuda()
        y = y.cuda()
        target = y[:, 1:].contiguous().cuda()
        teacher_forcing_rate = scheduler_sampling(epoch)
        logits = model(x, ground_truth=y, teacher_forcing_rate=teacher_forcing_rate)
        
        y_hats = torch.max(logits, dim=-1)[1]
        #print(label_to_string(target, id2char))
        loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))

        total_loss += loss.item()
        num_samples += batch_size

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=400)
        optimizer.step()
        cer_, wer_ = score(y_hats.long(), target)
        cer += cer_
        wer += wer_
        if step%print_step==0:
            print('timestep: {:4d}/{:4d}, loss: {:.4f}, cer: {:.2f}, wer: {:.2f}, tf_rate: {:.2f}'.format(
                step, total_step, total_loss/num_samples, cer/num_samples, wer/num_samples, teacher_forcing_rate))
            with open('aihub-4.log', 'at') as f:
                f.write('timestep: {:4d}/{:4d}, loss: {:.4f}, cer: {:.2f}, wer: {:.2f}, tf_rate: {:.2f}\n'.format(
                step, total_step, total_loss/num_samples, cer/num_samples, wer/num_samples, teacher_forcing_rate))
        step += 1
        
    total_loss /= num_samples
    cer /= num_samples
    wer /= num_samples
    print('Epoch %d (Training) Total Loss %0.4f CER %0.4f WER %0.4f' % (epoch, total_loss, cer, wer))
    with open('aihub-4.log', 'at') as f:
        f.write('Epoch %d (Training) Total Loss %0.4f CER %0.4f WER %0.4f\n' % (epoch, total_loss, cer, wer))
    train_dataloader.dataset.shuffle()
        
    total_loss = 0.
    num_samples = 0
    cer = 0.
    wer = 0.
    with torch.no_grad():
        model.eval()
        for batch in test_dataloader:
            x, y = batch

            batch_size = x.size(0)

            x = x.cuda()
            y = y.cuda()
            target = y[:, 1:].contiguous().cuda()

            logits = model(x, ground_truth=None, teacher_forcing_rate=0.0)

            y_hats = torch.max(logits, dim=-1)[1]
            logits = logits[:,:target.size(1),:].contiguous() # cut over length to calculate loss

            loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
            total_loss += loss.item()

            cer_, wer_ = score(y_hats.long(), target)
            cer += cer_
            wer += wer_
            num_samples += batch_size
    val_loss = total_loss/num_samples
    cer /= num_samples
    wer /= num_samples
    #scheduler.step(val_loss)
    print('Epoch %d (Evaluate) Total Loss %0.4f CER %0.4f WER %0.4f' % (epoch, val_loss, cer, wer))
    with open('aihub-4.log', 'at') as f:
        f.write('Epoch %d (Evaluate) Total Loss %0.4f CER %0.4f WER %0.4f\n' % (epoch, val_loss, cer, wer))
    last_checkpoint(checkpoint_path+'/', epoch, model, optimizer, loss)
    if epoch%save_epoch==0:
        torch.save(model, "{}/epoch{}-cer{:.2f}-wer{:.2f}.pt".format(checkpoint_path, epoch, cer, wer))

timestep:    0/77661, loss: 145.7933, cer: 7.19, wer: 1.00, tf_rate: 0.90
timestep:   50/77661, loss: 180.5205, cer: 1.20, wer: 1.00, tf_rate: 0.90
timestep:  100/77661, loss: 173.5473, cer: 1.13, wer: 1.12, tf_rate: 0.90
timestep:  150/77661, loss: 167.2752, cer: 1.11, wer: 1.18, tf_rate: 0.90
timestep:  200/77661, loss: 163.2152, cer: 1.07, wer: 1.13, tf_rate: 0.90
timestep:  250/77661, loss: 159.5642, cer: 1.04, wer: 1.11, tf_rate: 0.90
timestep:  300/77661, loss: 156.9817, cer: 1.01, wer: 1.09, tf_rate: 0.90
timestep:  350/77661, loss: 154.1604, cer: 1.00, wer: 1.09, tf_rate: 0.90
timestep:  400/77661, loss: 153.4419, cer: 0.98, wer: 1.08, tf_rate: 0.90
timestep:  450/77661, loss: 152.6465, cer: 0.97, wer: 1.07, tf_rate: 0.90
timestep:  500/77661, loss: 151.5509, cer: 0.96, wer: 1.07, tf_rate: 0.90
timestep:  550/77661, loss: 149.7598, cer: 0.95, wer: 1.07, tf_rate: 0.90
timestep:  600/77661, loss: 148.8346, cer: 0.94, wer: 1.06, tf_rate: 0.90
timestep:  650/77661, loss: 147.8108, 

timestep: 5550/77661, loss: 127.9559, cer: 0.82, wer: 1.01, tf_rate: 0.90
timestep: 5600/77661, loss: 127.9134, cer: 0.82, wer: 1.01, tf_rate: 0.90
timestep: 5650/77661, loss: 127.7805, cer: 0.82, wer: 1.01, tf_rate: 0.90
timestep: 5700/77661, loss: 127.7302, cer: 0.82, wer: 1.01, tf_rate: 0.90
timestep: 5750/77661, loss: 127.6622, cer: 0.82, wer: 1.01, tf_rate: 0.90
timestep: 5800/77661, loss: 127.5743, cer: 0.82, wer: 1.01, tf_rate: 0.90
timestep: 5850/77661, loss: 127.5810, cer: 0.82, wer: 1.01, tf_rate: 0.90
timestep: 5900/77661, loss: 127.5753, cer: 0.82, wer: 1.01, tf_rate: 0.90
timestep: 5950/77661, loss: 127.5845, cer: 0.82, wer: 1.01, tf_rate: 0.90
timestep: 6000/77661, loss: 127.5700, cer: 0.82, wer: 1.01, tf_rate: 0.90
timestep: 6050/77661, loss: 127.5033, cer: 0.82, wer: 1.01, tf_rate: 0.90
timestep: 6100/77661, loss: 127.4560, cer: 0.82, wer: 1.01, tf_rate: 0.90
timestep: 6150/77661, loss: 127.3824, cer: 0.82, wer: 1.01, tf_rate: 0.90
timestep: 6200/77661, loss: 127.3549, 

timestep: 11100/77661, loss: 122.5346, cer: 0.78, wer: 0.99, tf_rate: 0.90
timestep: 11150/77661, loss: 122.4729, cer: 0.78, wer: 0.99, tf_rate: 0.90
timestep: 11200/77661, loss: 122.4100, cer: 0.78, wer: 0.99, tf_rate: 0.90
timestep: 11250/77661, loss: 122.3351, cer: 0.78, wer: 0.99, tf_rate: 0.90
timestep: 11300/77661, loss: 122.2965, cer: 0.78, wer: 0.99, tf_rate: 0.90
timestep: 11350/77661, loss: 122.2473, cer: 0.78, wer: 0.99, tf_rate: 0.90
timestep: 11400/77661, loss: 122.1725, cer: 0.78, wer: 0.99, tf_rate: 0.90
timestep: 11450/77661, loss: 122.1239, cer: 0.78, wer: 0.99, tf_rate: 0.90
timestep: 11500/77661, loss: 122.0430, cer: 0.77, wer: 0.99, tf_rate: 0.90
timestep: 11550/77661, loss: 121.9891, cer: 0.77, wer: 0.98, tf_rate: 0.90
timestep: 11600/77661, loss: 121.9349, cer: 0.77, wer: 0.98, tf_rate: 0.90
timestep: 11650/77661, loss: 121.8434, cer: 0.77, wer: 0.98, tf_rate: 0.90
timestep: 11700/77661, loss: 121.7864, cer: 0.77, wer: 0.98, tf_rate: 0.90
timestep: 11750/77661, lo

timestep: 16600/77661, loss: 115.1103, cer: 0.71, wer: 0.94, tf_rate: 0.90
timestep: 16650/77661, loss: 115.0274, cer: 0.71, wer: 0.94, tf_rate: 0.90
timestep: 16700/77661, loss: 114.9565, cer: 0.71, wer: 0.94, tf_rate: 0.90
timestep: 16750/77661, loss: 114.8785, cer: 0.71, wer: 0.94, tf_rate: 0.90
timestep: 16800/77661, loss: 114.7892, cer: 0.71, wer: 0.94, tf_rate: 0.90
timestep: 16850/77661, loss: 114.6898, cer: 0.71, wer: 0.94, tf_rate: 0.90
timestep: 16900/77661, loss: 114.6168, cer: 0.71, wer: 0.94, tf_rate: 0.90
timestep: 16950/77661, loss: 114.5558, cer: 0.71, wer: 0.94, tf_rate: 0.90
timestep: 17000/77661, loss: 114.4757, cer: 0.71, wer: 0.94, tf_rate: 0.90
timestep: 17050/77661, loss: 114.4386, cer: 0.71, wer: 0.93, tf_rate: 0.90
timestep: 17100/77661, loss: 114.3673, cer: 0.70, wer: 0.93, tf_rate: 0.90
timestep: 17150/77661, loss: 114.2834, cer: 0.70, wer: 0.93, tf_rate: 0.90
timestep: 17200/77661, loss: 114.2038, cer: 0.70, wer: 0.93, tf_rate: 0.90
timestep: 17250/77661, lo

timestep: 22100/77661, loss: 107.7072, cer: 0.64, wer: 0.88, tf_rate: 0.90
timestep: 22150/77661, loss: 107.6431, cer: 0.64, wer: 0.88, tf_rate: 0.90
timestep: 22200/77661, loss: 107.5776, cer: 0.64, wer: 0.88, tf_rate: 0.90
timestep: 22250/77661, loss: 107.5277, cer: 0.64, wer: 0.88, tf_rate: 0.90
timestep: 22300/77661, loss: 107.4647, cer: 0.64, wer: 0.88, tf_rate: 0.90
timestep: 22350/77661, loss: 107.4170, cer: 0.64, wer: 0.88, tf_rate: 0.90
timestep: 22400/77661, loss: 107.3494, cer: 0.64, wer: 0.88, tf_rate: 0.90
timestep: 22450/77661, loss: 107.2908, cer: 0.64, wer: 0.88, tf_rate: 0.90
timestep: 22500/77661, loss: 107.2331, cer: 0.64, wer: 0.88, tf_rate: 0.90
timestep: 22550/77661, loss: 107.1894, cer: 0.64, wer: 0.88, tf_rate: 0.90
timestep: 22600/77661, loss: 107.1274, cer: 0.64, wer: 0.88, tf_rate: 0.90
timestep: 22650/77661, loss: 107.0627, cer: 0.64, wer: 0.88, tf_rate: 0.90
timestep: 22700/77661, loss: 106.9991, cer: 0.64, wer: 0.88, tf_rate: 0.90
timestep: 22750/77661, lo