In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import spacy

nlp = spacy.load('en_core_web_sm')

parse = nlp('Reform number one, of course, is to ensure that the next World Bank President is not an American.')

In [None]:
vocab = {s: i for i, s in enumerate(['BOS', 'EOS'] + list(nlp.vocab.strings) + list(spacy.glossary.GLOSSARY))}

In [None]:
def depth(doc):
    root = [tok for tok in doc if tok.head == tok] [0]
    def dfs(node):
        r = 1
        for child in node.children: r = max(r, 1 + dfs(child))
        return r
    return dfs(root)

depth(parse)

In [None]:
def gen(doc):
    INS, CPY, SUB = 0, 1, 2 
    
    traj = [['_' for _ in range(len(doc))] for _ in range(2*depth(doc))]
    
    def traverse(token, depth):
        for i in range(depth, len(traj)):
            traj[i][token.i] = (token.text if (i > depth+1) else token.pos_, token.i, token.head.i)
        
        traj[depth+1][token.i] = (token.text, token.i, token.head.i)
        
        for child in token.children:
            traverse(child, depth+2)
    
    root = next(token for token in doc if token.head == token)
    traverse(root, 0)
   
    res = [[root.pos_]]
    edit_traj = [[(INS, 'BOS', -1), (INS, root.text, -1), (INS, 'EOS', -1)]]
    
    # for i, seq in enumerate(traj[1:]):
    #     cur_edits = [(CPY, -1, 0)]
           
    #     if i % 2 == 0:
    #         k = 1
    #         for t in seq:
    #             if t == '_': continue
    #             if t[1] in m:
    #                 cur_edits.append((CPY, -1, k))
    #             else:
    #                 cur_edits.append((SUB, t[0], k))
    #             m[t[1]] = k
    #             k += 1
                
    #     else:
    #         k = 1
    #         for t in seq:
    #             if t == '_': continue
    #             if t[1] in m:
    #                 cur_edits.append((CPY, -1, m[t[1]]))
    #             else:
    #                 cur_edits.append((SUB, t[0], m[t[2]]))
           
    #     res.append([t[0] for t in seq if t != '_'])
    #     cur_edits.append((CPY, -1, len(edit_traj[-1])+1))
    #     edit_traj.append(cur_edits)
    
    m = {root.i: 1}
    for i, seq in enumerate(traj[3::2]):
        cur_edits = [(CPY, -1, 0)]
        for t in seq:
            k = 1
            if t == '_': continue
            if t[1] in m:
                cur_edits.append((CPY, -1, m[t[1]]))
            else:
                cur_edits.append((SUB, t[0], m[t[2]]))
            m[t[1]] = k
            k += 1
       
        cur_edits.append((CPY, -1, len(edit_traj[-1])+1)) 
        edit_traj.append(cur_edits)
    
    return res, edit_traj

traj, edit_traj = gen(parse)

for t in traj:
    print(' '.join(t))

print()    
for e in edit_traj:
    print(e)

## figure out alignment!!!

In [None]:
from transformers import MarianTokenizer

tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
ids = tokenizer.encode(str(parse))
tokenizer.decode(ids, skip_special_tokens=True)

In [None]:
spacy_tokens = [str(t).lower() for t in parse]

marian_tokens = [
    t[1:].lower() if ord(t[0]) == 9601 else t.lower()
    for t in tokenizer.tokenize(str(parse))
]

# we can check and remove if the first character is ord = 9601

print(spacy_tokens)
print(marian_tokens)

In [None]:
import json
from datasets import load_dataset
from tqdm import tqdm

dataset = load_dataset('wmt14', 'de-en', split='test')

with open('tmp', 'w') as f:
    for pair in tqdm(dataset['translation']):
        en = pair['en']
        parse = nlp(en)
        
        spacy_tokens = [str(t).lower() for t in parse]
        
        marian_tokens = [
            t[1:].lower() if ord(t[0]) == 9601 else t.lower()
            for t in tokenizer.tokenize(str(parse))]
       
        f.write(json.dumps(spacy_tokens))
        f.write('\n')
        f.write(json.dumps(marian_tokens))
        f.write('\n\n')

