<a href="https://colab.research.google.com/github/zhangxs131/NER/blob/main/ner_crf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#使用CRF实现NER

*     数据集 CONLL
*     特征工程 bi-gram +pos
*     模型  CRF (sklearn_crfsuite)


In [16]:
#下载数据
!mkdir data
!wget https://data.deepai.org/conll2003.zip
!unzip -d data conll2003.zip

mkdir: cannot create directory ‘data’: File exists
--2022-03-06 08:48:17--  https://data.deepai.org/conll2003.zip
Resolving data.deepai.org (data.deepai.org)... 138.201.36.183
Connecting to data.deepai.org (data.deepai.org)|138.201.36.183|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 982975 (960K) [application/x-zip-compressed]
Saving to: ‘conll2003.zip.1’


2022-03-06 08:48:17 (2.09 MB/s) - ‘conll2003.zip.1’ saved [982975/982975]

Archive:  conll2003.zip
replace data/metadata? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace data/test.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace data/train.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace data/valid.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: n


In [18]:
!pip install sklearn_crfsuite



In [47]:
#导入必要的包
import nltk
import sklearn
#下载nltk词性标注和分词的工具
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')
from nltk.tag import pos_tag
from sklearn_crfsuite import CRF ,metrics
from sklearn.metrics import make_scorer,confusion_matrix
from pprint import pprint
from sklearn.metrics import f1_score,classification_report
from sklearn.pipeline import Pipeline
import string
import warnings
warnings.filterwarnings('ignore')

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [20]:
#查看数据形式
!head -n 20 data/train.txt

-DOCSTART- -X- -X- O

EU NNP B-NP B-ORG
rejects VBZ B-VP O
German JJ B-NP B-MISC
call NN I-NP O
to TO B-VP O
boycott VB I-VP O
British JJ B-NP B-MISC
lamb NN I-NP O
. . O O

Peter NNP B-NP B-PER
Blackburn NNP I-NP I-PER

BRUSSELS NNP B-NP B-LOC
1996-08-22 CD I-NP O

The DT B-NP O
European NNP I-NP B-ORG


数据前两行为标题，一共四列，分别为
*    token
*    pos
*    ——
*    entity Labels
每列中间使用空格分隔，每句话中间使用空行进行分割

#导入数据

In [21]:

def load__data_conll(file_path):
    myoutput,words,tags = [],[],[]
    with open(file_path,'r',encoding='utf-8') as f:
      content=f.readlines()
    for line in content[2:]:
        line = line.strip()
        if " " not in line:
            #Sentence ended.
            myoutput.append([words,tags])
            words,tags = [],[]
        else:
            word,_,_,tag = line.split(" ")
            words.append(word)
            tags.append(tag)
    return myoutput

#特征工程
这里使用bi-gram作为特征，然后使用nlp.pos_tag词性作为辅助特征。

In [22]:
def sent2feats(sentence):
    feats = []
    sen_tags = pos_tag(sentence) #This format is specific to this POS tagger!
    for i in range(0,len(sentence)):
        word = sentence[i]
        wordfeats = {}
       #word features: word, prev 2 words, next 2 words in the sentence.
        wordfeats['word'] = word
        if i == 0:
            wordfeats["prevWord"] = wordfeats["prevSecondWord"] = "<S>"
        elif i==1:
            wordfeats["prevWord"] = sentence[0]
            wordfeats["prevSecondWord"] = "</S>"
        else:
            wordfeats["prevWord"] = sentence[i-1]
            wordfeats["prevSecondWord"] = sentence[i-2]
        #next two words as features
        if i == len(sentence)-2:
            wordfeats["nextWord"] = sentence[i+1]
            wordfeats["nextNextWord"] = "</S>"
        elif i==len(sentence)-1:
            wordfeats["nextWord"] = "</S>"
            wordfeats["nextNextWord"] = "</S>"
        else:
            wordfeats["nextWord"] = sentence[i+1]
            wordfeats["nextNextWord"] = sentence[i+2]
        
        #POS tag features: current tag, previous and next 2 tags.
        wordfeats['tag'] = sen_tags[i][1]
        if i == 0:
            wordfeats["prevTag"] = wordfeats["prevSecondTag"] = "<S>"
        elif i == 1:
            wordfeats["prevTag"] = sen_tags[0][1]
            wordfeats["prevSecondTag"] = "</S>"
        else:
            wordfeats["prevTag"] = sen_tags[i - 1][1]

            wordfeats["prevSecondTag"] = sen_tags[i - 2][1]
            # next two words as features
        if i == len(sentence) - 2:
            wordfeats["nextTag"] = sen_tags[i + 1][1]
            wordfeats["nextNextTag"] = "</S>"
        elif i == len(sentence) - 1:
            wordfeats["nextTag"] = "</S>"
            wordfeats["nextNextTag"] = "</S>"
        else:
            wordfeats["nextTag"] = sen_tags[i + 1][1]
            wordfeats["nextNextTag"] = sen_tags[i + 2][1]
        #That is it! You can add whatever you want!
        feats.append(wordfeats)
    return feats

