In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import spacy
from datasets import load_dataset
from transformers import MarianTokenizer
from tqdm import tqdm
from collections import defaultdict

we need to make sure the character-based span matching will _actually_ match correctly

In [None]:
def get_spans(text, tokens):
    spans = []
    i = 0
    for token in tokens:
        while i < len(text) and text[i].isspace(): i += 1
        if i < len(text):
            s = i
            i = text.find(token, i) + len(token)
            spans.append((s, i))
    
    assert len(spans) == len(tokens), f'number of spans does not match number of tokens for: {text}'
    return spans

def get_alignments(text, spacy_tokens, marian_tokens):
    spacy_spans = get_spans(text, spacy_tokens)
    marian_spans = get_spans(text, marian_tokens)
    alignment = {} # map marian_tokens[i] to spacy_tokens[j]
    best_overlap = defaultdict(int) # track max overlap
    
    # just bruteforce check (who needs DP?)
    for i, marian_span in enumerate(marian_spans):
        for j, spacy_span in enumerate(spacy_spans):
            overlap = max(0, min(spacy_span[1], marian_span[1]) - max(spacy_span[0], marian_span[0]))
            if overlap > 0 and overlap > best_overlap[i]:
                alignment[i] = j
                best_overlap[i] = overlap
                
    for i, tok in enumerate(marian_tokens):
        if tok == '': alignment[i] = alignment[i+1] # word break
        if tok == '<unk>': alignment[i] = 1 + alignment[i-1] # unk
    
    assert len(alignment) == len(marian_tokens), f'did not find a complete alignment for: {text}'
    return alignment

marian = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
# nlp = spacy.load('en_core_web_sm')
nlp = spacy.load('de_core_news_sm')
wmt = load_dataset('wmt14', 'de-en', split='validation')

def get_spacy_tokens(text):
    return [token.text for token in nlp(text)]

def get_marian_tokens(text):
    return [marian.decode(id) for id in marian(text)['input_ids'][:-1]]

for i, example in enumerate(tqdm(wmt['translation'])):
    # text = example['en']
    text = example['de']
    spacy_tokens = get_spacy_tokens(text)
    marian_tokens = get_marian_tokens(text)
    alignments = get_alignments(text, spacy_tokens, marian_tokens)

In [47]:
class Parser:

    def __init__(self):
        self.en_nlp = spacy.load('en_core_web_sm')
        self.de_nlp = spacy.load('de_core_news_sm')
        self.marian = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
    
    def get_spans(self, text, tokens):
        spans = []
        i = 0
        for token in tokens:
            while i < len(text) and text[i].isspace(): i += 1
            if i < len(text):
                s = i
                i = text.find(token, i) + len(token)
                spans.append((s, i))
        
        assert len(spans) == len(tokens), f'number of spans does not match number of tokens for: {text}'
        return spans

    def get_alignment(self, text, spacy_tokens, marian_tokens, spacy_spans=None, marian_spans=None):
        if spacy_spans is None: spacy_spans = self.get_spans(spacy_tokens) 
        if marian_tokens is None: marian_spans = self.get_spans(marian_tokens) 
        
        alignment = {} # map marian_tokens[i] to spacy_tokens[j]
        best_overlap = defaultdict(int) # track max overlap
        
        # just bruteforce check (who needs DP?)
        for i, marian_span in enumerate(marian_spans):
            for j, spacy_span in enumerate(spacy_spans):
                overlap = max(0, min(spacy_span[1], marian_span[1]) - max(spacy_span[0], marian_span[0]))
                if overlap > 0 and overlap > best_overlap[i]:
                    alignment[i] = j
                    best_overlap[i] = overlap
                    
        for i, tok in enumerate(marian_tokens):
            if tok == '': alignment[i] = alignment[i+1] # word break
            if tok == '<unk>': alignment[i] = 1 + alignment[i-1] # unk
        
        assert len(alignment) == len(marian_tokens), f'did not find a complete alignment for: {text}'
        return alignment

    def get_spacy_tokens(self, text, lang):
        return [token.text for token in (self.en_nlp if lang == 'en' else self.de_nlp)(text)]

    def get_marian_tokens(self, text):
        return [self.marian.decode(id) for id in self.marian(text)['input_ids'][:-1]]
    
    def align(self, text, lang):
        spacy_tokens = self.get_spacy_tokens(text, lang)
        spacy_spans = self.get_spans(text, spacy_tokens)
        normalized_marian_tokens = self.get_marian_tokens(text)
        marian_spans = self.get_spans(text, normalized_marian_tokens)

        raw_marian_tokens = self.marian.tokenize(text)
        doc = (self.en_nlp if lang == 'en' else self.de_nlp)(text)
        alignment = self.get_alignment(text, spacy_tokens, normalized_marian_tokens, spacy_spans, marian_spans)
    
        # reverse map spacy to marian tokens
        reverse = defaultdict(list)
        for k, v in alignment.items(): reverse[v].append(k)

        # get original root idx
        root_i = next(i for i, tok in enumerate(doc) if tok.head == tok)

        seq = []
        for i, text in enumerate(raw_marian_tokens):
            spacy_tok = doc[alignment[i]]
            seq.append({'text': text, 'pos': spacy_tok.pos_, 'i': i, 'is_head': alignment[i] == root_i})
        
        for i, _ in enumerate(seq):
            spacy_children = doc[alignment[i]].children
            seq[i]['children'] = []
            for child in spacy_children:
                seq[i]['children'].extend(seq[i] for i in reverse[child.i])
            
            ### heuristic for lineage 
            # spacy_parent = doc[alignment[i]].head.i
            # inferred_parent = reverse[spacy_parent][0]
            seq[i]['par'] = reverse[doc[alignment[i]].head.i][0]
            
        return seq

In [None]:
from mt import WMTForEvolver

parser = Parser()
wmt = WMTForEvolver(split='validation')

traj = wmt._get_short_input_traj(parser.align(text, 'en'))

for thing in traj:
    print(thing)

In [None]:
text = 'The new restrictions disproportionately affect young people, minorities and people with low incomes.'

seq = parser.align(text, 'en')
for thing in seq:
    print(thing['text'], '<-', seq[thing['par']]['text'])

▁The <- ▁restrictions
▁new <- ▁restrictions
▁restrictions <- ▁affect
▁ <- ▁affect
disproportionate <- ▁affect
ly <- ▁affect
▁affect <- ▁affect
▁young <- ▁people
▁people <- ▁affect
, <- ▁people
▁minorities <- ▁people
▁and <- ▁minorities
▁people <- ▁minorities
▁with <- ▁affect
▁low <- ▁incomes
▁incomes <- ▁with
. <- ▁affect


In [42]:
en_nlp = spacy.load('en_core_web_sm')

for token in en_nlp(text):
    # print({'text': token.text, 'pos': token.pos_, 'i': token.i, 'is_head': token == token.head})
    print(token.text, '<-', token.head.text)

The <- restrictions
new <- restrictions
restrictions <- affect
disproportionately <- affect
affect <- affect
young <- people
people <- affect
, <- people
minorities <- people
and <- minorities
people <- minorities
with <- affect
low <- incomes
incomes <- with
. <- affect
