In [2]:
import re
from collections import Counter, defaultdict
from conllu import parse

import nltk
from nltk.corpus import conll2000
conll2000.ensure_loaded()

In [3]:
#accuracy counter
def accuracy(test_sents, postagger):
    errors = 0
    length = 0
    for sent in test_sents:
        length += len(sent)
        sent, real_tags = zip(*sent)
        my_tags = postagger.tag(sent)
        for i in range(len(my_tags)):
            if my_tags[i][1] != real_tags[i]:
                errors += 1
    return 1 - errors / length

#normalizer
class BaseNormalizer:
    def normalize(self, counter):
        sum_ = sum(counter.values())
        for token in counter:
            counter[token] /= sum_

In [4]:
#loading and parsing data
with open('en_ewt-ud-train.conllu.txt', 'r', encoding = 'utf-8') as f:
    train_raw = f.read()
train = parse(train_raw)
with open('en_ewt-ud-test.conllu.txt', 'r', encoding = 'utf-8') as f:
    test_raw = f.read()
test = parse(test_raw)

train_sample = []
for sent in train:
    sent_parsed = []
    for token in sent:
        sent_parsed.append((token['form'], token['upostag']))
    train_sample.append(sent_parsed)
test_sample = []
for sent in test:
    sent_parsed = []
    for token in sent:
        sent_parsed.append((token['form'], token['upostag']))
    test_sample.append(sent_parsed)
    
train_conll2000 = conll2000.tagged_sents()[:8000]
test_conll2000 = conll2000.tagged_sents()[8000:]

In [5]:
class BiTagger:
    def __init__(self, train_sample, normalizer=BaseNormalizer()):
        self.normalizer = normalizer
        self.create_emission(train_sample)
        self.create_transition(train_sample)
        self.tags_extraction(train_sample)
        self.taglist_extraction()
        
    def tags_extraction(self, train_sample):
        tags = []
        for sent in train_sample:
            sent_parsed = []
            for el in sent:
                sent_parsed.append(el[1])
            tags.append(sent_parsed)
        self.all_tags = tags
        return self.all_tags
    
    def taglist_extraction(self):
        unique = []
        for sent in self.all_tags:
            for tag in sent:
                if tag not in unique:
                    unique.append(tag)
        self.unique_tags = unique
        return self.unique_tags
    
    def create_emission(self, train_sample):
        self.em = defaultdict(Counter)
        for sent in train_sample:
            for word, tag in sent:
                self.em[tag][word] += 1
                self.em['UNK'][word] += 1
        for tag in self.em:
            self.normalizer.normalize(self.em[tag])
        return self.em
    
    def create_transition(self, train_sample):
        self.tm = defaultdict(Counter)
        all_tags = self.tags_extraction(train_sample)
        for sent in all_tags:
            for i, tag in enumerate(sent):
                if i == 0:
                    self.tm['START'][sent[0]] += 1
                else:
                    self.tm[sent[i-1]][sent[i]] += 1
        for el in self.tm:
            self.normalizer.normalize(self.tm[el])
        return self.tm
    
    def tag(self, sent):
        tags = []
        unique_tags = self.taglist_extraction()
        for i, word in enumerate(sent):
            max_prob = 0
            best_tag = 'UNK'
            if i == 0:
                t_back = 'START'
            else:
                t_back = tags[i-1]
            for t in unique_tags:
                prob = self.em[t][word[0]] * self.tm[t_back][t]
                if prob > max_prob:
                    max_prob, best_tag = prob, t
            tags.append(best_tag)
        return list(zip(sent, tags))

In [6]:
a = BiTagger(train_sample)
b = BiTagger(train_conll2000)

In [7]:
#никак не могу найти ошибку, из-за которой точность такая низкая :(#

ud_acc = accuracy(test_sample, a)
print('Точность разметки при обучении по корпусу UD:', ud_acc)

Точность разметки при обучении по корпусу UD: 0.21022432960114756


In [8]:
conll2000_acc = accuracy(test_conll2000, b)
print('Точность разметки при обучении по корпусу conll2000:', conll2000_acc)

Точность разметки при обучении по корпусу conll2000: 0.01968242990115554
