In [1]:
import os
import random
import time
import logging
import argparse
from dataclasses import dataclass, field
from typing import Optional,Dict, Union, Any, Tuple, List

import numpy as np
import datasets
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import transformers
from transformers import (
    DataCollatorForSeq2Seq,
    AutoConfig,
    AutoTokenizer,
    HfArgumentParser,
    TrainingArguments,
    Seq2SeqTrainingArguments,
    set_seed,
)
from transformers import Trainer, Seq2SeqTrainer
from transformers import TrainingArguments
from transformers import trainer_utils, training_args
from transformers.trainer_pt_utils import nested_detach
from transformers import BertForMaskedLM
from transformers.file_utils import PaddingStrategy
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
from transformers.training_args import TrainingArguments

import sys
sys.path.append("/remote-home/xtzhang/CTC/CTC2021/SpecialEdition/")
from core import get_super_magic_dataset, get_magic_dataset, get_metrics, argument_init
from lib import subTrainer  
from models.bert.modeling_bert_v3 import BertForMaskedLM_v2 



In [8]:
model = BertForMaskedLM_v2.from_pretrained("/remote-home/xtzhang/CTC/CTC2021/SpecialEdition/tmp/sighan/bert_MaskedLM_v2_std_super_mask.epoch10.bs32")

In [2]:
train_dataset, eval_dataset, test_dataset, tokenizer = get_super_magic_dataset("sighan", "../") 

Loading Dataset !
Fri Dec 17 01:46:25 UTC 2021
Loading Abs_Pos and Special Token Bert SigHan Dataset ...


100%|██████████| 568392/568392 [00:03<00:00, 151615.14it/s]
100%|██████████| 1100/1100 [00:00<00:00, 541200.52it/s]
100%|██████████| 1100/1100 [00:00<00:00, 315231.92it/s]


Save cache to cache/sighan_abs_pos_expo_token.
Loading Succeed !
Fri Dec 17 01:48:45 UTC 2021


In [9]:
@dataclass
class MyDataCollatorForSeq2Seq:
    """
    """
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    label_pad_token_id: int = -100

    def __call__(self, features):
        """
        """
        from copy import deepcopy

        f_copy = deepcopy(features)

        shared_max_length = max([ len(i['input_ids']) for i in f_copy])


        for i in range(len(f_copy)):
            f_copy[i]["raw_length"] = []

        for i in range(len(f_copy)):
            f_copy[i]["raw_length"].append(len(f_copy[i]["input_ids"]))

        def simple_pad(f_copy, key):
            f_key = [ f[key] for f in f_copy ]
            if f_key is not None:
                max_length = max(len(l) for l in f_key)

                padding_side = "right"

                if key == "attention_mask":
                    label_pad_token_id = 0
                elif key in ["input_ids", "lattice"]:
                    label_pad_token_id = 0
                elif key == "labels":
                    max_length = shared_max_length
                    label_pad_token_id= -100
                else:
                    label_pad_token_id = self.label_pad_token_id 

                for f in f_copy: 
                    remainder = [label_pad_token_id] * (max_length - len(f[key]))
                    f[key] = (
                        f[key] + remainder if padding_side == "right" else remainder + f[key]
                    )
            
            return f_copy

        for key in ["input_ids", "lattice", "labels", "attention_mask"]:
            f_copy = simple_pad(f_copy, key)

        new = {}

        black_list = []

        for key in f_copy[0].keys():
            if key not in black_list:    
                new[key] = []
        
        for feature in f_copy:
            for key in feature.keys():
                if key not in black_list:
                    new[key].append(feature[key])

        for key in new.keys():
            if key not in  black_list:
                #print(key)
            #    new[key] = new[key]
                new[key] = torch.tensor(new[key]) 

        new.pop("raw_length")

        return new

In [5]:
tokenizer_model_name_path="hfl/chinese-roberta-wwm-ext"

tokenizer = AutoTokenizer.from_pretrained(
    tokenizer_model_name_path
)

KeyboardInterrupt: 

In [6]:
pre = MyDataCollatorForSeq2Seq()

In [10]:
p = model(**pre([test_dataset[0]])).logits

In [11]:
out_ = torch.argmax(torch.softmax(p, 2), -1)

In [12]:
test_dataset[0]

{'input_ids': [872,
  1962,
  106,
  2769,
  3221,
  2476,
  4263,
  3152,
  511,
  2644,
  2356,
  6421,
  5687,
  2002],
 'lattice': [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 4, 6, 6, 6],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 'labels': [21128, 21128, 21128, 21128, 21128, 21128, 21128, 21128, 21128],
 'sub_length': 9}

In [13]:
out_

tensor([[21128, 21128, 21128, 21128, 21128, 21128, 21128, 21128, 21128, 21128,
          3221, 21128, 21128, 21128]])

