* 

In [18]:
actions = ['shift', 'right-arc', 'left-arc', 'reduce']
elemnt = ('parent', 'child')

In [32]:
from collections import OrderedDict
from conllu import parse
from enum import Enum


def get_data(path):
    with open(path, "r") as f:
        data = f.read()

    trees = parse(data)
    return trees

# debug mode
"""
for tree in trees:
    for node in tree:
        head = node["head"]
        try:
            print("{} <-- {}".format(node["form"],
                                     tree[head - 1]["form"]
                                     if head > 0 else "root"))
        except TypeError:
            pass
"""

In [20]:
print(trees[1][2])

OrderedDict([('id', 3), ('form', 'у'), ('lemma', 'у'), ('upostag', 'ADP'), ('xpostag', 'Spsg'), ('feats', OrderedDict([('Case', 'Gen')])), ('head', 4), ('deprel', 'case'), ('deps', [('case', 4)]), ('misc', OrderedDict([('Id', '000k'), ('LTranslit', 'u'), ('Translit', 'u')]))])


In [36]:
from pprint import pprint as pp

def shift(stack, queue):
    stack.append(queue.pop(0))
    return stack, queue

def right_arc(stack, queue, dep_arcs):
    dep_arcs.append((stack[-1]['id'], queue[0]['id']))
    stack, queue = shift(stack, queue)
    return stack, queue, dep_arcs

def left_arc(stack, queue, dep_arcs):
    dep_arcs.append((queue[0]['id'], stack[-1]['id']))
    stack.pop(-1)
    return stack, queue, dep_arcs

def reduce(stack):
    stack.pop()
    return stack

def oracle(stack, queue, dep_arcs):
    global ROOT
    
    if stack[-1] and not queue[0]:
        return 'reduce'
    elif stack[-1]['head'] == queue[0]['id']:
        return 'left_arc'
    elif queue[0]['head'] == stack[-1]['id']:
        return 'right_arc'
    elif stack[-1]["id"] in [i[0] for i in dep_arcs] and \
         (queue[0]["head"] < stack[-1]["id"] or \
         [s for s in stack if s["head"] == queue[0]["id"]]):
        return 'reduce'    
    else:
        return 'shift'

def feature_extract(stack, queue, dep_arcs):
    features = {}
    
    # stk_0: form, lemma, postag, feats
    features['stk_0_form'] = stack[-1]['form']
    features['stk_0_lemma'] = stack[-1]['lemma']
    features['stk_0_postag'] = stack[-1]['upostag']
    
    if stack[-1]['feats'] != None:
        for feat in stack[-1]['feats'].keys():
            features['stk_0_'+feat] = stack[-1]['feats'][feat]
    
    # queue_0: form, lemma, postag, feats
    features['que_0_form'] = queue[0]['form']
    features['que_0_lemma'] = queue[0]['lemma']
    features['que_0_postag'] = queue[0]['upostag']
    
    if queue[0]['feats'] != None:
        for feat in queue[0]['feats'].keys():
            features['que_0_'+feat] = queue[0]['feats'][feat]
    
    # queue_1: form, postag
    try:
        features['que_1_form'] = queue[1]['form']
        features['que_1_postag'] = queue[1]['upostag']
    except IndexError:
        pass

    # queue_2: postag
    try:
        features['que_2_postag'] = queue[2]['upostag']
    except IndexError:
        pass
    
    # queue_3: postag
    try:
        features['que_3_postag'] = queue[3]['upostag']
    except IndexError:
        pass
    
    return features
      
    
    
ROOT = OrderedDict([('id', 0), ('form', 'ROOT'), ('lemma', 'ROOT'), ('upostag', 'ROOT'),
                    ('xpostag', None), ('feats', None), ('head', None), ('deprel', None),
                    ('deps', None), ('misc', None)])

def dep_parse(tree):
    """Parse dependencies for one sentence (tree)"""
    
    global ROOT
    stack = [ROOT]
    queue = tree[:]
    dep_arcs = []
    
    x, y = [], []
    
    while len(stack) > 0 and len(queue) > 0:
        
        features = feature_extract(stack, queue, dep_arcs)        
        
        try:
            action = oracle(stack, queue, dep_arcs)
        except TypeError:
            print(stack)
            print(queue)
            break
        
        x.append(features)
        y.append(action)
        
        if action == 'reduce':
            stack = reduce(stack)
        elif action == 'left_arc':
            stack, queue, dep_arcs = left_arc(stack, queue, dep_arcs)
        elif action == 'right_arc':
            stack, queue, dep_arcs = right_arc(stack, queue, dep_arcs)
        elif action == 'shift':
            stack, queue = shift(stack, queue)
    
    # return dep_arcs
    return x, y


def filter_trees(trees): 
    """Delete nodes from a tree where id is not an integer"""
    return [[token for token in tree if type(token['id']) == int] for tree in trees]


def prepare_data(path):

    X, Y = [], []
    trees = filter_trees(get_data(path))
    
    for tree in trees:
        x, y = dep_parse(tree)
        X.extend(x)
        Y.extend(y)

    assert len(X) == len(Y)
    
    return X, Y
    
train_path = "./corpus/uk_iu-ud-train.conllu"
test_path = "./corpus/uk_iu-ud-test.conllu"

X_train, Y_train = prepare_data(train_path)
X_test, Y_test = prepare_data(test_path)


In [37]:
# Vectorize features

from sklearn.feature_extraction import DictVectorizer

def vectorize(X_train, X_test):
    
    print('\nVectorizing...')
    v = DictVectorizer(sparse=True)
    
    vectorizer = v.fit(X_train)
    v_train = vectorizer.transform(X_train)
    v_test = vectorizer.transform(X_test)
    
    return v_train, v_test, vectorizer

X_train, X_test, vectorizer = vectorize(X_train, X_test)


Vectorizing...


In [39]:
# Try a different classifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report

rfc = RandomForestClassifier(n_estimators=20, criterion='entropy', max_depth=None, n_jobs=-1, verbose=True)
rfc.fit(X_train, Y_train)

predicted = rfc.predict(X_test)
print(classification_report(Y_test, predicted))

[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done  20 out of  20 | elapsed:  5.1min finished
[Parallel(n_jobs=4)]: Using backend ThreadingBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done  20 out of  20 | elapsed:    0.4s finished


              precision    recall  f1-score   support

    left_arc       0.85      0.93      0.89      7346
      reduce       0.63      0.42      0.50      2552
   right_arc       0.79      0.82      0.80      5935
       shift       0.86      0.86      0.86     10336

   micro avg       0.83      0.83      0.83     26169
   macro avg       0.78      0.76      0.76     26169
weighted avg       0.82      0.83      0.82     26169

