In [None]:
import os
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertTokenizerFast, BertForSequenceClassification
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
import numpy as np
import kaldiio
import json
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from utils import *

train_json = "/path/espnet/egs/aishell/asrNoperAddBERTToken/dump/train/deltafalse/data.json"
with open(train_json, "r") as f:
    train_json = json.load(f)["utts"]
trainset = make_batchset(train_json,
            min_batch_size=1,
            shortest_first=True,
            count="frame",
            batch_frames_in=10000,
        )

dev_json = "/path/espnet/egs/aishell/asrNoperAddBERTToken/dump/dev/deltafalse/data.json"
with open(dev_json, "r") as f:
    dev_json = json.load(f)["utts"]
devset = make_batchset(dev_json,
            min_batch_size=1,
            shortest_first=True,
            count="frame",
            batch_frames_in=10000,
        )

def collate(minibatch):
    fbanks = []
    tokens = []
    for key, info in minibatch[0]:
        fbanks.append(torch.tensor(spec_augment(kaldiio.load_mat(info["input"][0]["feat"]))))
        s = info["output"][0]["tokenid"].split()
        if len(s)<60:
            for i in range(60-len(s)):
                s+=[torch.tensor([0])]
        if len(s)>60:
            s=s[0:60]
        tokens.append(torch.tensor([int(st) for st in s]))
        ilens = torch.tensor([x.shape[0] for x in fbanks])
    return pad_sequence(fbanks, batch_first=True), pad_sequence(tokens, batch_first=True)

def collate_dev(minibatch):
    fbanks = []
    tokens = []
    for key, info in minibatch[0]:
        fbanks.append(torch.tensor(kaldiio.load_mat(info["input"][0]["feat"])))
        s = info["output"][0]["tokenid"].split()
        if len(s)<60:
            for i in range(60-len(s)):
                s+=[torch.tensor([0])]
        if len(s)>60:
            s=s[0:60]
        tokens.append(torch.tensor([int(st) for st in s]))
        ilens = torch.tensor([x.shape[0] for x in fbanks])
    return pad_sequence(fbanks, batch_first=True), pad_sequence(tokens, batch_first=True)

train_loader = DataLoader(trainset, collate_fn=collate, shuffle=True, pin_memory=True)
dev_loader = DataLoader(devset, collate_fn=collate_dev, pin_memory=True)

In [None]:
from distutils.version import LooseVersion
from typing import Union
from torch.optim.lr_scheduler import _LRScheduler

class WarmupLR(_LRScheduler):
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        warmup_steps: Union[int, float] = 25000,
        last_epoch: int = -1,
    ):
        self.warmup_steps = warmup_steps
        super().__init__(optimizer, last_epoch)

    def __repr__(self):
        return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"

    def get_lr(self):
        step_num = self.last_epoch + 1
        return [
            (768) ** -0.5
            * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5)*0.1
            for lr in self.base_lrs
        ]


In [None]:
import bert_asr
from tqdm.notebook import trange, tqdm

PRETRAINED_MODEL_NAME = "bert-base-chinese"
bertmodel = bert_asr.BertForMaskedLMForBERTASR.from_pretrained(PRETRAINED_MODEL_NAME)
encoder_model = bert_asr.BERTASR_Encoder(83,21128)
device = torch.device('cpu')
encoder_model.load_state_dict(torch.load("./path/pretraining_avg10", map_location=device))
model = bert_asr.BERTASR(encoder_model, bertmodel)
model = model.cuda()

optimizer =torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = WarmupLR(optimizer, 12000)
EPOCHS = 130

print("----Training Start----")
optimizer.zero_grad()
for epoch in range(EPOCHS):

    running_loss = 0.0
    train_step = 0
    train_index = 0
    model.train()
    for x, y in tqdm(train_loader):
        # forward pass
        outputs = model(x.cuda(), y.cuda())[0]
        loss = outputs/12.0
        # backward
        loss.backward()
        train_index += 1
        if train_index % 12 == 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            train_index = 0

        # log batch loss
        running_loss += loss.item()
        train_step += 1
    #
    model.eval()
    dev_loss = 0.0
    dev_loss2 = 0.0
    dev_step = 0
    with torch.no_grad():
        for dev_x, dev_y in tqdm(dev_loader):
            outputs = model(dev_x.cuda(), dev_y.cuda())[0]
            dev_loss += outputs.item()
            dev_step += 1
            #
            outputs = model(dev_x.cuda())[0]
            loss_fct = CrossEntropyLoss()
            outputs = loss_fct(outputs.view(-1, model.encoder.odim), dev_y.cuda().view(-1))
            dev_loss2 += outputs.item()
    print('[epoch %d] train loss: %.3f | dev loss: %.3f | dev loss(w/o sm): %.3f' %
          (epoch + 1, running_loss/train_step*12, dev_loss/dev_step, dev_loss2/dev_step))
    torch.save(model.state_dict(), "./bertasr."+str(epoch + 1))
    print("save for epoch:" + str(epoch+1))


In [None]:
#Average 10 model 121~130
avg_model(root="./bertasr.", avg_num=10, last_num=130, save_path="./bertasr_avg10")