In [None]:
%load_ext autoreload
%autoreload 2

## sanity check overfit

In [None]:
import torch.nn.functional as F
from de import DependencyEvolver
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

model = DependencyEvolver(
    d_model=16,
    dim_feedforward=8,
    nhead=1,
    dropout=0,
    N=5,
    encoder_layers=1,
    decoder_layers=1,
    tok_v=tokenizer.vocab_size,
    rel_v=2,
    pos_v=3
)

In [None]:
import torch
from torch.optim import AdamW

batch = (
    (8823, 2),
    [
        torch.tensor([[[1, 0, 2, 0, 3]]]),
        torch.tensor([[[-1, -1, 1, -1, -1]]]),
        torch.tensor([[[-1, 2, -1, 2, -1]]]),
        torch.tensor([[[-1, 0, -1, 1, -1]]]),
        torch.tensor([[[-1, 0, -1, 1, -1]]]),
        torch.tensor([[[-1, 2016, -1, 2833, -1]]])
    ]
)

train_loader = [batch for _ in range(1000)]
eval_loader = [batch for _ in range(20)]
optim = AdamW(model.parameters())

In [None]:
model._train(optim, train_loader, eval_loader, len(train_loader), 20, 1, 1e9, 100)

## full data creation pipeline

In [None]:
import conllu

with open('./data/ud/en_gum-ud-train.conllu', 'r') as f:
    train_sentences = conllu.parse(f.read())
    
with open('./data/ud/en_gum-ud-dev.conllu', 'r') as f:
    dev_sentences = conllu.parse(f.read())

In [None]:
seen = set()
for sentence in train_sentences + dev_sentences:
    for tok in sentence:
        seen.add(tok['form'])

print(len(seen))

with open('vocab/gum_tok.vocab', 'w') as f:
    for tok in seen:
        f.write(tok)
        f.write('\n')

### construct adjacent sequences

In [None]:
def family_tree(parsed):
    children = {tok['id']: [] for tok in parsed}
    for tok in parsed:
        if tok['head'] is None: return None
        if tok['head'] != 0: children[tok['head']].append(tok['id'])
        
    i, root = next((i, tok) for (i, tok) in enumerate(parsed) if tok['head'] == 0)

    seqs = [[(root['form'], root['upos'], root['deprel'], i, True, -1)]]
    cur_leaves = [root['id']]
    all_leaves = [root['id']]

    while cur_leaves:
        seq = []
        next_leaves = []
        
        for i, tok in enumerate(parsed):
            if tok['id'] in all_leaves or tok['head'] in all_leaves:
                seq.append((
                    tok['form'], tok['upos'], tok['deprel'],
                    i, tok['head'] in cur_leaves, None
                ))
                if tok['head'] in cur_leaves: next_leaves.extend(children[tok['id']])
                
        for i, (form, upos, deprel, j, is_leaf, _) in enumerate(seq):
            tok = next(t for t in parsed if t['form'] == form and t['upos'] == upos and t['deprel'] == deprel)
            if tok['head'] == 0:
                par = -1
            else:
                par = next((j for j, (p_form, p_upos, p_deprel, _, _, _) in enumerate(seq)
                            if p_form == parsed[tok['head']-1]['form']
                            and p_upos == parsed[tok['head']-1]['upos']
                            and p_deprel == parsed[tok['head']-1]['deprel']),
                            None)
            
            seq[i] = (form, upos, deprel, j, is_leaf, par)
        
        seqs.append(seq)
        cur_leaves = next_leaves
        all_leaves.extend(next_leaves)
        
    return seqs

### tokenizers

In [None]:
tok_map = {}
rel_map = {}
pos_map = {}

with open('vocab/gum_tok.vocab', 'r') as f:
    for i, line in enumerate(f):
        tok_map[line.strip()] = i

with open('vocab/rel.vocab', 'r') as f:
    for i, line in enumerate(f):
        rel_map[line.strip()] = len(tok_map) + i
        
with open('vocab/pos.vocab', 'r') as f:
    for i, line in enumerate(f):
        pos_map[line.strip()] = len(tok_map) + len(rel_map) + i

In [None]:
print(len(tok_map), len(rel_map), len(pos_map))

In [None]:
decode = {v: k for k, v in (list(tok_map.items()) + list(rel_map.items()) + list(pos_map.items()))}

### get labels

In [None]:
from de import INS_ID, CPY_ID, PRO_ID, EOS_ID

def label(seqs, N):
    traj = [[] for _ in range(6)]

    for i in range(1, len(seqs)-1):
        a, b = seqs[i], seqs[i+1]
    
        op_list = [CPY_ID]
        cpy_list = [-1]
        par_list = [-1]
        tok_list = [-1]
        pos_list = [-1]
        rel_list = [-1]
        
        prev = {i: is_leaf for (_, _, _, i, is_leaf, _) in a}
        for i, (tok, pos, rel, j, _, par) in enumerate(b[:N-2]):
            if j in prev:
                op_list.append(PRO_ID if prev[j] else CPY_ID)
                cpy_list.append(next(i for i, t in enumerate(a) if t[3] == j) + 1)
                par_list.append(-1)
                tok_list.append(-1)
                rel_list.append(-1)
                pos_list.append(-1)
            else:
                op_list.append(INS_ID)
                cpy_list.append(-1)
                par_list.append(par + 1)
                tok_list.append(tok_map.get(tok, len(tok_map) + len(rel_map) + len(pos_map)))
                rel_list.append(rel_map[rel])
                pos_list.append(pos_map[pos])
                
        op_list.append(EOS_ID)
        for i, list in enumerate([op_list, cpy_list, par_list, rel_list, pos_list, tok_list]):
            list.extend([-1 for _ in range(N-len(list))])
            traj[i].append(list)
            
    root = (tok_map[seqs[0][0][0]], rel_map[seqs[0][0][2]], pos_map[seqs[0][0][1]])
            
    return root, traj

In [None]:
from tqdm import tqdm

traj_list = []
t1 = t2 = t3 = 0

for parsed in tqdm(train_sentences):
    seqs = family_tree(parsed)
    
    # missing head
    if seqs is None:
        t1 += 1
        continue
    
    try:
        root, labels = label(seqs, 64)
        
        # tree depth is 1
        if len(labels[0]) == 0:
            t3 += 1
            continue
        
        traj_list.append((root, labels))
       
    # missing parent
    except TypeError as e:
        t2 += 1
        continue
    
print(len(traj_list))
    
import pickle
with open('data/gum/train_1.pkl', 'wb') as f:
    pickle.dump(traj_list, f)

## test

In [None]:
loader = single_loader('data/gum/train_1.pkl')
root, tgts = next(iter(loader))

print(decode[root[0]], decode[root[1]], decode[root[2]])
print(tgts.shape)

In [None]:
root, tgts = next(iter(loader))

model.traj_loss(root, tgts)

In [None]:
from de import *
from torch.optim import AdamW

model = DependencyEvolver(
    d_model=512,
    nhead=8,
    dim_feedforward=2048,
    dropout=0.1,
    encoder_layers=6,
    decoder_layers=6,
    N=64 
)

optim = AdamW(model.parameters(), lr=1e-3)
loader = single_loader('data/gum/train_1.pkl', shuffle=False)

model._train(
    optim, loader, loader,
    1000, 20, 4,
    1001, 100
)