In [60]:
bs = 32

result = []

for i in tqdm(range(0, len(test_dataset) // bs + 1)):
    batch = test_dataset[ i*32 : (i+1) *32]

    logits = model(**pre(batch)).logits

    pred = torch.argmax(torch.softmax(logits, 2), -1)

    mid = tokenizer.batch_decode(pred)

    cus = []

    for j in range(len(batch)):
        cus.append( mid[j].split()[:len(batch[j]["labels"])] )

    result += cus



100%|██████████| 35/35 [01:18<00:00,  2.25s/it]


In [53]:
from utils.io import read_csv
tmp_source = read_csv("/remote-home/xtzhang/CTC/CTC2021/SpecialEdition/data/rawdata/sighan/lattice_balanced/test.src")

In [58]:
source = []

import re

for i in range(0, len(tmp_source), 2):
    source.append(re.sub("\n", "", tmp_source[i]))


In [64]:

finals = []

for i in range(len(result)):
    tmp = []
    for j, tok in enumerate(result[i]):
        if tok == "<RAW>":
            tmp.append(source[i][j])
        else:
            tmp.append(tok)

    finals.append("".join(tmp))

In [66]:
finals

['你好!我是张爱文。',
 '下个星期,我跟我朋友打算去法国玩儿。',
 '我听说,你找到新工作,我很高兴。',
 '对不起,最近我很忙,所以我不会去妳的。',
 '真麻烦你了。希望你们好好地跳舞。',
 '我本来要参加这个会的,可是我今天有一点儿事情一定要做完。',
 '所以我先去看医生,再去你的祝庆会。',
 '吃了早餐以后他去上课。',
 '走路差不多十分钟,我们到了。',
 '他知道今天,高中三年级的最后一天,是个很重要的天。',
 '他们看了一个很可爱的电影,一个小机一个人去火星。',
 '我起床的时候,吃早饭。',
 '在学校因为他很用功常常坐最前面的椅子。',
 '因为他学得很好,所以同班同学都喜欢问他问题。',
 '美美说因为学校没有冷气,所以今天教室里面热得不得了。',
 '在补习班他昨天晚上到夜里两点才读书,所以他一回家就累得睡觉了。',
 '她戴著眼镜跟袜子入睡了。',
 '张爱文很聪明,老师教他英文、地理什么的,他很快明白了。',
 '今天下了课,我打算跟我的女朋友去看电影,所以我有一点儿紧张,六点半我就起床了。',
 '在公车上有很多人,所以我们没有位子可以座。',
 '我觉得「站在公车上有一点麻烦没关系,有我的女朋友就好了」。',
 '我听过她说的话了,我高兴的不得了。',
 '看电影时候,我都觉得这个电影很有意思,可是现在我吧什么事都不懂的。',
 '那个晚上,我睡觉睡得比较更安心。',
 '除了功课以外我也得准备这个周末考试。连忙我都没有时间跟父母见面!怎么办!再忙我也高兴你找到那个工作。',
 '身体健康。',
 '我真的不好意可是今天不能参加,因为我要去台南接我的奶奶。',
 '我听说这个礼拜六你要开一个舞会。可是那天我会很忙,因为我男朋友要回国来看我。',
 '我已经打算了跟他去玩,所以我不能来,对不起!',
 '我请你吃饭,可以吗?',
 '祝你生日快乐!恭喜恭喜!',
 '因为下个星期有很多事情,所以我不能来。',
 '我以前想要告诉你,可是我忘了。我真相怕。',
 '希望你还高兴,也你生日的时候给你一个很好玩的感觉。',
 '不好意思,可是我觉得我那时间不能参加,因为我已经跟我的弟弟打算去台湾。',
 '我不能参加你的舞会,因为我有一点生病了,我真的愿意参加可惜我身体不健康。',
 '我觉得你们会好好的玩。',
 '我的

In [59]:
print(result[0])
len(result)

[['<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>'], ['<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '友', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>'], ['<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>'], ['<RAW>', '<RAW>', '起', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>'], ['<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '地', '<RAW>', '舞', '<RAW>'], ['<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>'], ['<RAW>', '<RAW>', '<RAW>', '<RAW>', '<RAW>', '<R

35

In [68]:
from utils.io import write_to
write_to("no_mask_preds_super.txt", "\n".join(finals))

In [21]:
test_dataset[0]


{'input_ids': [872,
  1962,
  106,
  2769,
  3221,
  2476,
  4263,
  3152,
  511,
  2644,
  2356,
  2002,
  6421,
  5687],
 'lattice': [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 4, 6, 6, 6],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 'labels': [872, 1962, 106, 2769, 3221, 2476, 4263, 3152, 511],
 'sub_length': 9}