In [2]:
%load_ext autoreload
%autoreload 2

import sys
if '..' not in sys.path:
    sys.path.append('..')

In [2]:
import conllu
import numpy as np
import random
from nltk.tokenize.treebank import TreebankWordDetokenizer

detok = TreebankWordDetokenizer()

hierarchy = {
    'NOUN': ['NOUN', 'PROPN'],
    'VERB': ['VERB', 'AUX'],
    'MOD': ['ADJ', 'ADV'],
    'FUNC': ['DET', 'PRON', 'ADP', 'CCONJ', 'SCONJ', 'PART'],
    'NUM': ['NUM'],
    'OTHER': ['INTJ', 'SYM', 'PUNCT', 'X', '_']
}

to_parent = {}
for k, vs in hierarchy.items():
    to_parent[k] = k
    for v in vs:
        to_parent[v] = k

def minimize(path):
    with open(path, 'r') as f:
        sentences = conllu.parse(f.read())
       
    res = [] 
    for sent in sentences:
        cur = [(tok['form'], to_parent[tok['upostag']]) for tok in sent]
        res.append(cur)
        
    return res

def deterministic_noise(sent):
    traj = [detok.detokenize([s[0] for s in sent])]
    
    for drop in ['OTHER', 'MOD', 'FUNC', 'NUM', 'VERB', 'NOUN']:
        new_sent = []
        for tok, pos in sent:
            if pos != drop: new_sent.append((tok, pos))
        sent = new_sent
        new_seq = detok.detokenize(s[0] for s in sent)
        if new_seq != traj[-1]: traj.append(new_seq)
        
    return traj[::-1], 0

def sample_pos_noise(sent, traj_length=6):
    importance = {
        'OTHER': 1.5, 'FUNC': 2, 'NUM': 4,
        'MOD': 5, 'VERB': 8, 'NOUN': 9
    }
   
    log_prob = 0 
    traj = [detok.detokenize([s[0] for s in sent])]
   
    for i in range(traj_length):
        if not sent: break
        
        weights = [1 / importance[pos] for _, pos in sent]
        tot = sum(weights)
        weights = [w / tot for w in weights]
        
        N = max(1, len(sent) // (traj_length - i))
        to_drop = random.choices(range(len(sent)), k=min(N, len(sent)), weights=weights)
        log_prob += sum(np.log(weights[i]) for i in to_drop)
        
        sent  = [tok for j, tok in enumerate(sent) if j not in to_drop]
        traj.append(detok.detokenize(s[0] for s in sent))
        
    if traj[-1] != '': traj.append('')
    return traj[::-1], log_prob

In [3]:
import json

def create_corpus(output, input, func):
    m_sentences = minimize(input)
    with open(output, 'w') as f:
        for m_sent in m_sentences:
            output = func(m_sent)
            json.dump(output, f)
            f.write('\n')

In [4]:
create_corpus('../data/ud/ud_train_2.0.0.jsonl', '../data/ud/en_gum-ud-train.conllu', sample_pos_noise)
create_corpus('../data/ud/ud_dev_2.0.0.jsonl', '../data/ud/en_gum-ud-dev.conllu', sample_pos_noise)
create_corpus('../data/ud/ud_train_2.1.0.jsonl', '../data/ud/en_gum-ud-train.conllu', deterministic_noise)
create_corpus('../data/ud/ud_dev_2.1.0.jsonl', '../data/ud/en_gum-ud-dev.conllu', deterministic_noise)

In [None]:
!python ../dep.py ../data/ud/en_ewt-ud-train.conllu ../data/ud_train.jsonl --redundant=1 --weight=0.1

In [None]:
!python ../dep.py ../data/ud/en_ewt-ud-dev.conllu ../data/ud_dev.jsonl --redundant=3 --weight=0.1

In [5]:
from data import TrajectoryDataset
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

train_dataset = TrajectoryDataset.from_disk(
    path='../data/ud/ud_train_2.0.0.jsonl',
    max_len=64,
    tokenizer=tokenizer
)

# eval_dataset = TrajectoryDataset.from_disk(
#     path='../data/ud/ud_dev.jsonl',
#     max_len=64,
#     tokenizer=tokenizer,
#     limit=100
# )

Tokenizing inputs: 100%|██████████| 9521/9521 [00:13<00:00, 690.99it/s] 


In [13]:
# what is the average number of non-pad tokens in each batch?

import torch

for thing, _ in train_loader:
    tot = sum(torch.sum(traj != 0) for traj in thing)
    print(tot)

tensor(3040)
tensor(2991)
tensor(2953)
tensor(2571)
tensor(2556)
tensor(2725)
tensor(2904)
tensor(2812)
tensor(2628)
tensor(3396)
tensor(3065)
tensor(2718)
tensor(3148)
tensor(3089)
tensor(3193)
tensor(3150)
tensor(2514)
tensor(3123)
tensor(3762)
tensor(2584)
tensor(2684)
tensor(3129)
tensor(2824)
tensor(2285)
tensor(3134)
tensor(2710)
tensor(3854)
tensor(3067)
tensor(2627)
tensor(2701)
tensor(2663)
tensor(3191)
tensor(3483)
tensor(2553)
tensor(2726)
tensor(3209)
tensor(2884)
tensor(3391)
tensor(2483)
tensor(3073)
tensor(2960)
tensor(2804)
tensor(3019)
tensor(3371)
tensor(3089)
tensor(3381)
tensor(3271)
tensor(3330)
tensor(2743)
tensor(2772)
tensor(3225)
tensor(2907)
tensor(2240)
tensor(3245)
tensor(3304)
tensor(2874)
tensor(2971)
tensor(2696)
tensor(3084)
tensor(3104)
tensor(2713)
tensor(2965)
tensor(2338)
tensor(3038)
tensor(3252)
tensor(2044)
tensor(2861)
tensor(3404)
tensor(2725)
tensor(2994)
tensor(3373)
tensor(2523)
tensor(3230)
tensor(2517)
tensor(2993)
tensor(2680)
tensor(2320)

KeyboardInterrupt: 

In [11]:
from torch.utils.data import DataLoader
from data import StratifiedInfiniteSampler

train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    sampler=StratifiedInfiniteSampler(train_dataset, 32),
    collate_fn=lambda x: zip(*x)
)

