In [1]:
#!pip install protobuf



In [2]:
#!pip install sklearn_crfsuite eli5



In [61]:
import eli5
import nltk
import scipy.stats
import sklearn
import sklearn_crfsuite

from itertools import chain
from sklearn.metrics import make_scorer
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RandomizedSearchCV
from sklearn_crfsuite import scorers
from sklearn_crfsuite import metrics

nltk.download('conll2002')

[nltk_data] Downloading package conll2002 to
[nltk_data]     /Users/bootcamp/nltk_data...
[nltk_data]   Package conll2002 is already up-to-date!


True

### Загрузим данные:

In [62]:
train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))
test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))
train_sents[0]

[('Melbourne', 'NP', 'B-LOC'),
 ('(', 'Fpa', 'O'),
 ('Australia', 'NP', 'B-LOC'),
 (')', 'Fpt', 'O'),
 (',', 'Fc', 'O'),
 ('25', 'Z', 'O'),
 ('may', 'NC', 'O'),
 ('(', 'Fpa', 'O'),
 ('EFE', 'NC', 'B-ORG'),
 (')', 'Fpt', 'O'),
 ('.', 'Fp', 'O')]

In [63]:
len(train_sents)

8323

### Добавим фичи для каждого слова, чтобы обучить CRF (смотри лекцию:)):

In [73]:
def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]
    articles = ("a", "de", "con", "por", "para", "en", "ante", "bajo", "contra", "desde", "entre", "hasta", "segun", "sin", "sobre", "tras", "hacia", "detrás")
    sigles = ("uno", "un", "una", "uno")
    pointers = ("yo", "tú", "ella", "Él".lower(), "usted", "ustedes", "vosotros", "vosotras", "nosotros", "nosotras")
    this = ("este", "esto", "esta", "estos", "estas", "eso", "esa", "esos")
    
    features = {
        'bias': 1.0,
        'postag': postag,
        'word.len': len(word),
        ### YOUR CODE HERE
        'word.isupper': word.isupper(),
        'word.istitle': word.istitle(),
        'word.numeric': word.isnumeric(),
        'word.isaplha': word.isalpha(),
        'word.xx': word in ("lo", "la", "le"),
        "word.endswith_punctuation": word in (".", "?", "!"),
        "word.article": word in articles,
        "word.sigle": word in sigles,
        "word.pointer": word in pointers,
        "word.last2letters": word[:-2] if len(word) > 1 else word,
        "word.last_letter": word[:-1] if len(word) > 0 else word
    }
    
    if len(word) >= 2:
        features["word.verb_inf"] = word[-2:] in ("ar", "ir", "er")
    
    if i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features.update({
            '-1:word.istitle': word1.istitle(),
            ### YOUR CODE HERE
            '-1.postag': postag1,
            '-1.word.len': len(word1),
            '-1.word.isupper': word1.isupper(),
            '-1.word.numeric': word1.isnumeric(),
            '-1.word.isaplha': word1.isalpha(),
            '-1.word.verb_inf': word1[-2:] in ("ar", "ir", "er"),
            '-1.word.xx': word1 in ("lo", "la", "le"),
            "-1.word.endswith_punctuation": word1 in (".", "?", "!"),
            "-1.word.article": word1 in articles,
            "-1.word.sigle": word in sigles,
            "-1.word.pointer": word in pointers,
            "-1.word.last2letters": word[:-2] if len(word) > 1 else word,
            "-1.word.last_letter": word[:-1] if len(word) > 0 else word
        })
        if len(word1) >= 2:
            features["-1.word.verb_inf"] = word1[-2:] in ("ar", "ir", "er")
    else:
        features['BOS'] = True

    if i < len(sent)-1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features.update({
            '+1:word.istitle': word1.istitle(),
            ### YOUR CODE HERE
            '+1.postag': postag1,
            '+1.word.len': len(word1),
            '+1.word.isupper': word1.isupper(),
            '+1.word.numeric': word1.isnumeric(),
            '+1.word.isaplha': word1.isalpha(),
            '+1.word.verb_inf': word1[-2:] in ("ar", "ir", "er"),
            '+1.word.xx': word1 in ("lo", "la", "le"),
            "+1.word.endswith_punctuation": word1 in (".", "?", "!"),
            "+1.word.article": word1 in articles,
            "+1.word.sigle": word in sigles,
            "+1.word.pointer": word in pointers,
            "+1.word.last2letters": word[:-2] if len(word) > 1 else word,
            "+1.word.last_letter": word[:-1] if len(word) > 0 else word
        })
        if len(word1) >= 2:
            features["+1.word.verb_inf"] = word1[-2:] in ("ar", "ir", "er")
    else:
        features['EOS'] = True

    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, postag, label in sent]

