In [1]:
import os
import random
import time
import logging
import argparse
from dataclasses import dataclass, field
from typing import Optional,Dict, Union, Any, Tuple, List

import numpy as np
import datasets
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import transformers
from transformers import (
    DataCollatorForSeq2Seq,
    AutoConfig,
    AutoTokenizer,
    HfArgumentParser,
    TrainingArguments,
    Seq2SeqTrainingArguments,
    set_seed,
)
from transformers import Trainer, Seq2SeqTrainer
from transformers import TrainingArguments
from transformers import trainer_utils, training_args
from transformers.trainer_pt_utils import nested_detach
from transformers import BertForMaskedLM
from transformers.file_utils import PaddingStrategy
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
from transformers.training_args import TrainingArguments

import sys
sys.path.append("/remote-home/xtzhang/CTC/CTC2021/SpecialEdition/")
from core import get_magic_dataset, get_metrics, argument_init
from lib import subTrainer  
from models.bert.modeling_bert_v3 import BertForMaskedLM_v2 



In [7]:
model = BertForMaskedLM_v2.from_pretrained("/remote-home/xtzhang/CTC/CTC2021/SpecialEdition/tmp/sighan/bert_MaskedLM_v2_std_mask.epoch20.bs32")

In [2]:
train_dataset, eval_dataset, test_dataset = get_magic_dataset("sighan", "../") 

Loading Dataset !
Thu Dec 16 09:55:52 UTC 2021
Loading Abs_Pos Bert SigHan Dataset ...


100%|██████████| 284196/284196 [00:02<00:00, 125311.72it/s]
100%|██████████| 600/600 [00:00<00:00, 456730.02it/s]
100%|██████████| 1100/1100 [00:00<00:00, 408186.71it/s]


Save cache to cache/sighan_abs_pos_test.
Loading Succeed !
Thu Dec 16 09:57:08 UTC 2021


In [3]:
@dataclass
class MyDataCollatorForSeq2Seq:
    """
    """
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    label_pad_token_id: int = -100

    def __call__(self, features):
        """
        """
        from copy import deepcopy

        f_copy = deepcopy(features)

        shared_max_length = max([ len(i['input_ids']) for i in f_copy])


        for i in range(len(f_copy)):
            f_copy[i]["raw_length"] = []

        for i in range(len(f_copy)):
            f_copy[i]["raw_length"].append(len(f_copy[i]["input_ids"]))

        def simple_pad(f_copy, key):
            f_key = [ f[key] for f in f_copy ]
            if f_key is not None:
                max_length = max(len(l) for l in f_key)

                padding_side = "right"

                if key == "attention_mask":
                    label_pad_token_id = 0
                elif key in ["input_ids", "lattice"]:
                    label_pad_token_id = 0
                elif key == "labels":
                    max_length = shared_max_length
                    label_pad_token_id= -100
                else:
                    label_pad_token_id = self.label_pad_token_id 

                for f in f_copy: 
                    remainder = [label_pad_token_id] * (max_length - len(f[key]))
                    f[key] = (
                        f[key] + remainder if padding_side == "right" else remainder + f[key]
                    )
            
            return f_copy

        for key in ["input_ids", "lattice", "labels", "attention_mask"]:
            f_copy = simple_pad(f_copy, key)

        new = {}

        black_list = []

        for key in f_copy[0].keys():
            if key not in black_list:    
                new[key] = []
        
        for feature in f_copy:
            for key in feature.keys():
                if key not in black_list:
                    new[key].append(feature[key])

        for key in new.keys():
            if key not in  black_list:
                #print(key)
            #    new[key] = new[key]
                new[key] = torch.tensor(new[key]) 

        new.pop("raw_length")

        return new

In [18]:
tokenizer_model_name_path="hfl/chinese-roberta-wwm-ext"

tokenizer = AutoTokenizer.from_pretrained(
    tokenizer_model_name_path
)

In [9]:
pre = MyDataCollatorForSeq2Seq()

In [13]:
p = model(**pre([test_dataset[0]])).logits

In [15]:
out_ = torch.argmax(torch.softmax(p, 2), -1)

In [16]:
out_

tensor([[ 872, 1962,  106, 2769, 3221, 2476, 4263, 3152,  511, 2644, 3221, 5687,
         3195, 5687]])

In [40]:
bs = 32

result = []

for i in tqdm(range(0, len(test_dataset) // bs + 1)):
    batch = test_dataset[ i*32 : (i+1) *32]

    logits = model(**pre(batch)).logits

    pred = torch.argmax(torch.softmax(logits, 2), -1)

    mid = tokenizer.batch_decode(pred)

    for j in range(len(batch)):
        mid[j] = "".join(mid[j].split())[:len(batch[j]["labels"])]

    result += mid


100%|██████████| 35/35 [01:16<00:00,  2.19s/it]


In [41]:
result[0]
len(result)

1100

In [42]:
from utils.io import write_to
write_to("no_mask_preds.txt", "\n".join(result))

In [21]:
test_dataset[0]


{'input_ids': [872,
  1962,
  106,
  2769,
  3221,
  2476,
  4263,
  3152,
  511,
  2644,
  2356,
  2002,
  6421,
  5687],
 'lattice': [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 4, 6, 6, 6],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 'labels': [872, 1962, 106, 2769, 3221, 2476, 4263, 3152, 511],
 'sub_length': 9}