# eval_loader = DataLoader(
#     eval_dataset,
#     batch_size=1,
#     shuffle=True
# )

In [None]:
from model import Evolver
from torch.optim import AdamW

evolver = Evolver(
    d_model=512,
    nhead=8,
    max_len=64,
    encoder_layers=6,
    decoder_layers=6,
    device='cpu'
)

optim = AdamW(evolver.parameters(), lr=3e-4)

In [None]:
from train import train_evolver

train_evolver(
    evolver, optim, None,
    train_loader, eval_loader,
    train_steps=1,
    grad_accum_steps=1,
    checkpoint_at=2,
    eval_at=1,
    num_particles=5,
    threshold=2,
    temperature=1.0,
    device='cpu',
    prefix='test-local'
)

In [None]:
from train import evaluate_evolver

evaluate_evolver(evolver, eval_loader, 'cpu')

In [None]:
!python ../train.py \
    --train ../data/ud/ud.jsonl \
    --eval ../data/ud/en_ewt-ud-dev.conllu \
    --config ../configs/ud.json \
    --prefix ud-1.0.0 \
    --device cpu

In [15]:
from data import Seq2SeqDataset
from transformers import BertTokenizer

dataset = Seq2SeqDataset.from_trajectories(
    '../data/ud/ud_train_2.0.0.jsonl',
    denoising=True,
    max_len=64,
    tokenizer=tokenizer
)

INFO:data:tokenizing input/output pairs...
INFO:data:done in 24.05 seconds!


In [16]:
from torch.utils.data import DataLoader
from data import StratifiedInfiniteSampler

train_loader = DataLoader(
    dataset,
    batch_size=128,
    sampler=StratifiedInfiniteSampler(dataset, 128),
)

In [18]:
for _, output in train_loader:
    print(torch.sum(output[0] != 0))

tensor(21)
tensor(5)
tensor(8)
tensor(6)
tensor(28)
tensor(7)
tensor(20)
tensor(3)
tensor(11)
tensor(3)
tensor(3)
tensor(31)
tensor(12)
tensor(24)
tensor(4)
tensor(4)
tensor(13)
tensor(3)
tensor(29)
tensor(10)
tensor(42)
tensor(20)
tensor(3)
tensor(7)
tensor(25)
tensor(27)
tensor(4)
tensor(10)
tensor(4)
tensor(4)
tensor(13)
tensor(12)
tensor(30)
tensor(12)
tensor(10)
tensor(13)
tensor(8)
tensor(8)
tensor(5)
tensor(11)
tensor(34)
tensor(14)
tensor(31)
tensor(35)
tensor(17)
tensor(3)
tensor(7)
tensor(6)
tensor(7)
tensor(12)
tensor(14)
tensor(26)
tensor(23)
tensor(4)
tensor(8)
tensor(6)
tensor(21)
tensor(3)
tensor(17)
tensor(7)
tensor(27)
tensor(64)
tensor(11)
tensor(6)
tensor(4)
tensor(13)
tensor(9)
tensor(28)
tensor(8)
tensor(5)
tensor(20)
tensor(23)
tensor(23)
tensor(27)
tensor(10)
tensor(11)
tensor(13)
tensor(10)
tensor(4)
tensor(7)
tensor(3)
tensor(5)
tensor(5)
tensor(3)
tensor(4)
tensor(7)
tensor(27)
tensor(34)
tensor(28)
tensor(24)
tensor(24)
tensor(8)
tensor(11)
tensor(3)
tensor(8

KeyboardInterrupt: 

In [11]:
from data import SupervisedTrajectoryDataset
from transformers import BertTokenizer

dataset = SupervisedTrajectoryDataset.from_disk(
    '../data/ud/ud_dev.jsonl',
    max_len=64,
    tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')
)

Computing alignments:   0%|          | 0/12544 [07:25<?, ?it/s]
Computing alignments:   0%|          | 0/6003 [05:27<?, ?it/s]
Tokenizing inputs: 100%|██████████| 6003/6003 [00:04<00:00, 1465.51it/s]
2024-06-25 22:20:47,760 - simalign.simalign - INFO - Initialized the EmbeddingLoader with model: bert-base-uncased
Computing alignments:   1%|▏         | 86/6003 [01:07<1:54:07,  1.16s/it]

KeyboardInterrupt: 