In [8]:
import numpy as np
from tqdm.auto import tqdm
from CwnGraph import CwnBase, CwnImage
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict, DatasetInfo
from datasets import Value, Sequence, Features
from datasets import concatenate_datasets

In [9]:
rng = np.random.RandomState(12052)

In [10]:
cwn = CwnImage.latest()
senses = []
for sense_x in tqdm(cwn.get_all_senses()):
    all_exs = sense_x.all_examples()
    if all_exs and all(x for x in all_exs):
        senses.append(sense_x)
len(senses), sum(len(set(sense_x.all_examples())) for sense_x in senses)

  0%|          | 0/29415 [00:00<?, ?it/s]

(29306, 93208)

In [11]:
def make_example(sense_x):
    return dict(
        cwnid=sense_x.id,
        word=sense_x.head_word,
        pos=sense_x.pos,
        definition=sense_x.definition,
        examples=sense_x.all_examples(),
    )

In [12]:
sense_data = [make_example(x) for x in senses]
sense_cols = {fld: [sense_x[fld] for sense_x in sense_data] for fld in sense_data[0].keys()}
sense_ds = Dataset.from_dict(sense_cols)

sense_ds = sense_ds.cast(
    Features({
          "cwnid": Value(dtype='string'),
          "word": Value(dtype='string'),
          "pos": Value(dtype='string'),
          "definition": Value(dtype='string'),
          "examples": Sequence(feature=Value(dtype='string'))
         }))

Casting the dataset:   0%|          | 0/3 [00:00<?, ?ba/s]

In [13]:
sense_ds

Dataset({
    features: ['cwnid', 'word', 'pos', 'definition', 'examples'],
    num_rows: 29306
})

## Definition Corruption

In [14]:
def corrupt_sentence(instance, win=3):
    cwnid = instance["cwnid"]
    text = instance["definition"]
    if len(text) <= win+1:
        return {
            "cwnid": cwnid,
            "src": text, 
            "tgt": text
        }
    cor_x = rng.randint(len(text)-win)
    
    cor_len = np.clip(rng.poisson(2), 1, 4)
    cor_sites = [(cor_x, cor_len)]
    if len(text) > 20:
        cor_x = ((cor_x+cor_len) + rng.randint(len(text)-cor_len)) % len(text)
        cor_len = np.clip(rng.poisson(2), 1, 4)        
        cor_sites.append((cor_x, cor_len))
    cor_sites = sorted(cor_sites, key=lambda x: x[0])
    
    cor_text = ""    
    target_text = ""
    cur_pos = 0
    cor_idx = 0    
    for cor_x, cor_len in cor_sites:
        if cur_pos <= cor_x:            
            sentinel = f"<extra_id_{cor_idx:02d}>"
            cor_text += text[cur_pos:cor_x] + sentinel
            target_text += sentinel + text[cor_x:cor_x+cor_len]
            cor_idx += 1
        cur_pos = cor_x + cor_len
    cor_text += text[cur_pos:]
    target_text += f"<extra_id_{cor_idx:02d}>"
    return {
        "cwnid": cwnid,
        "src": cor_text, 
        "tgt": target_text}
    

In [15]:
ds_corrupt = sense_ds.map(corrupt_sentence, remove_columns=["word", "pos", "examples", "definition"])

  0%|          | 0/29306 [00:00<?, ?ex/s]

In [16]:
ds_corrupt_split = ds_corrupt.train_test_split(test_size=0.1)

## Definition Generation

In [17]:
def defgen_mapper(instance):
    cwnid = instance["cwnid"]
    word = instance["word"]
    pos = instance["pos"]
    definition = instance["definition"]
    examples = set(instance["examples"])
    word_prefix = f"{word}。"
    pos_prefix = f"{pos}。"
    return {
        "cwnid": [cwnid] * len(examples),            
        "src": [word_prefix+ex for ex in examples],
        "tgt": [pos_prefix+definition]*len(examples)
    }

def flatten_list(instances):    
    return {k: sum(instances[k], [])
            for k in instances.keys()}

In [18]:
ds_defgen = sense_ds.map(defgen_mapper, remove_columns=["word", "pos", "examples", "definition"])\
                    .map(flatten_list, batched=True)

  0%|          | 0/29306 [00:00<?, ?ex/s]

  0%|          | 0/30 [00:00<?, ?ba/s]

In [20]:
ds_defgen[121]

{'cwnid': '03001702',
 'src': '抱。不要整天躲宿舍<抱>電腦，有時間應該出去曬曬太陽。',
 'tgt': 'VC。花費大量時間於特定事物上。'}

In [21]:
from itertools import groupby
def stratify_split(cwnids):
    train_idxs = []
    test_idxs = []
    for grp, grp_idx in groupby(np.arange(len(cwnids)), key=lambda x: cwnids[x]):
        grp_idx = list(grp_idx)        
        if len(grp_idx) > 2:
            test_idxs.extend(grp_idx[-1:])
            train_idxs.extend(grp_idx[:-1])
        else:
            train_idxs.extend(grp_idx)
    return train_idxs, test_idxs

