In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import Dataset

In [15]:
import sys
sys.path.append("../../pyASBC/src")
sys.path.append("../../CompoTree/src")

In [17]:
import pyASBC
from CompoTree import ComponentTree

In [19]:
asbc = pyASBC.Asbc5Corpus()
ctree = ComponentTree.load()

In [13]:
from itertools import islice
sent_iter = asbc.iter_sentences()
sents = list(islice(sent_iter, 0, 10))

In [21]:
word = sents[0][0]

In [25]:
cc = ctree.query("時", use_flag="shortest", max_depth=1)[0]

In [28]:
cc.components()

['日', '寺']

In [99]:
from itertools import chain
def serialize_components(ch):
    try:
        compos = ctree.query(ch, use_flag="shortest", max_depth=1)[0]
    except Exception as ex:
        compos = ""
    if isinstance(compos, str):
        return ["<COMPO_NA>"]
    
    idc = compos.idc
    serialized = []
    for i, c in enumerate(compos.components()):
        if isinstance(c, str):            
            serialized.append(f"{idc}{i}{c}")
        else:
            serialized.append("<COMPO_NA>")
    return serialized

def make_charpos(word):
    charpos = []
    for i in range(len(word)):
        cp = ""
        if i == 0:
            cp += "_"
        cp += word[i]
        if i == len(word)-1:
            cp += "_"
        charpos.append(cp)
    return charpos
    
def make_word_tuple(word):
    chars = make_charpos(word)
    compos = [serialize_components(ch) for ch in word]
    compos = list(chain.from_iterable(compos))
    return (compos, chars, word)

In [100]:
class Vocabulary:
    def __init__(self):
        self.vocab = {"<UNK>": 0, "<PAD>": 1}
        self.freq = {v: 0 for v in self.vocab.values()}        
        self.make_rev_vocab()
        self.is_dirty = False
    
    def __repr__(self):
        return f"<Vocabulary: {len(self.vocab)} term(s)>"
    
    def __len__(self):
        return len(self.vocab)
    
    def add(self, term):
        if term not in self.vocab:
            term_idx = len(self.vocab)
            self.vocab[term] = term_idx
        else:
            term_idx = self.vocab[term]
        self.freq[term_idx] = self.freq.get(term_idx, 0) + 1
        self.is_dirty = True
    
    def make_rev_vocab(self):
        self.rev_vocab = {v: k for k, v in self.vocab.items()}
        self.is_dirty = False
        
    def encode(self, term):
        return self.vocab.get(term, 0)
    
    def decode(self, index):
        if not self.rev_vocab or self.is_dirty:
            self.make_rev_vocab()
        return self.rev_vocab.get(index, "<UNK>")
    
    def save(self, fpath):
        with open(fpath, "wb") as fout:
            pickle.dump((self.vocab, self.freq), fout)
    
    @classmethod
    def load(self, fpath):
        with open(fpath, "rb") as fin:
            vocab, freq = pickle.load(fin)
        loaded = Vocabulary()
        loaded.vocab = vocab
        loaded.freq = freq
        return loaded

In [101]:
word_vocab = Vocabulary()
char_vocab = Vocabulary()
compo_vocab = Vocabulary()

In [102]:
word_vocab.add("測試")
word_vocab.add("測試")
word_vocab.add("程式")
assert word_vocab.encode("測試") == 2
assert word_vocab.decode(2) == "測試"
assert len(word_vocab) == 4

In [103]:
import pickle
from tqdm.auto import tqdm
word_iter = asbc.iter_words()
# word_iter = islice(word_iter, 0, 100)
for word in tqdm(word_iter):
    components, chars, word = make_word_tuple(word)
    for compo_x in components:
        compo_vocab.add(compo_x)
    for char_x in chars:
        char_vocab.add(char_x)
    word_vocab.add(word)
compo_vocab.save("../data/compo_vocab.pkl")
char_vocab.save("../data/char_vocab.pkl")
word_vocab.save("../data/word_vocab.pkl")

len(compo_vocab), len(char_vocab), len(word_vocab)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




(3289, 20396, 217350)

In [97]:
compo_vocab = Vocabulary.load("../data/compo_vocab.pkl")
char_vocab = Vocabulary.load("../data/char_vocab.pkl")
word_vocab = Vocabulary.load("../data/word_vocab.pkl")
len(compo_vocab), len(char_vocab), len(word_vocab)

(171, 149, 89)