In [1]:
%load_ext autoreload
%autoreload 2
import sys
if "../src" not in sys.path:
    sys.path.append("../src")

In [2]:
import re
import pickle
import json
import random
from datetime import datetime
from pathlib import Path
from dataclasses import dataclass
from typing import List

import numpy as np
import torch
from tqdm.auto import tqdm

from transformers import MT5TokenizerFast

from CwnGraph import CwnImage
import vec4gloss
from vec4gloss import check_hashes
from vec4gloss import Vec4GlossModel

## Data dependencies
```
..\data\annotation.json 2ed250
..\data\models\vec4gloss-defgen-220629-1250\pytorch_model.bin 9f894f
```

In [3]:
vec4gloss_model_dir = "../data/models/vec4gloss-defgen-220629-1250"
_ = check_hashes([
    "../data/annotation.json",
    vec4gloss_model_dir + "/pytorch_model.bin",
])

..\data\annotation.json 2ed250
..\data\models\vec4gloss-defgen-220629-1250\pytorch_model.bin 9f894f


## Loading resources

In [4]:
annot_data = json.loads(Path("../data/annotation.json").read_text(encoding="UTF-8"))

In [5]:
len(annot_data), annot_data[45]

(288,
 {'sense_id': 3048001,
  'head_word': '沿街',
  'POS': 'D',
  'definition': '表同一事件在經過的街道中重複發生。',
  'event_role': 'agent',
  'schemas': [{'type': 'event', 'start': 1, 'end': 5},
   {'type': 'scope', 'start': 5, 'end': 6},
   {'type': 'place', 'start': 6, 'end': 11},
   {'type': 'scope', 'start': 11, 'end': 12},
   {'type': 'mod', 'start': 12, 'end': 14},
   {'type': 'action', 'start': 14, 'end': 16}]})

In [6]:
## Loading modem
use_cuda = torch.cuda.is_available() and "GeForce" not in torch.cuda.get_device_name()
device = "cuda" if use_cuda else "cpu"    
print("Using", device)

model = Vec4GlossModel.from_pretrained(vec4gloss_model_dir).to(device)
tokenizer = MT5TokenizerFast.from_pretrained(vec4gloss_model_dir)
gen = vec4gloss.gen_func(tokenizer, model)

Using cpu


In [7]:
CWN_VER = "v.2022.06.21"
cwn = CwnImage.load(CWN_VER)

## Annotation Frame

In [9]:
from dataclasses import dataclass
@dataclass
class Scheme:
    type: str; start: int; end: int
@dataclass
class AnnotFrame:
    sense_id: str;    POS: str; head_word: str
    definition: str;  event_role: str
    schemas: List[Scheme]
    
    def __post_init__(self):
        self.schemas = self.preprocess()
        
    def get_frame(self, frame_idx):
        frame_x = self.schemas[frame_idx]
        s, e = frame_x.start, frame_x.end
        return (frame_x.type, self.definition[s:e])
    
    def preprocess(self):        
        new_frames = []
        idx = 0
        for frame_x in self.schemas:
            if idx < frame_x.start:
                new_frames.append(Scheme('--', idx, frame_x.start))                
            elif idx > frame_x.start:
                assert("shouldn't happen")
            new_frames.append(frame_x)
            idx = frame_x.end
        return new_frames
    
    def show(self):
        print(" ".join([
            "<{}>{}".format(
                x.type, self.definition[x.start:x.end]) 
            for x in self.schemas
        ]))

In [10]:
from collections import Counter
annot_frames = []
for annot_i, annot_x in enumerate(annot_data):
    schemas = [Scheme(**x) for x in annot_x["schemas"]]
    annot_y = {k: v for k, v in annot_x.items() if k!="schemas"}    
    frame_x = AnnotFrame(**annot_y, schemas=schemas)
    annot_frames.append(frame_x)        

In [11]:
annot_frames[20].show()

<-->表 <reason>為同一特定原因 <-->而 <mod>一起 <action>做 <event>後述事件


## Decoding

In [12]:
@dataclass
class DepWinOutput:
    dep_win: int
    lratio: float
    token_probs: List[float]
    lratios: List[float]
    token_text: str
    dep_text: str
    
    def __repr__(self):
        return "[{:5.2f}] {:\u3000>2s} <- {}".format(
            self.lratio, self.token_text, self.dep_text)