In [22]:
defgen_train_idxs, defgen_test_idxs = stratify_split(ds_defgen["cwnid"])

In [23]:
ds_defgen_train = ds_defgen.select(defgen_train_idxs)
ds_defgen_test = ds_defgen.select(defgen_test_idxs)

In [24]:
info = DatasetInfo("CWN seq2seq data with denoising and definition generation")
cwn_seq2seq_ds = DatasetDict({
    "train": concatenate_datasets([ds_defgen_train, ds_corrupt_split["train"]], info=info),
    "test":  concatenate_datasets([ds_defgen_test, ds_corrupt_split["test"]], info=info)})

In [25]:
cwn_seq2seq_ds["train"][:10]

{'cwnid': ['03000101',
  '03000101',
  '03000102',
  '03000102',
  '03000201',
  '03000201',
  '03000202',
  '03000202',
  '03000203',
  '03000203'],
 'src': ['啊唷。<啊唷>，這麼多蟑螂和老鼠屎！髒透了！',
  '啊唷。門砰的一聲，阿姨跳了起來，喊一聲「<啊唷>」。',
  '啊唷。只聽老爺「<啊唷>」一聲，說是一條腿跌斷了。',
  '啊唷。在用剪刀剪去死肉時，老人叫著「<啊唷>」、「<啊唷>」。',
  '唉唷。哈、哈、哈，<唉唷>，這也是，這也是有意思的說法。',
  '唉唷。<唉唷>，你的東西掉下去了啦！',
  '唉唷。<唉唷>，我的頭好痛。',
  '唉唷。<唉唷>，你踩到我的腳了。',
  '唉唷。<唉唷>！什麼情什麼愛,說來說去還不都是為了自己。',
  '唉唷。<唉唷>，好啦！你不要老是學那個電視上啦！'],
 'tgt': ['I。表驚訝的語氣。',
  'I。表驚訝的語氣。',
  'I。表痛苦、呼痛的聲音。',
  'I。表痛苦、呼痛的聲音。',
  'I。表驚訝的語氣。',
  'I。表驚訝的語氣。',
  'I。表痛苦、呼痛的聲音。',
  'I。表痛苦、呼痛的聲音。',
  'I。表示不耐煩的語氣。',
  'I。表示不耐煩的語氣。']}

In [26]:
cwn_seq2seq_ds.save_to_disk("../data/cwn_seq2seq_charlie_ds")

Flattening the indices:   0%|          | 0/95 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/28 [00:00<?, ?ba/s]

## Collator

In [39]:
from transformers import MT5TokenizerFast
tokenizer = MT5TokenizerFast.from_pretrained("google/mt5-base")

In [40]:
tokenizer("!23")

{'input_ids': [259, 309, 2116, 1], 'attention_mask': [1, 1, 1, 1]}

In [58]:
def CwnSeq2Seq_collator_fn(batch):
    src_batch = tokenizer(batch["src"], padding="longest", return_tensors="pt")
    tgt_batch = tokenizer(batch["tgt"], padding="longest", return_tensors="pt")        
    tgt_mask = tgt_batch["attention_mask"]
    tgt_batch["input_ids"][~tgt_mask.bool()] = -100
    return {
        **src_batch, "labels": tgt_batch["input_ids"]
    }

In [59]:
cwn_seq2seq_ds["train"][:2]

{'cwnid': ['03000101', '03000101'],
 'src': ['啊唷，I。<啊唷>，這麼多蟑螂和老鼠屎！髒透了！', '啊唷，I。門砰的一聲，阿姨跳了起來，喊一聲「<啊唷>」。'],
 'tgt': ['表驚訝的語氣。', '表驚訝的語氣。']}

In [60]:
CwnSeq2Seq_collator_fn(cwn_seq2seq_ds["train"][:5])

{'input_ids': tensor([[   259,  33332, 239895,    261,    566,    306,   2709,  33332, 239895,
             669,    261,  31752, 120775,   3139, 242812, 241213,   1107,   5991,
           94738, 239113,    309, 242228,  39378,   1322,    309,      1,      0,
               0,      0,      0,      0],
         [   259,  33332, 239895,    261,    566,    306,  27304, 241982,  96543,
          137266,    261,  15101, 194762,  52195,   1322, 210707,    261, 116100,
            1374, 137266,    939,   2709,  33332, 239895,    669,  54172,      1,
               0,      0,      0,      0],
         [   259,  33332, 239895,    261,    566,    306,   8779, 165023,   5991,
          222111,    939,   2709,  33332, 239895,    669,    879,   1374, 137266,
             261,  43454, 136363, 124894,  87917,  63687, 164686,   1322,    306,
               1,      0,      0,      0],
         [   259,  33332, 239895,    261,    566,    306,   1083,   2151, 132582,
           36990, 132582,   6072,  111