## regressions

In [None]:
!python mt.py --config=./configs/wmt/ar-toy.json --local

In [None]:
!python mt.py --config=./configs/wmt/evolver-toy.json --skip --local

In [None]:
!python mt.py --config=./configs/wmt/teacher-toy.json --local

## build static vocab

In [None]:
import spacy
from tqdm import tqdm
from datasets import load_dataset

de_nlp = spacy.load('de_core_news_sm')
en_nlp = spacy.load('en_core_web_sm')

en_toks = set()
en_pos = set()
de_toks = set()
de_pos = set()

dataset = load_dataset('wmt14', 'de-en')
for split in ['train', 'test', 'validation']:
    for pair in tqdm(dataset[split]['translation'], desc=f'crawling {split}'):
        for de in de_nlp(pair['de']):
            de_toks.add(de.text)
            de_pos.add(de.pos_)
        for en in en_nlp(pair['de']):
            en_toks.add(en.text) 
            en_pos.add(en.pos_) 
    
    print(f'crawled {len(en_toks)} en_toks, {len(en_pos)} en_pos, {len(de_toks)} de_toks, {len(de_pos)} de_pos')
    vocab = en_toks.union(en_pos, de_toks, de_pos)
    with open('vocab/wmt14_de_en.vocab', 'w') as f:
        for v in tqdm(vocab, desc=f'dumping {split}'):
            f.write(v)
            f.write('\n')

## seq2seq edit dataset

In [None]:
from mt import MTEditDataset

dataset = MTEditDataset(split='test', max_len=128, buffer_size=1000)

In [None]:
src_ids, input_ids, edit_ids = dataset[0]

In [None]:
import torch
from mt import SpacyTokenizer

tok = SpacyTokenizer()

print(input_ids)
print(edit_ids)

print(tok.decode(input_ids))

In [None]:
from transformers import BertTokenizer

tok = BertTokenizer.from_pretrained('bert-base-multilingual-uncased')


In [None]:
from mt import MTDataset, BertTokenizer

dataset = MTDataset(split='test', tokenizer=BertTokenizer())

## debug transformer

In [None]:
from mt import MarianTokenizer
from datasets import load_dataset
from tqdm import tqdm

tokenizer = MarianTokenizer()

dataset = load_dataset('wmt14', 'de-en', split='test')

len_de = []
len_en = []

for thing in tqdm(dataset['translation']):
    de = thing['de']
    en = thing['en']
    len_de.append(len(tokenizer.encode(de)))
    len_en.append(len(tokenizer.encode(en)))

In [None]:
def truncate(x):
    x['translation']  = x['translation'][:4]

dataset.map(truncate)

In [None]:
from mt import TrajectoryDataset, Teacher, SpacyTokenizer
from torch.utils.data import DataLoader

tokenizer = SpacyTokenizer()

dataset = TrajectoryDataset(split='test', truncate=4, tokenizer=tokenizer)

loader = DataLoader(dataset, batch_size=1, collate_fn=dataset.collate_fn)

teacher = Teacher(
    d_model=64,
    dim_feedforward=256,
    vocab_size=tokenizer.vocab_size,
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    max_len=20)

batch = next(iter(loader))
batch

In [None]:
from itertools import islice

for thing in islice(loader, 10):
    print(thing)

In [None]:
traj = teacher.rollout({'src_ids': batch['src_ids'], 'root_ids': batch['root_ids']}, T=3)

In [None]:
traj

## debug evolver toy

truncated dataset:

