In [1]:
#https://eli5.readthedocs.io/en/latest/tutorials/sklearn_crfsuite.html
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('ggplot')

In [2]:
from itertools import chain

import nltk
import sklearn
import scipy.stats
from sklearn.metrics import make_scorer
from sklearn.cross_validation import cross_val_score
from sklearn.grid_search import RandomizedSearchCV

import sklearn_crfsuite
from sklearn_crfsuite import scorers
from sklearn_crfsuite import metrics



In [3]:
nltk.download('conll2002')
nltk.corpus.conll2002.fileids()

[nltk_data] Downloading package conll2002 to /home/pablo/nltk_data...
[nltk_data]   Package conll2002 is already up-to-date!


['esp.testa', 'esp.testb', 'esp.train', 'ned.testa', 'ned.testb', 'ned.train']

In [4]:
%time
train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))
test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 10 µs


In [14]:
train_sents[0]

[('Melbourne', 'NP', 'B-LOC'),
 ('(', 'Fpa', 'O'),
 ('Australia', 'NP', 'B-LOC'),
 (')', 'Fpt', 'O'),
 (',', 'Fc', 'O'),
 ('25', 'Z', 'O'),
 ('may', 'NC', 'O'),
 ('(', 'Fpa', 'O'),
 ('EFE', 'NC', 'B-ORG'),
 (')', 'Fpt', 'O'),
 ('.', 'Fp', 'O')]

# Features

In [6]:
def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]

    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word[-3:]': word[-3:],
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),
        'postag': postag,
        'postag[:2]': postag[:2],
    }
    if i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            '-1:postag': postag1,
            '-1:postag[:2]': postag1[:2],
        })
    else:
        features['BOS'] = True

    if i < len(sent)-1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            '+1:postag': postag1,
            '+1:postag[:2]': postag1[:2],
        })
    else:
        features['EOS'] = True

    return features

def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, postag, label in sent]

def sent2tokens(sent):
    return [token for token, postag, label in sent]

In [7]:
sent2features(train_sents[0])[0]

{'bias': 1.0,
 'word.lower()': 'melbourne',
 'word[-3:]': 'rne',
 'word.isupper()': False,
 'word.istitle()': True,
 'word.isdigit()': False,
 'postag': 'NP',
 'postag[:2]': 'NP',
 'BOS': True,
 '+1:word.lower()': '(',
 '+1:word.istitle()': False,
 '+1:word.isupper()': False,
 '+1:postag': 'Fpa',
 '+1:postag[:2]': 'Fp'}

Extract features from the data.

In [8]:
%time
X_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

X_test = [sent2features(s) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 6.91 µs


In [20]:
X_train[0][0:2]

[{'bias': 1.0,
  'word.lower()': 'melbourne',
  'word[-3:]': 'rne',
  'word.isupper()': False,
  'word.istitle()': True,
  'word.isdigit()': False,
  'postag': 'NP',
  'postag[:2]': 'NP',
  'BOS': True,
  '+1:word.lower()': '(',
  '+1:word.istitle()': False,
  '+1:word.isupper()': False,
  '+1:postag': 'Fpa',
  '+1:postag[:2]': 'Fp'},
 {'bias': 1.0,
  'word.lower()': '(',
  'word[-3:]': '(',
  'word.isupper()': False,
  'word.istitle()': False,
  'word.isdigit()': False,
  'postag': 'Fpa',
  'postag[:2]': 'Fp',
  '-1:word.lower()': 'melbourne',
  '-1:word.istitle()': True,
  '-1:word.isupper()': False,
  '-1:postag': 'NP',
  '-1:postag[:2]': 'NP',
  '+1:word.lower()': 'australia',
  '+1:word.istitle()': True,
  '+1:word.isupper()': False,
  '+1:postag': 'NP',
  '+1:postag[:2]': 'NP'}]

In [21]:
y_train[0][0:2]

['B-LOC', 'O']

In [10]:
X_train[2][0:2]

[{'bias': 1.0,
  'word.lower()': 'el',
  'word[-3:]': 'El',
  'word.isupper()': False,
  'word.istitle()': True,
  'word.isdigit()': False,
  'postag': 'DA',
  'postag[:2]': 'DA',
  'BOS': True,
  '+1:word.lower()': 'abogado',
  '+1:word.istitle()': True,
  '+1:word.isupper()': False,
  '+1:postag': 'NC',
  '+1:postag[:2]': 'NC'},
 {'bias': 1.0,
  'word.lower()': 'abogado',
  'word[-3:]': 'ado',
  'word.isupper()': False,
  'word.istitle()': True,
  'word.isdigit()': False,
  'postag': 'NC',
  'postag[:2]': 'NC',
  '-1:word.lower()': 'el',
  '-1:word.istitle()': True,
  '-1:word.isupper()': False,
  '-1:postag': 'DA',
  '-1:postag[:2]': 'DA',
  '+1:word.lower()': 'general',
  '+1:word.istitle()': True,
  '+1:word.isupper()': False,
  '+1:postag': 'AQ',
  '+1:postag[:2]': 'AQ'}]

In [11]:
y_train[2][0:2]

['O', 'B-PER']

# Training

In [12]:
%%time
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs',
    c1=0.1,
    c2=0.1,
    max_iterations=30,
    all_possible_transitions=False,
)
crf.fit(X_train, y_train)