def sent2tokens(sent):
    return [token for token, postag, label in sent]

X_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

X_test = [sent2features(s) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]

### Посмотрим на пример фичей для одного слова:

In [74]:
X_train[0][1]

{'bias': 1.0,
 'postag': 'Fpa',
 'word.len': 1,
 'word.isupper': False,
 'word.istitle': False,
 'word.numeric': False,
 'word.isaplha': False,
 'word.xx': False,
 'word.endswith_punctuation': False,
 'word.article': False,
 'word.sigle': False,
 'word.pointer': False,
 'word.last2letters': '(',
 'word.last_letter': '',
 '-1:word.istitle': True,
 '-1.postag': 'NP',
 '-1.word.len': 9,
 '-1.word.isupper': False,
 '-1.word.numeric': False,
 '-1.word.isaplha': True,
 '-1.word.verb_inf': False,
 '-1.word.xx': False,
 '-1.word.endswith_punctuation': False,
 '-1.word.article': False,
 '-1.word.sigle': False,
 '-1.word.pointer': False,
 '-1.word.last2letters': '(',
 '-1.word.last_letter': '',
 '+1:word.istitle': True,
 '+1.postag': 'NP',
 '+1.word.len': 9,
 '+1.word.isupper': False,
 '+1.word.numeric': False,
 '+1.word.isaplha': True,
 '+1.word.verb_inf': False,
 '+1.word.xx': False,
 '+1.word.endswith_punctuation': False,
 '+1.word.article': False,
 '+1.word.sigle': False,
 '+1.word.pointer':

### Обучим CRF:

In [66]:
%%time
### YOUR CODE HERE (Probably you will change some hyperparameters)
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs',
    c1=0.5,
    c2=0.5,
    max_iterations=100,
    all_possible_transitions=True
)
crf.fit(X_train, y_train)

CPU times: user 41.1 s, sys: 1.4 s, total: 42.5 s
Wall time: 44.5 s


### Посмотрим на веса признаков:

In [78]:
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs',
    c1=0.5,
    c2=0.5,
    max_iterations=150,
    all_possible_transitions=True
)
crf.fit(X_train, y_train)
eli5.show_weights(crf, top=30)

From \ To,O,B-LOC,I-LOC,B-MISC,I-MISC,B-ORG,I-ORG,B-PER,I-PER
O,2.286,0.537,-5.842,0.422,-6.192,0.719,-5.911,0.218,-5.461
B-LOC,-0.23,0.861,5.239,-1.557,-1.655,-1.671,-1.414,-1.139,-1.277
I-LOC,0.015,-0.96,4.828,-1.008,-1.429,-1.08,-1.199,-1.727,-1.16
B-MISC,-0.661,-0.52,-1.137,-2.293,4.524,-0.711,-1.449,-1.955,-1.386
I-MISC,-0.565,-1.437,-1.508,-0.93,4.737,-1.509,-1.683,-1.103,-1.26
B-ORG,-0.161,-0.45,-1.618,-2.18,-2.232,-2.724,5.037,-1.35,-1.635
I-ORG,-0.557,-2.157,-2.039,-1.872,-2.314,-1.609,4.879,-1.256,-1.808
B-PER,-0.075,-0.63,-0.962,-1.53,-1.524,-1.645,-1.157,-3.243,5.648
I-PER,-0.411,0.62,-1.203,-1.464,-1.52,-1.795,-1.266,-1.5,4.73