# 提取特征

In [37]:
#计算并显示困惑矩阵

def get_confusion_matrix(y_true,y_pred,labels):
    trues,preds = [], []
    for yseq_true, yseq_pred in zip(y_true, y_pred):
        trues.extend(yseq_true)
        preds.extend(yseq_pred)
    print_cm(confusion_matrix(trues,preds),labels)

def print_cm(cm, labels):
    print("\n")
    """pretty print for confusion matrixes"""
    columnwidth = max([len(x) for x in labels] + [5])  # 5 is value length
    empty_cell = " " * columnwidth
    # Print header
    print("    " + empty_cell, end=" ")
    for label in labels:
        print("%{0}s".format(columnwidth) % label, end=" ")
    print()
    # Print rows
    for i, label1 in enumerate(labels):
        print("    %{0}s".format(columnwidth) % label1, end=" ")
        sum = 0
        for j in range(len(labels)):
            cell = "%{0}.0f".format(columnwidth) % cm[i, j]
            sum =  sum + int(cell)
            print(cell, end=" ")
        print(sum) #Prints the total number of instances per cat at the end.

In [24]:
def get_feats_conll(conll_data):
    feats = []
    labels = []
    for sentence in conll_data:
        feats.append(sent2feats(sentence[0]))
        labels.append(sentence[1])
    return feats, labels

#训练模型

In [51]:
#训练一个序列模型
def train_seq(X_train,Y_train,X_dev,Y_dev):
  crf=CRF(algorithm='lbfgs',c1=0.1,c2=10,max_iterations=50,all_possible_states=True)
  crf.fit(X_train,Y_train)
  labels=list(crf.classes_)

  #testing
  y_pred=crf.predict(X_dev)
  sorted_labels=sorted(labels,key=lambda name:(name[1:],name[0]))
  print('sorted labels :',sorted_labels)
  print(metrics.flat_f1_score(Y_dev,y_pred,average='weighted',labels=labels))
  print(metrics.sequence_accuracy_score(Y_dev, y_pred))
  #计算每一类的标签的p,r,f值，用于标签类别不平均的情况
  y_true_sum=[]
  y_pred_sum=[]
  for i in Y_dev:
    y_true_sum+=i
  for i in y_pred:
    y_pred_sum+=i
  print(sklearn.metrics.classification_report(y_true_sum, y_pred_sum,labels=sorted_labels))
  get_confusion_matrix(Y_dev, y_pred,labels=sorted_labels)

  return crf

#在main中调用我们的函数

In [54]:
def main():
    
    
    train_path = 'data/train.txt'
    test_path = 'data/test.txt'
        
    conll_train = load__data_conll(train_path)
    conll_dev = load__data_conll(test_path)
    
    print("Training a Sequence classification model with CRF")
    feats, labels = get_feats_conll(conll_train)
    devfeats, devlabels = get_feats_conll(conll_dev)
    model=train_seq(feats, labels, devfeats, devlabels)
    print("Done with sequence model")

if __name__=="__main__":
    main()

Training a Sequence classification model with CRF
sorted labels : ['O', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER']
0.9291277569768975
0.5929948411620961
              precision    recall  f1-score   support

           O       0.97      0.98      0.98     38553
       B-LOC       0.72      0.77      0.74      1668
       I-LOC       0.80      0.49      0.61       257
      B-MISC       0.70      0.36      0.48       702
      I-MISC       0.67      0.52      0.58       216
       B-ORG       0.68      0.57      0.62      1661
       I-ORG       0.56      0.71      0.62       835
       B-PER       0.77      0.79      0.78      1617
       I-PER       0.82      0.90      0.86      1156

    accuracy                           0.93     46665
   macro avg       0.74      0.68      0.70     46665
weighted avg       0.93      0.93      0.93     46665



                O  B-LOC  I-LOC B-MISC I-MISC  B-ORG  I-ORG  B-PER  I-PER 
         O   1286     36     91   