CPU times: user 18.4 s, sys: 68 ms, total: 18.5 s
Wall time: 18.5 s


# Evaluation

In [15]:
labels = list(crf.classes_)
labels.remove('O')
labels

['B-LOC', 'B-ORG', 'B-PER', 'I-PER', 'B-MISC', 'I-ORG', 'I-LOC', 'I-MISC']

In [42]:
y_pred = crf.predict(X_test)
metrics.flat_f1_score(y_test, y_pred, average='weighted', labels=labels)

0.5843169201751595

In [44]:
# group B and I results
sorted_labels = sorted(
    labels, 
    key=lambda name: (name[1:], name[0])
)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=3
))

             precision    recall  f1-score   support

      B-LOC      0.610     0.643     0.626      1084
      I-LOC      0.295     0.369     0.328       325
     B-MISC      0.230     0.145     0.178       339
     I-MISC      0.259     0.404     0.315       557
      B-ORG      0.721     0.642     0.679      1400
      I-ORG      0.713     0.426     0.533      1104
      B-PER      0.697     0.774     0.734       735
      I-PER      0.737     0.886     0.805       634

avg / total      0.608     0.581     0.584      6178



# Older tests

In [27]:
eli5.show_weights(crf, top=30)

From \ To,O,B-LOC,I-LOC,B-MISC,I-MISC,B-ORG,I-ORG,B-PER,I-PER
O,3.281,2.204,0.0,2.101,0.0,3.468,0.0,2.325,0.0
B-LOC,-0.259,-0.098,4.058,0.0,0.0,0.0,0.0,-0.212,0.0
I-LOC,-0.173,-0.609,3.436,0.0,0.0,0.0,0.0,0.0,0.0
B-MISC,-0.673,-0.341,0.0,0.0,4.069,-0.308,0.0,-0.331,0.0
I-MISC,-0.803,-0.998,0.0,-0.519,4.977,-0.817,0.0,-0.611,0.0
B-ORG,-0.096,-0.242,0.0,-0.57,0.0,-1.012,4.739,-0.306,0.0
I-ORG,-0.339,-1.758,0.0,-0.841,0.0,-1.382,5.062,-0.472,0.0
B-PER,-0.4,-0.851,0.0,0.0,0.0,-1.013,0.0,-0.937,4.329
I-PER,-0.676,-0.47,0.0,0.0,0.0,0.0,0.0,-0.659,3.754

Weight?,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0
Weight?,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Weight?,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
Weight?,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3
Weight?,Feature,Unnamed: 2_level_4,Unnamed: 3_level_4,Unnamed: 4_level_4,Unnamed: 5_level_4,Unnamed: 6_level_4,Unnamed: 7_level_4,Unnamed: 8_level_4
Weight?,Feature,Unnamed: 2_level_5,Unnamed: 3_level_5,Unnamed: 4_level_5,Unnamed: 5_level_5,Unnamed: 6_level_5,Unnamed: 7_level_5,Unnamed: 8_level_5
Weight?,Feature,Unnamed: 2_level_6,Unnamed: 3_level_6,Unnamed: 4_level_6,Unnamed: 5_level_6,Unnamed: 6_level_6,Unnamed: 7_level_6,Unnamed: 8_level_6
Weight?,Feature,Unnamed: 2_level_7,Unnamed: 3_level_7,Unnamed: 4_level_7,Unnamed: 5_level_7,Unnamed: 6_level_7,Unnamed: 7_level_7,Unnamed: 8_level_7
Weight?,Feature,Unnamed: 2_level_8,Unnamed: 3_level_8,Unnamed: 4_level_8,Unnamed: 5_level_8,Unnamed: 6_level_8,Unnamed: 7_level_8,Unnamed: 8_level_8
+4.416,postag[:2]:Fp,,,,,,,
+3.116,BOS,,,,,,,
+2.401,bias,,,,,,,
+2.297,postag:Fc,,,,,,,
+2.297,postag[:2]:Fc,,,,,,,
+2.297,"word[-3:]:,",,,,,,,
+2.297,"word.lower():,",,,,,,,
+2.124,postag:CC,,,,,,,
+2.124,postag[:2]:CC,,,,,,,
+1.984,EOS,,,,,,,