Weight?,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0
Weight?,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Weight?,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
Weight?,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3
Weight?,Feature,Unnamed: 2_level_4,Unnamed: 3_level_4,Unnamed: 4_level_4,Unnamed: 5_level_4,Unnamed: 6_level_4,Unnamed: 7_level_4,Unnamed: 8_level_4
Weight?,Feature,Unnamed: 2_level_5,Unnamed: 3_level_5,Unnamed: 4_level_5,Unnamed: 5_level_5,Unnamed: 6_level_5,Unnamed: 7_level_5,Unnamed: 8_level_5
Weight?,Feature,Unnamed: 2_level_6,Unnamed: 3_level_6,Unnamed: 4_level_6,Unnamed: 5_level_6,Unnamed: 6_level_6,Unnamed: 7_level_6,Unnamed: 8_level_6
Weight?,Feature,Unnamed: 2_level_7,Unnamed: 3_level_7,Unnamed: 4_level_7,Unnamed: 5_level_7,Unnamed: 6_level_7,Unnamed: 7_level_7,Unnamed: 8_level_7
Weight?,Feature,Unnamed: 2_level_8,Unnamed: 3_level_8,Unnamed: 4_level_8,Unnamed: 5_level_8,Unnamed: 6_level_8,Unnamed: 7_level_8,Unnamed: 8_level_8
+6.283,BOS,,,,,,,
+3.116,postag:RG,,,,,,,
+3.016,-1.postag:Z,,,,,,,
+2.847,word.last_letter:R/,,,,,,,
+2.847,+1.word.last_letter:R/,,,,,,,
+2.736,-1.postag:Fd,,,,,,,
+2.592,postag:CS,,,,,,,
+2.403,word.last2letters:Y,,,,,,,
+2.403,+1.word.last2letters:Y,,,,,,,
+2.344,-1.word.isupper,,,,,,,

Weight?,Feature
+6.283,BOS
+3.116,postag:RG
+3.016,-1.postag:Z
+2.847,word.last_letter:R/
+2.847,+1.word.last_letter:R/
+2.736,-1.postag:Fd
+2.592,postag:CS
+2.403,word.last2letters:Y
+2.403,+1.word.last2letters:Y
+2.344,-1.word.isupper

Weight?,Feature
+1.614,-1.word.last_letter:L
+1.603,BOS
+1.462,+1.word.numeric
+1.432,word.istitle
+1.359,+1.word.last_letter:Madri
+1.359,word.last_letter:Madri
+1.301,-1.word.isupper
+1.232,word.last2letters:Par
+1.232,+1.word.last2letters:Par
+1.194,+1.postag:Fg

Weight?,Feature
+1.270,-1.postag:DA
+0.775,word.last_letter:Unid
+0.775,-1.word.last_letter:Unid
+0.775,+1.word.last_letter:Unid
+0.733,+1.word.last_letter:Chil
+0.733,-1.word.last_letter:Chil
+0.733,word.last_letter:Chil
+0.662,+1.postag:VMS
+0.646,+1.word.last2letters:Monta
+0.646,word.last2letters:Monta

Weight?,Feature
+3.751,word.isupper
+1.415,word.istitle
+1.412,+1.word.last_letter:AV
+1.412,-1.word.last_letter:AV
+1.412,word.last_letter:AV
+1.380,+1.word.last_letter:OP
+1.380,-1.word.last_letter:OP
+1.380,word.last_letter:OP
+1.346,word.last2letters:Diversia.c
+1.346,+1.word.last2letters:Diversia.c

Weight?,Feature
+1.877,word.numeric
+1.381,-1.postag:RG
+1.238,-1.postag:DI
+1.227,-1.postag:DA
+1.127,+1.postag:Z
+1.125,+1.postag:Fe
+0.945,postag:AO
+0.877,-1.postag:Z
+0.850,postag:RG
+0.785,-1.postag:DP

Weight?,Feature
+3.682,word.isupper
+2.254,+1.word.last_letter:Ci
+2.254,-1.word.last_letter:Ci
+2.254,word.last_letter:Ci
+2.166,word.last2letters:EFE-Cantabr
+2.166,word.last_letter:EFE-Cantabri
+2.166,+1.word.last2letters:EFE-Cantabr
+2.166,+1.word.last_letter:EFE-Cantabri
+1.924,+1.word.last_letter:EUi
+1.924,word.last_letter:EUi

Weight?,Feature
+2.731,word.numeric
+2.257,word.isupper
+1.768,-1.postag:DA
+0.879,-1:word.istitle
+0.808,+1.word.last_letter:4
+0.808,-1.word.last_letter:4
+0.808,word.last_letter:4
+0.780,word.last_letter:Dynamic
+0.780,-1.word.last2letters:Dynami
+0.780,-1.word.last_letter:Dynamic

