In [None]:
%load_ext autoreload
%autoreload 2

import sys
if '..' not in sys.path:
    sys.path.append('..')

In [None]:
import conllu
import numpy as np
import random
from nltk.tokenize.treebank import TreebankWordDetokenizer

detok = TreebankWordDetokenizer()

hierarchy = {
    'NOUN': ['NOUN', 'PROPN'],
    'VERB': ['VERB', 'AUX'],
    'MOD': ['ADJ', 'ADV'],
    'FUNC': ['DET', 'PRON', 'ADP', 'CCONJ', 'SCONJ', 'PART'],
    'NUM': ['NUM'],
    'OTHER': ['INTJ', 'SYM', 'PUNCT', 'X', '_']
}

to_parent = {}
for k, vs in hierarchy.items():
    to_parent[k] = k
    for v in vs:
        to_parent[v] = k

def minimize(path):
    with open(path, 'r') as f:
        sentences = conllu.parse(f.read())
       
    res = [] 
    for sent in sentences:
        cur = [(tok['form'], to_parent[tok['upostag']]) for tok in sent]
        res.append(cur)
        
    return res

def deterministic_noise(sent):
    traj = [detok.detokenize([s[0] for s in sent])]
    
    for drop in ['OTHER', 'MOD', 'FUNC', 'NUM', 'VERB', 'NOUN']:
        new_sent = []
        for tok, pos in sent:
            if pos != drop: new_sent.append((tok, pos))
        sent = new_sent
        new_seq = detok.detokenize(s[0] for s in sent)
        if new_seq != traj[-1]: traj.append(new_seq)
        
    return traj[::-1], 0

def sample_pos_noise(sent, traj_length=6):
    importance = {
        'OTHER': 1.5, 'FUNC': 2, 'NUM': 4,
        'MOD': 5, 'VERB': 8, 'NOUN': 9
    }
   
    log_prob = 0 
    traj = [detok.detokenize([s[0] for s in sent])]
   
    for i in range(traj_length):
        if not sent: break
        
        weights = [1 / importance[pos] for _, pos in sent]
        tot = sum(weights)
        weights = [w / tot for w in weights]
        
        N = max(1, len(sent) // (traj_length - i))
        to_drop = random.choices(range(len(sent)), k=min(N, len(sent)), weights=weights)
        log_prob += sum(np.log(weights[i]) for i in to_drop)
        
        sent  = [tok for j, tok in enumerate(sent) if j not in to_drop]
        traj.append(detok.detokenize(s[0] for s in sent))
        
    if traj[-1] != '': traj.append('')
    return traj[::-1], log_prob

In [None]:
import json

def create_corpus(output, input, func):
    m_sentences = minimize(input)
    with open(output, 'w') as f:
        for m_sent in m_sentences:
            output = func(m_sent)
            json.dump(output, f)
            f.write('\n')

In [None]:
create_corpus('../data/ud/ud_train_2.0.0.jsonl', '../data/ud/en_gum-ud-train.conllu', sample_pos_noise)
create_corpus('../data/ud/ud_dev_2.0.0.jsonl', '../data/ud/en_gum-ud-dev.conllu', sample_pos_noise)
create_corpus('../data/ud/ud_train_2.1.0.jsonl', '../data/ud/en_gum-ud-train.conllu', deterministic_noise)
create_corpus('../data/ud/ud_dev_2.1.0.jsonl', '../data/ud/en_gum-ud-dev.conllu', deterministic_noise)

In [None]:
!python ../dep.py ../data/ud/en_ewt-ud-train.conllu ../data/ud_train.jsonl --redundant=1 --weight=0.1

In [None]:
!python ../dep.py ../data/ud/en_ewt-ud-dev.conllu ../data/ud_dev.jsonl --redundant=3 --weight=0.1

In [None]:
from data import TrajectoryDataset
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

train_dataset = TrajectoryDataset.from_disk(
    path='../data/ud/ud_train_2.0.0.jsonl',
    max_len=64,
    tokenizer=tokenizer
)

# eval_dataset = TrajectoryDataset.from_disk(
#     path='../data/ud/ud_dev.jsonl',
#     max_len=64,
#     tokenizer=tokenizer,
#     limit=100
# )

In [None]:
# what is the average number of non-pad tokens in each batch?

import torch

for thing, _ in train_loader:
    tot = sum(torch.sum(traj != 0) for traj in thing)
    print(tot)

In [None]:
from torch.utils.data import DataLoader
from data import StratifiedInfiniteSampler

train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    sampler=StratifiedInfiniteSampler(train_dataset, 32),
    collate_fn=lambda x: zip(*x)
)

# eval_loader = DataLoader(
#     eval_dataset,
#     batch_size=1,
#     shuffle=True
# )

In [None]:
from model import Evolver
from torch.optim import AdamW

evolver = Evolver(
    d_model=512,
    nhead=8,
    max_len=64,
    encoder_layers=6,
    decoder_layers=6,
    device='cpu'
)

optim = AdamW(evolver.parameters(), lr=3e-4)

In [None]:
from train import train_evolver

train_evolver(
    evolver, optim, None,
    train_loader, eval_loader,
    train_steps=1,
    grad_accum_steps=1,
    checkpoint_at=2,
    eval_at=1,
    num_particles=5,
    threshold=2,
    temperature=1.0,
    device='cpu',
    prefix='test-local'
)

In [None]:
from train import evaluate_evolver

evaluate_evolver(evolver, eval_loader, 'cpu')

In [None]:
!python ../train.py \
    --train ../data/ud/ud.jsonl \
    --eval ../data/ud/en_ewt-ud-dev.conllu \
    --config ../configs/ud.json \
    --prefix ud-1.0.0 \
    --device cpu

In [None]:
from data import Seq2SeqDataset
from transformers import BertTokenizer

dataset = Seq2SeqDataset.from_trajectories(
    '../data/ud/ud_train_2.0.0.jsonl',
    denoising=True,
    max_len=64,
    tokenizer=tokenizer
)

In [None]:
from torch.utils.data import DataLoader
from data import StratifiedInfiniteSampler

train_loader = DataLoader(
    dataset,
    batch_size=128,
    sampler=StratifiedInfiniteSampler(dataset, 128),
)

In [None]:
for _, output in train_loader:
    print(torch.sum(output[0] != 0))

## test streaming dataset

In [None]:
from data import SupervisedTrajectoryDataset
from transformers import BertTokenizer

dataset = SupervisedTrajectoryDataset.from_disk(
    '../data/ud/ud_dev.jsonl',
    max_len=64,
    tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')
)

In [None]:
from data import SupervisedTrajectoryDataset
from transformers import BertTokenizer

train_dataset = SupervisedTrajectoryDataset.from_disk(
    path='../data/toy/toy.jsonl',
    max_len=10,
    tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')
)

In [None]:
from torch.utils.data import DataLoader
from data import StratifiedInfiniteSampler, collate_supervised

loader = DataLoader(
    train_dataset,
    batch_size=2,
    sampler=StratifiedInfiniteSampler(train_dataset, 2),
    num_workers=2,
    prefetch_factor=2
)