Weight?,Feature
+4.416,postag[:2]:Fp
+3.116,BOS
+2.401,bias
+2.297,postag:Fc
+2.297,postag[:2]:Fc
+2.297,"word[-3:]:,"
+2.297,"word.lower():,"
+2.124,postag:CC
+2.124,postag[:2]:CC
+1.984,EOS

Weight?,Feature
+2.530,word.istitle()
+2.224,-1:word.lower():en
+0.906,word[-3:]:rid
+0.905,word.lower():madrid
+0.646,word.lower():españa
+0.640,word[-3:]:ona
+0.595,word[-3:]:aña
+0.595,+1:postag[:2]:Fp
+0.515,word.lower():parís
+0.514,word[-3:]:rís

Weight?,Feature
+0.886,-1:word.istitle()
+0.664,-1:word.lower():de
+0.582,word[-3:]:de
+0.578,word.lower():de
+0.529,-1:word.lower():san
+0.444,+1:word.istitle()
+0.441,word.istitle()
+0.335,-1:word.lower():la
+0.262,postag[:2]:SP
+0.262,postag:SP

Weight?,Feature
+1.770,word.isupper()
+0.693,word.istitle()
+0.606,postag[:2]:Fe
+0.606,postag:Fe
+0.606,"word[-3:]:"""
+0.606,"word.lower():"""
+0.538,+1:word.istitle()
+0.508,"-1:word.lower():"""
+0.508,-1:postag:Fe
+0.508,-1:postag[:2]:Fe

Weight?,Feature
+1.364,-1:word.istitle()
+0.675,-1:word.lower():de
+0.597,"+1:word.lower():"""
+0.597,+1:postag:Fe
+0.597,+1:postag[:2]:Fe
+0.369,-1:postag:NC
+0.369,-1:postag[:2]:NC
+0.324,-1:word.lower():liga
+0.318,word[-3:]:de
+0.304,word.lower():de

Weight?,Feature
+2.695,word.lower():efe
+2.519,word.isupper()
+2.084,word[-3:]:EFE
+1.174,word.lower():gobierno
+1.142,word.istitle()
+1.018,-1:word.lower():del
+0.958,word[-3:]:rno
+0.671,word.lower():pp
+0.671,word[-3:]:PP
+0.667,-1:word.lower():al

Weight?,Feature
+1.499,-1:word.istitle()
+1.200,-1:word.lower():de
+0.539,-1:word.lower():real
+0.511,word[-3:]:rid
+0.446,word[-3:]:de
+0.433,word.lower():de
+0.428,-1:postag:SP
+0.428,-1:postag[:2]:SP
+0.399,word.lower():madrid
+0.368,word[-3:]:la

Weight?,Feature
+1.698,word.istitle()
+0.683,-1:postag:VMI
+0.601,+1:postag[:2]:VM
+0.589,postag[:2]:NP
+0.589,postag:NP
+0.589,+1:postag:VMI
+0.565,-1:word.lower():a
+0.520,word[-3:]:osé
+0.503,word.lower():josé
+0.476,-1:postag[:2]:VM

Weight?,Feature
+2.742,-1:word.istitle()
+0.736,word.istitle()
+0.660,-1:word.lower():josé
+0.598,-1:postag[:2]:AQ
+0.598,-1:postag:AQ
+0.510,-1:postag[:2]:VM
+0.487,-1:word.lower():juan
+0.419,-1:word.lower():maría
+0.413,-1:postag:VMI
+0.345,-1:word.lower():luis


In [34]:
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs',
    c1=0.1,
    c2=0.1,
    max_iterations=20,
    all_possible_transitions=True,
)
crf.fit(X_train, y_train);

In [35]:
eli5.show_weights(crf, top=5, show=['transition_features'])

From \ To,O,B-LOC,I-LOC,B-MISC,I-MISC,B-ORG,I-ORG,B-PER,I-PER
O,2.732,1.217,-4.675,1.515,-5.785,1.36,-6.19,0.968,-6.236
B-LOC,-0.226,-0.091,3.378,-0.433,-1.065,-0.861,-1.783,-0.295,-1.57
I-LOC,-0.184,-0.585,2.404,-0.276,-0.485,-0.582,-0.749,-0.442,-0.647
B-MISC,-0.714,-0.353,-0.539,-0.278,3.512,-0.412,-1.047,-0.336,-0.895
I-MISC,-0.697,-0.846,-0.587,-0.297,4.252,-0.84,-1.206,-0.523,-1.001
B-ORG,0.419,-0.187,-1.074,-0.567,-1.607,-1.13,5.392,-0.223,-2.122
I-ORG,-0.117,-1.715,-0.863,-0.631,-1.221,-1.442,5.141,-0.397,-1.908
B-PER,-0.127,-0.806,-0.834,-0.52,-1.228,-1.089,-2.076,-1.01,4.04
I-PER,-0.766,-0.242,-0.67,-0.418,-0.856,-0.903,-1.472,-0.692,2.909