```json
{
 "paragraphs": [
   {
     "de": "Wiederaufnahme der Sitzungsperiode",
     "en": "Resumption of the session"
   },
   {
     "de": "Ich erkläre die am Freitag, dem 17. Dezember unterbrochene Sitzungsperiode des Europäischen Parlaments für wiederaufgenommen, wünsche Ihnen nochmals alles Gute zum Jahreswechsel und hoffe, daß Sie schöne Ferien hatten.",
     "en": "I declare resumed the session of the European Parliament adjourned on Friday 17 December 1999, and I would like once again to wish you a happy new year in the hope that you enjoyed a pleasant festive period."
   },
   {
     "de": "Wie Sie feststellen konnten, ist der gefürchtete \"Millenium-Bug\" nicht eingetreten. Doch sind Bürger einiger unserer Mitgliedstaaten Opfer von schrecklichen Naturkatastrophen geworden.",
     "en": "Although, as you will have seen, the dreaded 'millennium bug' failed to materialise, still the people in a number of countries suffered a series of natural disasters that truly were dreadful."
   },
   {
     "de": "Im Parlament besteht der Wunsch nach einer Aussprache im Verlauf dieser Sitzungsperiode in den nächsten Tagen.",
     "en": "You have requested a debate on this subject in the course of the next few days, during this part-session."
   }
 ]
}
```

In [None]:
import torch
from mt import SpacyTokenizer, TrajectoryDataset
from torch.utils.data import DataLoader

tokenizer = SpacyTokenizer()
dataset = TrajectoryDataset(split='train', truncate=4, tokenizer=tokenizer)
loader = DataLoader(dataset, batch_size=1, collate_fn=dataset.collate_fn, shuffle=False)

batch = next(iter(loader))

tokenizer.decode(batch['src_ids'][0])

In [None]:
import torch
from mt import Evolver, Teacher1

config = {
    'd_model': 64,
    'dim_feedforward': 256,
    'nhead': 4,
    'dropout': 0.1,
    'layer_norm_eps': 1e-5,
    'decoder_layers': 4,
    'encoder_layers': 4,
    'max_len': 128,
    'bos_token_id': tokenizer.bos_token_id,
    'eos_token_id': tokenizer.eos_token_id,
    'pad_token_id': tokenizer.pad_token_id,
    'vocab_size': tokenizer.vocab_size,
    'name': None
}

evolver = Evolver(**config)
teacher = Teacher(**config)

evolver.load_state_dict(torch.load('mt_evolver_64d_4enc_4dec-20241021_200919_10000.pt', map_location='cpu')['model'])
teacher.load_state_dict(torch.load('mt_teacher_64d_4enc_4dec-20241021_200959_10000.pt', map_location='cpu')['model'])

In [None]:
traj_ids = dataset[2]['traj_ids']

for step in traj_ids:
    print(tokenizer.decode(step))

In [None]:
evolver.eval()

traj = evolver.rollout(batch, T=5, temp=0.1, verbose=True)

In [None]:
for step in traj:
    print(tokenizer.decode(step[0]))

In [None]:
teacher.eval()

traj = teacher.rollout(batch, T=5, verbose=True)

for step in traj:
    print(tokenizer.decode(step[0]))

In [None]:
from mt import evaluate

evaluate(teacher, loader, 'cpu', 2, tokenizer)

### observation: errors compound and we fall off the rails

idea 1: retrain with a simpler inductive bias

idea 2: can we tweak temperature?

In [None]:
traj = evolver.rollout()

In [None]:
import spacy
from transformers import BertTokenizer

bert_tok = BertTokenizer.from_pretrained('bert-base-uncased')

nlp = spacy.load('en_core_web_sm')

sentence = "here's a sentence with possibly more complicated tokenization."

parse = nlp(sentence)

# print(''.join(tok.text_with_ws for tok in parse))

bert_tok.tokenize(sentence)

# refresh 

In [None]:
from mt import WMT

wmt = WMT(split='validation')

In [None]:
from mt import SpacyTokenizer

tokenizer = SpacyTokenizer()

tokenizer.decode(wmt[0]['src_ids'])

In [None]:
from mt import WMTForEvolver

dataset = WMTForEvolver(split='validation')

input = 'Hello my name is John Doe.'

doc = tokenizer.en_nlp(input)

In [None]:
traj = dataset.gen(doc)

for thing in traj:
    print(thing)

In [None]:
dataset._get_output_traj(doc)