In [1]:
%load_ext autoreload
%autoreload 2

In [30]:
import spacy

nlp = spacy.load('en_core_web_sm')

parse = nlp('Reform number one, of course, is to ensure that the next World Bank President is not an American.')

In [31]:
vocab = {s: i for i, s in enumerate(['BOS', 'EOS'] + list(nlp.vocab.strings) + list(spacy.glossary.GLOSSARY))}

In [None]:
def depth(doc):
    root = [tok for tok in doc if tok.head == tok] [0]
    def dfs(node):
        r = 1
        for child in node.children: r = max(r, 1 + dfs(child))
        return r
    return dfs(root)

depth(parse)

In [None]:
def gen(doc):
    INS, CPY, SUB = 0, 1, 2 
    
    traj = [['_' for _ in range(len(doc))] for _ in range(2*depth(doc))]
    
    def traverse(token, depth):
        for i in range(depth, len(traj)):
            traj[i][token.i] = (token.text if (i > depth+1) else token.pos_, token.i, token.head.i)
        
        traj[depth+1][token.i] = (token.text, token.i, token.head.i)
        
        for child in token.children:
            traverse(child, depth+2)
    
    root = next(token for token in doc if token.head == token)
    traverse(root, 0)
   
    res = [[root.pos_]]
    edit_traj = [[(INS, 'BOS', -1), (INS, root.pos_, -1), (INS, 'EOS', -1)]]
    m = {}
    
    for i, seq in enumerate(traj[1:]):
        cur_edits = [(CPY, -1, 0)]
           
        if i % 2 == 0:
            k = 1
            for t in seq:
                if t == '_': continue
                if t[1] in m:
                    cur_edits.append((CPY, -1, k))
                else:
                    cur_edits.append((SUB, t[0], k))
                m[t[1]] = k
                k += 1
                
        else:
            k = 1
            for t in seq:
                if t == '_': continue
                if t[1] in m:
                    cur_edits.append((CPY, -1, m[t[1]]))
                else:
                    cur_edits.append((SUB, t[0], m[t[2]]))
           
        res.append([t[0] for t in seq if t != '_'])         
        cur_edits.append((CPY, -1, len(edit_traj[-1])+1))
        edit_traj.append(cur_edits)
    
    return res, edit_traj

traj, edit_traj = gen(parse)

for t in traj:
    print(' '.join(t))

print()    
for e in edit_traj:
    print(e)

## figure out alignment!!!

In [None]:
from transformers import MarianTokenizer

tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
ids = tokenizer.encode(str(parse))
tokenizer.decode(ids, skip_special_tokens=True)

In [None]:
spacy_tokens = [str(t).lower() for t in parse]

marian_tokens = [
    t[1:].lower() if ord(t[0]) == 9601 else t.lower()
    for t in tokenizer.tokenize(str(parse))
]

# we can check and remove if the first character is ord = 9601

print(spacy_tokens)
print(marian_tokens)

In [None]:
import json
from datasets import load_dataset
from tqdm import tqdm

dataset = load_dataset('wmt14', 'de-en', split='test')

with open('tmp', 'w') as f:
    for pair in tqdm(dataset['translation']):
        en = pair['en']
        parse = nlp(en)
        
        spacy_tokens = [str(t).lower() for t in parse]
        
        marian_tokens = [
            t[1:].lower() if ord(t[0]) == 9601 else t.lower()
            for t in tokenizer.tokenize(str(parse))]
       
        f.write(json.dumps(spacy_tokens))
        f.write('\n')
        f.write(json.dumps(marian_tokens))
        f.write('\n\n')

In [None]:
!python mt.py --config=./configs/wmt/ar-toy.json --local

In [None]:
!python mt.py --config=./configs/wmt/evolver-toy.json --local

In [None]:
!python mt.py --config=./configs/wmt/teacher-toy.json --local

In [3]:
from mt import MTTransformer, BertTokenizer, MTEditDataset, SpacyTokenizer
from torch.utils.data import DataLoader

