In [1]:
import sys
sys.path.append("/remote-home/xtzhang/CTC/CTC2021/SpecialEdition")

import os
import random
import time
import logging
import argparse
from dataclasses import dataclass, field
from typing import Optional,Dict, Union, Any, Tuple, List

import numpy as np
import datasets
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import transformers
from transformers import (
    BertConfig,
    DataCollatorForSeq2Seq,
    AutoConfig,
    AutoTokenizer,
    HfArgumentParser,
    TrainingArguments,
    Seq2SeqTrainingArguments,
    set_seed,
)
from transformers import Trainer, Seq2SeqTrainer
from transformers import TrainingArguments
from transformers import trainer_utils, training_args
from transformers.trainer_pt_utils import nested_detach
from transformers import BertForMaskedLM
from transformers.file_utils import PaddingStrategy
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
from transformers.training_args import TrainingArguments

from core import get_dataset, get_metrics, argument_init
from lib import subTrainer  
from data.DatasetLoadingHelper import load_ctc2021, load_sighan, load_lattice_sighan
#from models.bart.modeling_bart_v2 import BartForConditionalGeneration
from models.bert.modeling_bert_v2 import BertForFlat


In [2]:
tokenizer_model_name_path="hfl/chinese-roberta-wwm-ext"

tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name_path)

In [3]:
data = ["今天天气不错", "今天天气还行哦"]

res = tokenizer.batch_encode_plus(data)

In [4]:
res.data

{'input_ids': [[101, 791, 1921, 1921, 3698, 679, 7231, 102],
  [101, 791, 1921, 1921, 3698, 6820, 6121, 1521, 102]],
 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0]],
 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]]}

In [5]:
res.encodings[0]

Encoding(num_tokens=8, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])

In [6]:
from tokenizers import Encoding

In [12]:
class myEncoding():
    def __init__(self, ids, atten_masks, target, pos_s, pos_e):
        self.ids = ids
        self.atten_masks = atten_masks
        self.target = target
        self.pos_s = pos_s
        self.pos_e = pos_e

In [17]:
print(res[0].ids)
print(res[0].type_ids)
print(res[0].tokens)
print(res[0].offsets)
print(res[0].attention_mask)
#print(res[0].sequence_tokens_mask)
#print(res[0].overflowing)
res[0].n_sequences

[101, 791, 1921, 1921, 3698, 679, 7231, 102]
[0, 0, 0, 0, 0, 0, 0, 0]
['[CLS]', '今', '天', '天', '气', '不', '错', '[SEP]']
[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (0, 0)]
[1, 1, 1, 1, 1, 1, 1, 1]


1

In [None]:

class mydataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.data_iter = [ { data[key][i] for key in data.keys() } for i in range(len(data[data.keys()[0]])) ]

    def __getitem__(self, key):
        if isinstance(key, str):
            return self.data[key]
        else:   
            return self.data_iter[index]

    def __len__(self):
        return len(self.data)

In [11]:
def get_offset(length):
    res = []
    res.append((0, 0))
    for i in range(length-2):
        res.append((i, i+1))
    res.append((0,0))
    return res

#test
get_offset(3)

[(0, 0), (0, 1), (0, 0)]

In [8]:
test_encoding = Encoding(ids=[1], type_ids=[1], tokens=["是"], offsets=[(0)], atten_masks=[1])

In [13]:
test_encoding

Encoding(num_tokens=0, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])

In [43]:
print(res[0])

Encoding(num_tokens=8, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])


In [36]:
p = res.encodings[0]

In [32]:
son = dir(res.encodings[0])

In [5]:
res

{'input_ids': [[101, 791, 1921, 1921, 3698, 679, 7231, 102], [101, 791, 1921, 1921, 3698, 6820, 6121, 1521, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]]}

In [6]:
from transformers.tokenization_utils_base import BatchEncoding

In [None]:
tmp_encoding = Encoding()

In [7]:
tmp = BatchEncoding({"input_ids":[[1]], "token_types_ids":[[1]], "attention_mask":[[1]], "pos_s":[[1]]})

In [8]:
finals = [[1],[2]]
atten_masks = [[1], [2]]
target = [[1], [2]]
pos_s = [[1], [2]]
pos_e = [[1], [2]]

In [9]:
from fastNLP import DataSet

src = DataSet({ "lattice":finals, "atten_masks":atten_masks, "target": target, "pos_s":pos_s, "pos_e":pos_e})

In [10]:
src.field_arrays.keys()

dict_keys(['lattice', 'atten_masks', 'target', 'pos_s', 'pos_e'])

In [11]:
for i in src:
    print(i)

+---------+-------------+--------+-------+-------+
| lattice | atten_masks | target | pos_s | pos_e |
+---------+-------------+--------+-------+-------+
| [1]     | [1]         | [1]    | [1]   | [1]   |
+---------+-------------+--------+-------+-------+
+---------+-------------+--------+-------+-------+
| lattice | atten_masks | target | pos_s | pos_e |
+---------+-------------+--------+-------+-------+
| [2]     | [2]         | [2]    | [2]   | [2]   |
+---------+-------------+--------+-------+-------+


In [12]:
new = {}

for key in src.field_arrays:
    new[key] = [i for i in src.field_arrays[key]]

final = BatchEncoding(new)

In [18]:
print(final, type(final))
print(res, type(res), res[0])

{'lattice': [[1], [2]], 'atten_masks': [[1], [2]], 'target': [[1], [2]], 'pos_s': [[1], [2]], 'pos_e': [[1], [2]]} <class 'transformers.tokenization_utils_base.BatchEncoding'>
{'input_ids': [[101, 791, 1921, 1921, 3698, 679, 7231, 102], [101, 791, 1921, 1921, 3698, 6820, 6121, 1521, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]]} <class 'transformers.tokenization_utils_base.BatchEncoding'> Encoding(num_tokens=8, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])


In [19]:
for i in res:
    print(i)

input_ids
token_type_ids
attention_mask


In [14]:
dataset = {"train":src, "test":src}

def tmp_transform(fnlp_dataset):
    new = {}

    for key in fnlp_dataset.field_arrays:
        new[key] = [i for i in fnlp_dataset.field_arrays[key]]

    res = BatchEncoding(new)

    return res


test = tmp_transform(dataset["train"])

print(test)

dataset2 = dict(zip(dataset, map(tmp_transform, dataset.values())))


dataset2

{'lattice': [[1], [2]], 'atten_masks': [[1], [2]], 'target': [[1], [2]], 'pos_s': [[1], [2]], 'pos_e': [[1], [2]]}


{'train': {'lattice': [[1], [2]], 'atten_masks': [[1], [2]], 'target': [[1], [2]], 'pos_s': [[1], [2]], 'pos_e': [[1], [2]]},
 'test': {'lattice': [[1], [2]], 'atten_masks': [[1], [2]], 'target': [[1], [2]], 'pos_s': [[1], [2]], 'pos_e': [[1], [2]]}}