In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import datasets
import spacy

dataset = datasets.load_dataset("wmt14", "cs-en", split="train")

In [5]:
sent = 'The world needs the World Bank a lot more than it needs another condominium.'

nlp = spacy.load('en_core_web_sm')

parse = nlp(sent)

In [28]:
vocab = {s: i for i, s in enumerate(['BOS', 'EOS'] + list(nlp.vocab.strings) + list(spacy.glossary.GLOSSARY))}

In [None]:
def depth(doc):
    root = [tok for tok in doc if tok.head == tok] [0]
    def dfs(node):
        r = 1
        for child in node.children: r = max(r, 1 + dfs(child))
        return r
    return dfs(root)

depth(parse)

In [None]:
def gen(doc):
    INS, CPY, SUB = 0, 1, 2 
    
    traj = [['_' for _ in range(len(doc))] for _ in range(2*depth(doc))]
    
    def traverse(token, depth):
        for i in range(depth, len(traj)):
            traj[i][token.i] = (token.text if (i > depth+1) else token.pos_, token.i, token.head.i)
        
        traj[depth+1][token.i] = (token.text, token.i, token.head.i)
        
        for child in token.children:
            traverse(child, depth+2)
    
    root = next(token for token in doc if token.head == token)
    traverse(root, 0)
    
    edit_traj = [[(INS, 'BOS', -1), (INS, root.pos_, -1), (INS, 'EOS', -1)]]
    m = {}
    
    for i, seq in enumerate(traj[1:]):
        cur_edits = [(CPY, -1, 0)]
           
        if i % 2 == 0:
            k = 1
            for t in seq:
                if t == '_': continue
                if t[1] in m:
                    cur_edits.append((CPY, -1, k))
                else: 
                    cur_edits.append((SUB, t[0], k))
                m[t[1]] = k
                k += 1
                
        else:
            k = 1
            for t in seq:
                if t == '_': continue
                if t[1] in m:
                    cur_edits.append((CPY, -1, m[t[1]]))
                else:
                    cur_edits.append((SUB, t[0], m[t[2]]))
                    
        cur_edits.append((CPY, -1, len(edit_traj[-1])+1))
        edit_traj.append(cur_edits)
    
    return traj, edit_traj

traj, edit_traj = gen(parse)

for t in traj:
    print(' '.join(a[0] for a in t))

print()    
for e in edit_traj:
    print(e)

In [11]:
from mt import MTDataset
from torch.utils.data import DataLoader

dataset = MTDataset(split='train', max_len=128)

In [None]:
from mt import Tokenizer

tokenizer = Tokenizer()

tokenizer.vocab_size

In [None]:
with open('vocab/wmt14_de_en.vocab', 'r') as f:
    print(len(f.readlines()))

In [None]:
m = max(v for _, v in tokenizer.vocab.items())
print(m)

In [None]:
from datasets import load_dataset
from tqdm import tqdm

dataset = load_dataset('wmt14', 'de-en')

de_len = []
en_len = []

for pair in tqdm(dataset['train']['translation'][:5000]):
    de_len.append(len(tokenizer.encode(pair['de'], lang='de')))
    en_len.append(len(tokenizer.encode(pair['en'], lang='en')))

In [None]:
import matplotlib.pyplot as plt

plt.hist(de_len)
plt.show()

In [None]:
plt.hist(en_len)
plt.show()

In [None]:
!python mt.py --config=./configs/wmt/ar-toy.json --local

## build static vocab

In [None]:
import spacy
from tqdm import tqdm
from datasets import load_dataset

de_nlp = spacy.load('de_core_news_sm')
en_nlp = spacy.load('en_core_web_sm')

en_toks = set()
en_pos = set()
de_toks = set()
de_pos = set()

dataset = load_dataset('wmt14', 'de-en')
for split in ['train', 'test', 'validation']:
    for pair in tqdm(dataset[split]['translation'], desc=f'crawling {split}'):
        for de in de_nlp(pair['de']):
            de_toks.add(de.text)
            de_pos.add(de.pos_)
        for en in en_nlp(pair['de']):
            en_toks.add(en.text) 
            en_pos.add(en.pos_) 
    
    print(f'crawled {len(en_toks)} en_toks, {len(en_pos)} en_pos, {len(de_toks)} de_toks, {len(de_pos)} de_pos')
    vocab = en_toks.union(en_pos, de_toks, de_pos)
    with open('vocab/wmt14_de_en.vocab', 'w') as f:
        for v in tqdm(vocab, desc=f'dumping {split}'):
            f.write(v)
            f.write('\n')