Weight?,Feature
+1.894,word.istitle
+1.862,BOS
+1.493,+1.word.last_letter:McFarlan
+1.493,word.last_letter:McFarlan
+1.493,word.last2letters:McFarla
+1.493,+1.word.last2letters:McFarla
+1.386,-1.postag:VMI
+1.319,word.isupper
+1.243,-1.word.last2letters:McManam
+1.243,-1.word.last_letter:McManama

Weight?,Feature
+1.728,-1.word.xx
+1.385,-1:word.istitle
+1.283,+1.postag:Fx
+1.136,word.istitle
+0.884,+1.postag:Fz
+0.830,-1.word.last2letters:Gánda
+0.830,-1.word.last_letter:Gándar
+0.830,+1.word.last2letters:Gánda
+0.830,+1.word.last_letter:Gándar
+0.830,word.last2letters:Gánda


### Посчитаем предсказание на тесте:

In [76]:
labels = list(crf.classes_)
labels.remove('O')
labels

['B-LOC', 'B-ORG', 'B-PER', 'I-PER', 'B-MISC', 'I-ORG', 'I-LOC', 'I-MISC']

In [77]:
y_pred = crf.predict(X_test)
metrics.flat_f1_score(y_test, y_pred, average='weighted', labels=labels)

0.7717518512048152

### А теперь отдельно для каждого тэга:

In [21]:
sorted_labels = sorted(
    labels,
    key=lambda name: (name[1:], name[0])
)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=3
))

              precision    recall  f1-score   support

       B-LOC      0.589     0.590     0.590      1084
       I-LOC      0.280     0.289     0.284       325
      B-MISC      0.433     0.115     0.182       339
      I-MISC      0.465     0.205     0.284       557
       B-ORG      0.669     0.683     0.676      1400
       I-ORG      0.634     0.753     0.688      1104
       B-PER      0.684     0.747     0.714       735
       I-PER      0.758     0.868     0.809       634

   micro avg      0.626     0.611     0.618      6178
   macro avg      0.564     0.531     0.528      6178
weighted avg      0.608     0.611     0.598      6178



### Посмотрим на наиболее и наименее вероятные переходы модели: 

In [22]:
from collections import Counter

def print_transitions(trans_features):
    for (label_from, label_to), weight in trans_features:
        print("%-6s -> %-7s %0.6f" % (label_from, label_to, weight))

print("Top likely transitions:")
print_transitions(Counter(crf.transition_features_).most_common(20))

print("\nTop unlikely transitions:")
print_transitions(Counter(crf.transition_features_).most_common()[-20:])

Top likely transitions:
B-LOC  -> I-LOC   5.412020
I-LOC  -> I-LOC   5.172153
I-MISC -> I-MISC  5.144828
B-PER  -> I-PER   4.802833
B-MISC -> I-MISC  4.414769
I-PER  -> I-PER   4.343266
B-ORG  -> I-ORG   3.883694
I-ORG  -> I-ORG   3.715221
O      -> O       2.557314
O      -> B-ORG   0.960046
O      -> B-LOC   0.708120
O      -> B-MISC  0.626017
O      -> B-PER   0.279963
B-LOC  -> B-LOC   0.124138
I-LOC  -> O       0.035224
I-PER  -> B-LOC   -0.087076
B-LOC  -> O       -0.102480
I-PER  -> O       -0.187406
B-PER  -> O       -0.227854
B-ORG  -> O       -0.230901

Top unlikely transitions:
B-PER  -> I-ORG   -2.175640
B-LOC  -> I-PER   -2.176378
I-ORG  -> I-LOC   -2.208642
I-ORG  -> I-PER   -2.225142
B-LOC  -> I-ORG   -2.252632
I-ORG  -> B-MISC  -2.253004
I-ORG  -> B-ORG   -2.435161
B-ORG  -> I-PER   -2.480095
I-PER  -> B-ORG   -2.491719
I-MISC -> I-ORG   -2.564156
B-ORG  -> B-MISC  -2.639295
I-ORG  -> I-MISC  -2.799916
B-LOC  -> B-ORG   -2.814176
B-PER  -> B-ORG   -2.846970
B-ORG  -> B-