In [36]:
eli5.show_weights(crf, top=10, targets=['O', 'B-ORG', 'I-ORG'])

From \ To,O,B-ORG,I-ORG
O,2.732,1.36,-6.19
B-ORG,0.419,-1.13,5.392
I-ORG,-0.117,-1.442,5.141

Weight?,Feature,Unnamed: 2_level_0
Weight?,Feature,Unnamed: 2_level_1
Weight?,Feature,Unnamed: 2_level_2
+4.931,BOS,
+3.754,postag[:2]:Fp,
+3.539,bias,
+2.328,"word.lower():,",
+2.328,"word[-3:]:,",
+2.328,postag:Fc,
+2.328,postag[:2]:Fc,
… 15039 more positive …,… 15039 more positive …,
… 3905 more negative …,… 3905 more negative …,
-2.187,postag[:2]:NP,

Weight?,Feature
+4.931,BOS
+3.754,postag[:2]:Fp
+3.539,bias
+2.328,"word.lower():,"
+2.328,"word[-3:]:,"
+2.328,postag:Fc
+2.328,postag[:2]:Fc
… 15039 more positive …,… 15039 more positive …
… 3905 more negative …,… 3905 more negative …
-2.187,postag[:2]:NP

Weight?,Feature
+3.041,word.isupper()
+2.952,word.lower():efe
+1.851,word[-3:]:EFE
+1.278,word.lower():gobierno
+1.033,word[-3:]:rno
+1.005,word.istitle()
+0.864,-1:word.lower():del
… 3524 more positive …,… 3524 more positive …
… 621 more negative …,… 621 more negative …
-0.842,-1:word.lower():en

Weight?,Feature
+1.159,-1:word.lower():de
+0.993,-1:word.istitle()
+0.637,-1:postag:SP
+0.637,-1:postag[:2]:SP
+0.570,-1:word.lower():real
+0.547,word.istitle()
… 3517 more positive …,… 3517 more positive …
… 676 more negative …,… 676 more negative …
-0.480,postag:VMI
-0.508,postag[:2]:VM


In [37]:
eli5.show_weights(crf, top=10, feature_re='^word\.is',
                  horizontal_layout=False, show=['targets'])

Weight?,Feature
-3.685,word.isupper()
-7.025,word.istitle()

Weight?,Feature
2.397,word.istitle()
0.099,word.isupper()
-0.152,word.isdigit()

Weight?,Feature
0.46,word.istitle()
-0.018,word.isdigit()
-0.345,word.isupper()

Weight?,Feature
2.017,word.isupper()
0.603,word.istitle()
-0.012,word.isdigit()

Weight?,Feature
0.271,word.isdigit()
-0.072,word.isupper()
-0.106,word.istitle()

Weight?,Feature
3.041,word.isupper()
1.005,word.istitle()
-0.044,word.isdigit()

Weight?,Feature
0.547,word.istitle()
0.014,word.isdigit()
-0.012,word.isupper()

Weight?,Feature
1.757,word.istitle()
0.05,word.isupper()
-0.123,word.isdigit()

Weight?,Feature
0.976,word.istitle()
0.193,word.isupper()
-0.106,word.isdigit()


In [38]:
expl = eli5.explain_weights(crf, top=5, targets=['O', 'B-LOC', 'I-LOC'])
print(eli5.format_as_text(expl))

Explained as: CRF

Transition features:
            O    B-LOC    I-LOC
-----  ------  -------  -------
O       2.732    1.217   -4.675
B-LOC  -0.226   -0.091    3.378
I-LOC  -0.184   -0.585    2.404

y='O' top features
Weight  Feature       
------  --------------
+4.931  BOS           
+3.754  postag[:2]:Fp 
+3.539  bias          
… 15043 more positive …
… 3906 more negative …
-3.685  word.isupper()
-7.025  word.istitle()

y='B-LOC' top features
Weight  Feature           
------  ------------------
+2.397  word.istitle()    
+2.147  -1:word.lower():en
  … 2284 more positive …  
  … 433 more negative …   
-1.080  postag:SP         
-1.080  postag[:2]:SP     
-1.273  -1:word.istitle() 

y='I-LOC' top features
Weight  Feature           
------  ------------------
+0.882  -1:word.lower():de
+0.780  -1:word.istitle() 
+0.718  word[-3:]:de      
+0.711  word.lower():de   
  … 1684 more positive …  
  … 268 more negative …   
-1.965  BOS               