In [13]:
def compute_mean_vec(examples):       
    enc_vecs = []
    for ex in examples:
        try:
            with torch.no_grad():
                enc_vecs.append(vec4gloss.extract_encoder_vector(ex, tokenizer, model))
        except AssertionError:
            pass
    return torch.cat(enc_vecs).mean(0, keepdim=True).to(device)

In [14]:
def compute_dependents(tgt_seq, tgt_idx, dbg=False):
    assert tgt_idx < len(tgt_seq)
    token_probs = []    
    generated = [0] + tgt_seq[:tgt_idx]
    tgt_token = tgt_seq[tgt_idx]
    ret_lr = 0    
    lrs = []
    
    def dbg_print(*x):
        if dbg: print(*x)
        
    dbg_print("target: ", tokenizer.decode([tgt_token]))
    for cursor in range(min(4, len(generated))):
        seq_mask = torch.ones(1, len(generated), dtype=torch.long)
        seq_mask[0, len(generated)-cursor:] = 0
        seq_mask.to(device)        
        input_ids = torch.tensor([generated]).to(device)     
        with torch.no_grad():        
            out = model(decoder_encoder_vector=mean_vec, 
                        decoder_input_ids=input_ids,
                        decoder_attention_mask=seq_mask)                
        token_prob = out.logits.softmax(-1)[0, -1, tgt_token].item()        
        token_probs.append(token_prob)
        mask_text = tokenizer.decode(torch.mul(input_ids, seq_mask)[0, 1:])
        mask_text = mask_text.replace("<pad>", "□")        
                
        if len(token_probs) > 1:                  
            lr = token_probs[0]/token_probs[-1]
        else:
            lr = 0
        lrs.append(lr)
        dbg_print("({:>2d}) [{:.4f}] {} ({:.2f})".format(cursor, token_prob, mask_text, lr))
    
    win = len(lrs)-1
    ret_lr = lrs[-1]
    for lr_i, lr_x in enumerate(lrs):
        if lr_x > 2:            
            win = lr_i
            ret_lr = lr_x
            break                
    return DepWinOutput(win, ret_lr, token_probs, lrs, dep_text=None, token_text=None)

## Building decoder data objects

In [None]:
annot_dec_objects = []
for annot_x in tqdm(annot_frames):    
    try:
        example_sentences = cwn.from_sense_id("{:08d}".format(annot_x.sense_id)).examples        
        mean_vec = compute_mean_vec(example_sentences)
        # print(annot_x.head_word)
        tgt_text = "{}。{}".format(annot_x.POS, annot_x.definition)
        tgt_seq = tokenizer(tgt_text)["input_ids"]
        # print("Tgt:", tgt_text)
        # print("Gen:", vec4gloss.decode_vector(mean_vec, tokenizer, model))
        dec_info = []
        for idx in range(2, len(tgt_seq)-2):    
            token_text = tokenizer.convert_ids_to_tokens(tgt_seq[idx])    
            dep_out = compute_dependents(tgt_seq, idx, dbg=False)    
            dep_text = tokenizer.decode(tgt_seq[idx-dep_out.dep_win:idx])
            dep_out.dep_text = dep_text
            dep_out.token_text = token_text
            dec_info.append(dep_out)
        annot_dec_objects.append((annot_x, dec_info))    
    except Exception as ex:
        print(ex)        

In [17]:
annot_dec_objects_path = "../data/annot_dec_objects.pkl"
with open(annot_dec_objects_path, "wb") as fout:
    pickle.dump(annot_dec_objects, fout)

In [21]:
## there are senses having no example sentences
empty_senses = []
for annot_x in tqdm(annot_frames):        
    example_sentences = cwn.from_sense_id("{:08d}".format(annot_x.sense_id)).examples        
    if not example_sentences:
        empty_senses.append(annot_x.sense_id)
len(empty_senses)

  0%|          | 0/288 [00:00<?, ?it/s]

20

In [19]:
len(annot_dec_objects)

265

## Output Hashes

```
..\data\annot_dec_objects.pkl 68ef27
```

In [22]:
_ = check_hashes([annot_dec_objects_path])

..\data\annot_dec_objects.pkl 68ef27


In [None]:
annot_frames[20].show()

In [None]:
dep_out = compute_dependents(tgt_seq, 7, dbg=True)