tokenizer = SpacyTokenizer()

model = MTTransformer(
    d_model=512, dim_feedforward=2048, nhead=8, dropout=0, layer_norm_eps=1e-5, encoder_layers=2, decoder_layers=2,
    vocab_size=tokenizer.vocab_size, max_len=256, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id,
    name=None
)

dataset = MTEditDataset(split='test', tokenizer=tokenizer)

loader = DataLoader(dataset, batch_size=2, collate_fn=dataset.collate_fn)

In [None]:
from tqdm import tqdm

for _ in tqdm(loader):
    continue

In [4]:
from mt import MTTransformer, BertTokenizer, MTDataset
from torch.utils.data import DataLoader

tokenizer = BertTokenizer()

model = MTTransformer(
    d_model=256, dim_feedforward=1024, nhead=8, dropout=0, layer_norm_eps=1e-5, encoder_layers=2, decoder_layers=2,
    vocab_size=tokenizer.vocab_size, max_len=256, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id,
    name=None
)

dataset = MTDataset(split='test', tokenizer=tokenizer)

loader = DataLoader(dataset, batch_size=2, collate_fn=dataset.collate_fn)

In [5]:
batch = next(iter(loader))

## build static vocab

In [None]:
import spacy
from tqdm import tqdm
from datasets import load_dataset

de_nlp = spacy.load('de_core_news_sm')
en_nlp = spacy.load('en_core_web_sm')

en_toks = set()
en_pos = set()
de_toks = set()
de_pos = set()

dataset = load_dataset('wmt14', 'de-en')
for split in ['train', 'test', 'validation']:
    for pair in tqdm(dataset[split]['translation'], desc=f'crawling {split}'):
        for de in de_nlp(pair['de']):
            de_toks.add(de.text)
            de_pos.add(de.pos_)
        for en in en_nlp(pair['de']):
            en_toks.add(en.text) 
            en_pos.add(en.pos_) 
    
    print(f'crawled {len(en_toks)} en_toks, {len(en_pos)} en_pos, {len(de_toks)} de_toks, {len(de_pos)} de_pos')
    vocab = en_toks.union(en_pos, de_toks, de_pos)
    with open('vocab/wmt14_de_en.vocab', 'w') as f:
        for v in tqdm(vocab, desc=f'dumping {split}'):
            f.write(v)
            f.write('\n')

## seq2seq edit dataset

In [None]:
from mt import MTEditDataset

dataset = MTEditDataset(split='test', max_len=128, buffer_size=1000)

In [18]:
src_ids, input_ids, edit_ids = dataset[0]

In [None]:
import torch
from mt import SpacyTokenizer

tok = SpacyTokenizer()

print(input_ids)
print(edit_ids)

print(tok.decode(input_ids))

In [None]:
from transformers import BertTokenizer

tok = BertTokenizer.from_pretrained('bert-base-multilingual-uncased')


In [2]:
from mt import MTDataset, BertTokenizer

dataset = MTDataset(split='test', tokenizer=BertTokenizer())

## debug transformer

In [19]:
from mt import MarianTokenizer
from datasets import load_dataset
from tqdm import tqdm

tokenizer = MarianTokenizer()

dataset = load_dataset('wmt14', 'de-en', split='test')

len_de = []
len_en = []

for thing in tqdm(dataset['translation']):
    de = thing['de']
    en = thing['en']
    len_de.append(len(tokenizer.encode(de)))
    len_en.append(len(tokenizer.encode(en)))

In [None]:
def truncate(x):
    x['translation']  = x['translation'][:4]

dataset.map(truncate)

In [37]:
from mt import MTDataset
from torch.utils.data import DataLoader

dataset = MTDataset(split='test', truncate=4)

loader = DataLoader(dataset, batch_size=4, collate_fn=dataset.collate_fn)

In [None]:
import matplotlib.pyplot as plt

plt.hist(len_de)
plt.hist(len_en)