## Phrase Type
### Identify the type of multiword expressions

The task is to identify spans of phrases and their types, and it is designed as a sequence labeling (BIO) task. We use the STREUSLE corpus [(Schneider and Smith, 2015)](http://aclweb.org/anthology/N15-1177): Supersense-Tagged Repository of English with a Unified Semantics for Lexical Expressions. The text is from the web reviews portion of the English Web Treebank. We are going to use the MWE annotations. 

In [None]:
import random
random.seed(133)

import os
import re
import csv
import json
import spacy
import shutil
import codecs
import random
import fileinput

from nltk import agreement
from itertools import count
from collections import Counter, defaultdict

Download the STREUSLE corpus.

In [None]:
if not os.path.exists('mwe_type'):
    !mkdir mwe_type
    !wget https://github.com/nert-nlp/streusle/raw/master/streusle.conllulex -O mwe_type/streusle.conllulex

We remove the discontinuous spans and label the "weak" (compositional) MWEs as "COMP". Thanks to Nathan Schneider for the following code, that we adapted.

In [None]:
'''
Converts tagged sentences into a simpler form without gaps and/or weak links. 
If the mode is gaps+weak, 4 versions of each sentence are produced. 
Otherwise, 2 versions are produced, corresponding to liberal and conservative 
conversion rules.

Args: gaps|weak|gaps+weak TAGGED_FILE

@author: Nathan Schneider (nschneid@cs.cmu.edu)
@since: 2013-06-30
'''
def is_tag(t):
    return t in {'B','b','O','o','I','I_','I~','i','i~','i~'}

RE_TAGGING = re.compile(r'^(O|B(o|b[ii~i~]+|[II_I~])*[II_I~]+)+$')

def require_valid_tagging(tagging, simplify_gaps, simplify_weak):
    assert re.match(r'^(O|B(o|b[ii~i~]+|[II_I~])*[II_I~]+)+$', tagging)
    if simplify_gaps:
        assert re.match(r'^(O|B[II_I~]+)+$', tagging)
    if simplify_weak:
        assert re.match(r'^(O|B(o|bi+|I)*I+)+$', tagging)

IN_GAP_TAGS = {'o','b','i','i~','i~'}
I_TILDE = 'I~'
i_TILDE = 'i~'
I_BAR = 'I_'
i_BAR = 'i~'

def simplify(sentid, tokens, poses, tags, simplification='gaps+weak', policy='best'):
    '''
    For each possible conversion of the sentence under the given simplification scheme, 
    modifies the gold tags and yields the instance weight for the simplified version 
    (such that all weights are equal and sum to 1). Then restores the original 
    gold tags.
    '''
    assert simplification in {'gaps', 'weak', 'gaps+weak'}
    assert policy in {'all', 'best'}
    BEST_POLICY_RESULT = {'weak': 1, # the high-recall (liberal) policy: convert weak to strong
                          'gaps': 0, # the high-precision (conservative) policy: remove cross-gap links
                          'gaps+weak': 1}   # combination of the above
    
    gold_tags = set(tags)
    assert gold_tags<=set('OoBbIi') | {'I_','I~','i~','i~'}
    simplify_gaps = simplification in {'gaps','gaps+weak'} and not gold_tags<={'O','B','I','I_','I~'}
    simplify_weak = simplification in {'weak','gaps+weak'} and not gold_tags<=set('OoBbIi')
    
    results = []
    
    if simplify_gaps or simplify_weak:
        tt = list(tags) # output tags
        if simplify_gaps:
            assert 'o' in gold_tags or 'b' in gold_tags
            # conservative: remove gaps from gappy expressions
            for i,orig in enumerate(tags):
                if tt[i] in IN_GAP_TAGS:
                    if tt[i-1] not in IN_GAP_TAGS:
                        if tt[i-1]=='B':
                            tt[i-1] = 'O'
                    #tt[i] = 'O'
                    if tt[i+1] not in IN_GAP_TAGS:
                        # a strong or weak I tag
                        tt[i+1] = 'B' if i+2<len(tags) else 'O'
                elif tt[i] in {'O','B'} and i>0 and tt[i-1]=='B':
                    # we introduced a B after a gap which is actually a singleton, so remove it
                    tt[i-1] = 'O'
                    
            result = [t.upper().replace(i_TILDE, I_TILDE).replace(i_BAR, I_BAR) for t in tt]
            require_valid_tagging(''.join(result), simplify_gaps, False)
            results.append(result)
            
            tt = list(tags)
            # liberal: link across gaps (weakly if possible, preserving in-gap strong MWEs)
            for i,orig in enumerate(tags):
                if tt[i] in IN_GAP_TAGS:
                    if simplify_weak:
                        tt[i] = 'I'
                    elif tt[i]=='b':
                        if tt[i-1]=='B':
                            tt[i] = I_TILDE
                    else:
                        if tt[i]==i_TILDE and tags[i-1]=='b':
                            # weak link within a gap: merge it with the cross-gap weak MWE
                            tt[i-1] = I_TILDE
                        if tt[i]!=i_BAR:
                            tt[i] = I_TILDE
                elif not simplify_weak and tt[i]==I_BAR and tags[i-1] in IN_GAP_TAGS:
                    # post-gap continuation should be weak
                    tt[i] = I_TILDE
            
            
            result = [t.upper().replace(i_TILDE, I_TILDE).replace(i_BAR, I_BAR) for t in tt]
            require_valid_tagging(''.join(result), simplify_gaps, False)
            results.append(result)
        else:
            results.append(list(tags))
            if 'gaps' in simplification:
                results.append(list(tags))
        
        assert len(results)==(2 if 'gaps' in simplification else 1),results
        
        if simplify_weak:
            partial_results = results
            results = []
            for partial_result in partial_results:
                # conservative: remove weak links
                tt = list(partial_result)
                
                # - convert weak I's to B's, and strong I's to plain I's
                tt[:] = [{i_TILDE: 'b', I_TILDE: 'B', i_BAR: 'i', I_BAR: 'I'}.get(t,t) for t in tt]
                # - remove trans-gap weak links
                for i,t in enumerate(tt):
                    if t=='B' and i>0 and tt[i-1] in {'o','i'}:	# B after gap. the trans-gap link was weak, so everything inside the gap becomes no longer gappy
                        j = i-1
                        while tt[j].islower():
                            tt[j] = tt[j].upper()
                            j -= 1
                        if tt[j]=='B':
                            tt[j] = 'O'
                # - remove singleton B's (B must be followed by I or a gap)
                for i,t in enumerate(tt):
                    if t=='B':
                        if i+1==len(tt):    # B at end of sequence
                            tt[i] = 'O'
                        elif i>0 and tt[i-1]=='b':
                            assert False
                        elif tt[i+1] in {'O', 'B'}: # singleton B
                            tt[i] = 'O'
                    elif t=='b':
                        if tt[i+1]!='i':
                            tt[i] = 'o'
                # TODO: weak trans-gap link
                require_valid_tagging(''.join(tt), simplify_gaps, simplify_weak)
                results.append(tt)
                
                # liberal: convert weak links to strong links
                tt = list(partial_result)
                for i,t in enumerate(tt):
                    if t in {i_TILDE, i_BAR}:
                        tt[i] = 'i'
                    elif t in {I_TILDE, I_BAR}:
                        tt[i] = 'I'

                require_valid_tagging(''.join(tt), simplify_gaps, simplify_weak)
                results.append(tt)
        elif 'weak' in simplification:
            results.append(list(results[0]))
            results.append(list(results[1]))
            
        assert len(results)==(4 if simplification=='gaps+weak' else 2),(simplify_gaps,simplify_weak,results)

    else:       # nothing to do for this sentence
        for x in range(4 if simplification=='gaps+weak' else 2):
            results.append(list(tags))
    
    if policy=='best':
        results = [results[BEST_POLICY_RESULT[simplification]]]
    
    for result in results:
        assert len(tokens)==len(result)
        
    return results

Simplify and convert to JSON format.

In [None]:
sent = []

def add_sent_to_json(sent, f_out):
    words, poses, prev_tags, labels = zip(*sent)
    new_tags = simplify(sentid, words, poses, prev_tags)[0]
    tags = ['-'.join((t, l))
            if l != '' and t != 'O' else t 
            for t, l in zip(new_tags, labels)]
    
    # Remove empty, rare, and '!!@' ("needs to be manually corrected") labels 
    labels_to_remove = {'@', 'NUM', 'INF', 'INTJ'}
    new_tags = []
    i = 0
    repl_next = False
    
    while i < len(tags):
        # B with @ or with no label
        if tags[i] == 'B' or (any([c in tags[i] for c in labels_to_remove]) and 'B' in tags[i]):
            new_tags.append('O')
            repl_next = True
            
        # I after B with @
        elif repl_next and 'I' in tags[i]: 
            new_tags.append('O')
            
        # I with @ after regular B
        elif '@' in tags[i] and 'I' in tags[i]:
            new_tags.append('O')
            for prev in range(len(new_tags) - 1, 0, -1):
                orig_tag = new_tags[prev]
                new_tags[prev] = 'O'
                if orig_tag.startswith('B'):
                    break 
        
        # Regular tag
        else:
            new_tags.append(tags[i])
            repl_next = False
        
        i += 1
    
    f_out.write(json.dumps({'sentence_words': words, 'sentence_tags': new_tags}) + '\n')
    

sentences_skipped = 0
with codecs.open('mwe_type/streusle.jsonl', 'w', 'utf-8') as f_out:
    with codecs.open('mwe_type/streusle.conllulex', 'r', 'utf-8') as f_in:
        for line in f_in:  
            if not line.strip():
                if sent:
                    try:
                        add_sent_to_json(sent, f_out)
                    except Exception as e:
                        sentences_skipped += 1
                        pass
                    sent = []
                continue

            if line.startswith('#'):
                if 'streusle_sent_id' in line:
                    sentid = line.strip().split('=')[1][1:]
                continue

            data = line.strip().split('\t')

            if len(data) == 19:
                _, tok, lemma, _, pos = data[:5]
                pos = [lemma, pos]
                tag = data[-1]

                # Separate the specific label (MWE type) from the BIO tag
                label = ''
                if '-' in tag:
                    tag, label = tag.split('-')[:2]
                   
                    # Weak MWE - we want the label to be just for the first word in the expression
                    if 'I' in tag.upper() and label != '':
                        for prev in range(len(sent) - 1, 0, -1):
                            if 'B' in sent[prev][-2].upper():
                                sent[prev][-1] = 'COMP'
                                break
                        label = ''
                
                # B
                if not tag.upper().startswith('O') and not tag.upper().startswith('I'):
                    # Phrasal verb
                    if label.startswith('V.'):
                        label = label.split('.')[1]
                    # Regular label
                    else:
                        label = label.split('.')[0]
                    
                # O / I - no label
                else:
                    label = ''

                sent.append([tok, pos, tag, label])

    # Last sentence
    if sent:
        try:
            add_sent_to_json(sent, f_out)
        except Exception as e:
            sentences_skipped += 1
            pass
        
print(f'Sentences skipped: {sentences_skipped}')

In [None]:
with codecs.open('mwe_type/streusle.jsonl', 'r', 'utf-8') as f_in:
    dataset = [json.loads(line.strip()) for line in f_in]

all_tags = [t for example in dataset for t in example['sentence_tags']]
tagset = set(all_tags)
print(f'Number of tags: {len(tagset)}')

explanation = {'VPC': 'Verb-Particle Construction', 
               'LVC': 'Light Verb Construction',
               'IAV': 'Inherently Adpositional Verb', 
               'VID': 'Verbal Idiom',
               'N': 'Noun, common or proper',
               'POSS.PRON': 'Possessive Pronoun',
               'PRON': 'Non-possessive Pronoun', 
               'POSS': 'Possessive Clitic', 
               'V': 'Full verb or copula', 
               'AUX' : 'Auxiliary', 
               'P': 'Adposition',
               'PP': 'Prepositional Phrase MWE', 
               'INF': 'Infinitive marker', 
               'DISC': 'Discourse / Pragmatic expression',
               'ADJ': 'Adjective', 
               'ADV': 'Adverb', 
               'DET': 'Determiner', 
               'CCONJ': 'Conjunction', 
               'SCONJ': 'Subordinating Conjunction', 
               'INTJ': 'Interjection', 
               'NUM': 'Numeral', 
               'SYM': 'Symbol', 
               'PUNCT': 'Punctuation', 
               'X': 'Other',
               'COMP': 'Weak Compositional MWE'
              }

print('\n'.join((['\t'.join((tag, explanation.get(tag.split('-')[-1], tag), str(count))) 
                  for tag, count in Counter(all_tags).most_common()])))

Split to train, test, and validation.

In [None]:
print('Dataset size: {}'.format(len(dataset)))

train_size = 8 * len(dataset) // 10
val_size = test_size = len(dataset) // 10
random.shuffle(dataset)
train = dataset[:train_size]
test = dataset[train_size:train_size + test_size] 
val = dataset[train_size + test_size + 1:]

# Remove examples that repeat across sets
train_sents = set([' '.join(e['sentence_words']) for e in train])
val_sents = set([' '.join(e['sentence_words']) for e in val])
test_sents = set([' '.join(e['sentence_words']) for e in test])

train = [e for e in train if ' '.join(e['sentence_words']) not in val_sents.union(test_sents)]
val = [e for e in val if ' '.join(e['sentence_words']) not in train_sents.union(test_sents)]
test = [e for e in test if ' '.join(e['sentence_words']) not in val_sents.union(train_sents)]

train_sents = set([' '.join(e['sentence_words']) for e in train])
val_sents = set([' '.join(e['sentence_words']) for e in val])
test_sents = set([' '.join(e['sentence_words']) for e in test])

print('Train set size: {}, test set size: {}, validation set size: {}'.format(len(train), len(test), len(val)))
assert(len(train_sents.intersection(val_sents)) == 0)
assert(len(train_sents.intersection(test_sents)) == 0)
assert(len(test_sents.intersection(val_sents)) == 0)

data_dir = '../diagnostic_classifiers/data/mwe_type/streusle'
if not os.path.exists(data_dir):
    os.mkdir(data_dir)

for s, filename in zip([train, test, val], ['train', 'test', 'val']):
    with codecs.open(os.path.join(data_dir, '{}.jsonl'.format(filename)), 'w', 'utf-8') as f_out:
        for e in s:
            f_out.write(json.dumps(e) + '\n')