In [1]:
%load_ext autoreload
%autoreload 2
import sys
if "../src" not in sys.path:
    sys.path.append("../src")

In [266]:
import re
import pickle
import json
import random
from datetime import datetime
from pathlib import Path
from dataclasses import dataclass
from typing import List

import numpy as np
import torch
from tqdm.auto import tqdm

from transformers import MT5TokenizerFast

from CwnGraph import CwnImage
import vec4gloss
from vec4gloss import check_hashes
from vec4gloss import Vec4GlossModel

## Data dependencies
```
..\data\annotation.json 2ed250
..\data\models\vec4gloss-defgen-220629-1250\pytorch_model.bin 9f894f
```

In [3]:
vec4gloss_model_dir = "../data/models/vec4gloss-defgen-220629-1250"
_ = check_hashes([
    "../data/annotation.json",
    vec4gloss_model_dir + "/pytorch_model.bin",
])

..\data\annotation.json 2ed250
..\data\models\vec4gloss-defgen-220629-1250\pytorch_model.bin 9f894f


## Loading resources

In [4]:
annot_data = json.loads(Path("../data/annotation.json").read_text(encoding="UTF-8"))

In [5]:
len(annot_data), annot_data[45]

(288,
 {'sense_id': 3048001,
  'head_word': '沿街',
  'POS': 'D',
  'definition': '表同一事件在經過的街道中重複發生。',
  'event_role': 'agent',
  'schemas': [{'type': 'event', 'start': 1, 'end': 5},
   {'type': 'scope', 'start': 5, 'end': 6},
   {'type': 'place', 'start': 6, 'end': 11},
   {'type': 'scope', 'start': 11, 'end': 12},
   {'type': 'mod', 'start': 12, 'end': 14},
   {'type': 'action', 'start': 14, 'end': 16}]})

In [6]:
## Loading modem
use_cuda = torch.cuda.is_available() and "GeForce" not in torch.cuda.get_device_name()
device = "cuda" if use_cuda else "cpu"    
print("Using", device)

model = Vec4GlossModel.from_pretrained(vec4gloss_model_dir).to(device)
tokenizer = MT5TokenizerFast.from_pretrained(vec4gloss_model_dir)
gen = vec4gloss.gen_func(tokenizer, model)

Using cpu


In [7]:
CWN_VER = "v.2022.06.21"
cwn = CwnImage.load(CWN_VER)

## Decoding

In [267]:
@dataclass
class DepWinOutput:
    dep_win: int
    lratio: float
    token_probs: List[float]
    lratios: List[float]

In [268]:
def compute_mean_vec(examples):       
    enc_vecs = []
    for ex in examples:
        try:
            with torch.no_grad():
                enc_vecs.append(vec4gloss.extract_encoder_vector(ex, tokenizer, model))
        except AssertionError:
            pass
    return torch.cat(enc_vecs).mean(0, keepdim=True).to(device)

In [329]:
def compute_dependents(tgt_seq, tgt_idx, dbg=False):
    assert tgt_idx < len(tgt_seq)
    token_probs = []    
    generated = [0] + tgt_seq[:tgt_idx]
    tgt_token = tgt_seq[tgt_idx]
    ret_lr = 0    
    lrs = []
    
    def dbg_print(*x):
        if dbg: print(*x)
        
    dbg_print("target: ", tokenizer.decode([tgt_token]))
    for cursor in range(min(8, len(generated))):
        seq_mask = torch.ones(1, len(generated), dtype=torch.long)
        seq_mask[0, len(generated)-cursor:] = 0
        seq_mask.to(device)        
        input_ids = torch.tensor([generated]).to(device)     
        with torch.no_grad():        
            out = model(decoder_encoder_vector=mean_vec, 
                        decoder_input_ids=input_ids,
                        decoder_attention_mask=seq_mask)                
        token_prob = out.logits.softmax(-1)[0, -1, tgt_token].item()        
        token_probs.append(token_prob)
        mask_text = tokenizer.decode(torch.mul(input_ids, seq_mask)[0, 1:])
        mask_text = mask_text.replace("<pad>", "□")        
                
        if len(token_probs) > 1:            
            lr = -2*np.log(token_probs[-1] / token_probs[-2])
        else:
            lr = 0
        lrs.append(lr)
        dbg_print("({:>2d}) [{:.4f}] {} ({:.2f})".format(cursor, token_prob, mask_text, lr))
    
    win = len(lrs)-1
    ret_lr = lrs[-1]
    for lr_i, lr_x in enumerate(lrs):
        if lr_x > 0.2:            
            win = lr_i
            ret_lr = lr_x
            break                
    return DepWinOutput(win, ret_lr, token_probs, lrs)

## Finding dependent span

In [334]:
annot_x = annot_data[10] 
example_sentences = cwn.from_sense_id("{:08d}".format(annot_x["sense_id"])).examples
mean_vec = compute_mean_vec(example_sentences)
print(annot_x["head_word"])
tgt_text = "{}。{}".format(annot_x["POS"], annot_x["definition"])
tgt_seq = tokenizer(tgt_text)["input_ids"]
print("Tgt:", tgt_text)
print("Gen:", vec4gloss.decode_vector(mean_vec, tokenizer, model))

豈料
Tgt: D。表後述事件不在預期之中發生。
Gen: Dk。表後述事件不在預期之中發生。


In [335]:
for idx in range(2, len(tgt_seq)-2):    
    token_text = tokenizer.convert_ids_to_tokens(tgt_seq[idx])    
    dep_out = compute_dependents(tgt_seq, idx, dbg=False)
    dep_text = tokenizer.decode(tgt_seq[idx-dep_out.dep_win:idx])
    print("{:>2d}. [{:5.2f}] {:\u3000>2s} <- {}".format(idx, dep_out.lratio, token_text, dep_text))

 2. [ 3.29] 　表 <- D。
 3. [ 0.23] 　後 <- 表
 4. [ 0.34] 　述 <- 。表後
 5. [ 2.78] 事件 <- D。表後述
 6. [ 0.85] 不在 <- 述事件
 7. [ 0.22] 　預 <- 後述事件不在
 8. [-0.08] 　期 <- 。表後述事件不在預
 9. [13.63] 之中 <- 預期
10. [ 1.37] 　發 <- 不在預期之中
11. [ 3.35] 　生 <- 之中發


In [336]:
tgt_seq[11-1:11]

[31875]

In [338]:
dep_out = compute_dependents(tgt_seq, 11, dbg=True)

target:  生
( 0) [0.9972] D。表後述事件不在預期之中發 (0.00)
( 1) [0.9985] D。表後述事件不在預期之中□ (-0.00)
( 2) [0.1868] D。表後述事件不在預期□□ (3.35)
( 3) [0.1659] D。表後述事件不在預□□□ (0.24)
( 4) [0.4924] D。表後述事件不在□□□□ (-2.18)
( 5) [0.2968] D。表後述事件□□□□□ (1.01)
( 6) [0.8529] D。表後述□□□□□□ (-2.11)
( 7) [0.9953] D。表後□□□□□□□ (-0.31)
