In [1]:
from torch.utils.data import DataLoader
import transformers as T
negative_rate = 0.15
change_rate = 0.9

tokenizer = T.BertTokenizer.from_pretrained('hfl/chinese-pert-base')
print(' -- Data complete -- ')

from tqdm import tqdm

def _is_chinese_char(cp):
    cp = ord(cp)
    if ((cp >= 0x4E00 and cp <= 0x9FFF) or  #
        (cp >= 0x3400 and cp <= 0x4DBF) or  #
        (cp >= 0x20000 and cp <= 0x2A6DF) or  #
        (cp >= 0x2A700 and cp <= 0x2B73F) or  #
        (cp >= 0x2B740 and cp <= 0x2B81F) or  #
        (cp >= 0x2B820 and cp <= 0x2CEAF) or
        (cp >= 0xF900 and cp <= 0xFAFF) or  #
        (cp >= 0x2F800 and cp <= 0x2FA1F)):  #
      return True

def is_chinese(s):
    for each in s:
        if not _is_chinese_char(each):
            return False
    return True

chinese_char = []
for i in range(len(tokenizer)):
    if is_chinese(tokenizer.decode([i])):
        chinese_char.append(i)

  from .autonotebook import tqdm as notebook_tqdm


 -- Data complete -- 


In [2]:
import torch
import random
import jieba
import copy
from model_rr import Model4QA
def get_batch(sample):
    lengths = [len([101]+each[0]+[102]) for each in sample]
    input_seq = [[101]+each[0]+[102]+each[1]+[102] for each in sample]
    cn_in_seq = []
    for i in range(len(input_seq)):
        cn_in_seq.append([])
        for j in range(len(input_seq[i])):
            if input_seq[i][j] in chinese_char:
                cn_in_seq[i].append(j)

    sampled_char = [random.sample(each, int(len(each)*negative_rate)) for each in cn_in_seq]
    replace_char = [random.sample(each, int(len(each)*change_rate)) for each in sampled_char]
    # char_ids = [[each for each in zip([input_seq[i][t] for t in replace_char[i]], range(len(replace_char[i])))] for i in range(len(input_seq))]
    # shuffle_ids = copy.copy(char_ids)
    shuffle_char = [random.sample(each, len(each)) for each in replace_char]
    shuffle_ids = [[input_seq[i][t] for t in replace_char[i]] for i in range(len(input_seq))]
    
    for i in range(len(input_seq)):
        for j in range(len(shuffle_ids[i])):
            input_seq[i][replace_char[i][j]] = shuffle_ids[i][j]
    
    maxlen = max([len(each) for each in input_seq])
    input_ids = torch.zeros((len(sample), maxlen), dtype=torch.int64)
    mask = torch.zeros((len(sample), maxlen), dtype=torch.int64)
    token_type_ids = torch.zeros((len(sample), maxlen), dtype=torch.int64)
    labels = torch.zeros((len(sample), maxlen), dtype=torch.int64)
    negative_samples = torch.zeros((len(sample), maxlen), dtype=torch.int64)
    for i in range(len(input_seq)):
        for j in range(len(input_seq[i])):
            input_ids[i, j] = input_seq[i][j]
            mask[i, j] = 1
            if j < lengths[i]:
                token_type_ids[i, j] = 1
        for j in range(len(sampled_char[i])):
            if sampled_char[i][j] in replace_char[i]: labels[i, sampled_char[i][j]] = shuffle_char[i].index(sampled_char[i][j])
            else: labels[i, sampled_char[i][j]] = sampled_char[i][j]
            negative_samples[i, sampled_char[i][j]] = 1
    return input_ids, mask, token_type_ids, labels, negative_samples


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
myModel = Model4QA()
# myModel = torch.nn.DataParallel(myModel, device_ids=[0])
myModel = myModel.to(device)

lr_max = 1e-5
lr_min = 1e-7

def loss_func(scores, labels, mask, negative_samples):
    total_loss = 0
    labels = [each.to(device) for each in labels]
    for i in range(scores.size(0)):
        for j in range(scores.size(1)):
            if mask[i, j] == 1 and negative_samples[i, j] == 1:
                total_loss += torch.nn.CrossEntropyLoss()(scores[i, j], labels[i][j])
    return total_loss

Some weights of the model checkpoint at hfl/chinese-pert-base were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
import copy
import math
from data_cmrc import CMRC
print(' -- Start training -- ')
warmup = 1000
optim = torch.optim.Adam(myModel.parameters(), lr=1)
scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lambda cur_iter: ((cur_iter/warmup)*(lr_max-lr_min)+lr_min) if cur_iter <= warmup else \
                                (lr_max*((1-lr_min/lr_max)/2+((1-lr_min/lr_max)/2)*math.cos((cur_iter-warmup)/(4000-warmup)*math.pi))))
pd = CMRC(split='train')
for epoch in range(5):
    
    dl = DataLoader(pd, batch_size=8, collate_fn=get_batch, shuffle=True)
    loop = tqdm(dl, leave=True)
    running_loss = 0.
    match = 0.
    all = 0.

    rp_match = 0.
    rp_all = 0.

    myModel.train()
    for index, batch in enumerate(loop):
        
        input_ids, mask, token_type_ids, labels, negative_samples = batch
        input_ids, mask, token_type_ids = input_ids.to(device), mask.to(device), token_type_ids.to(device)
        optim.zero_grad()
        scores = myModel(input_ids, mask=mask, token_type_ids=token_type_ids)
        loss = loss_func(scores, labels, mask, negative_samples)
        loss.backward()
        optim.step()
        scheduler.step()
        running_loss+=loss.item()
        

        predict_result = torch.argmax(scores, dim=-1)
        for i in range(predict_result.size(0)):
            for j in range(predict_result.size(1)):
                if mask[i, j] == 1 and negative_samples[i, j] == 1:
                    if predict_result[i, j] == labels[i][j]:
                        match += 1
                    all += 1
                    if labels[i][j] != j:
                        if predict_result[i, j] == labels[i][j]:
                            rp_match += 1
                        rp_all += 1

    torch.save(myModel, f'./saved_model/epoch{epoch}')
    print('\nepoch {}, Batch {}, Loss {:.4f}, acc {:.4f}, rp acc {:.4f}, rm acc  {:.4f}'.format(
        epoch, index + 1, running_loss / len(dl), match / all, rp_match / rp_all, (match-rp_match) / (all-rp_all)))
    running_loss = 0
    match = 0.
    all = 0.
    rp_match = 0.
    rp_all = 0.

 -- Start training -- 


100%|██████████| 1836/1836 [46:04<00:00,  1.51s/it]



epoch 0, Batch 1836, Loss 1574.1019, acc 0.0214, rp acc 0.0238, rm acc  0.0025


100%|██████████| 1836/1836 [44:51<00:00,  1.47s/it]



epoch 1, Batch 1836, Loss 1456.6943, acc 0.0237, rp acc 0.0263, rm acc  0.0033


100%|██████████| 1836/1836 [40:31<00:00,  1.32s/it]



epoch 2, Batch 1836, Loss 1454.9361, acc 0.0236, rp acc 0.0262, rm acc  0.0031


 35%|███▌      | 650/1836 [13:44<25:03,  1.27s/it]


KeyboardInterrupt: 

In [None]:
random.sample([1,2,3,4,5], 5)

[2, 3, 4, 1, 5]