In [9]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import nltk
from itertools import chain

import nltk
import sklearn
import scipy.stats
from sklearn.metrics import make_scorer
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RandomizedSearchCV

import sklearn_crfsuite
from sklearn_crfsuite import scorers
from sklearn_crfsuite import metrics
import re

from nltk.corpus import stopwords
import pickle
import string
from datetime import datetime

from nltk.stem import WordNetLemmatizer
from nltk.stem import PorterStemmer
from nltk.stem import LancasterStemmer

In [10]:
wordnet_lemmatizer = WordNetLemmatizer()
porter = PorterStemmer()
lancaster=LancasterStemmer()


persongazetteerfilenames = ['lexicon\\people.person.lastnames.modified', 'lexicon\\people.family_name', 'lexicon\\firstname.5000', 'lexicon\\lastname.5000'] 
companygazetteerfilenames = ['lexicon\\business.consumer_company', 'lexicon\\venture_capital.venture_funded_company', 'lexicon\\business.brand']
locationgazetteerfilenames = ['lexicon\\location.country','lexicon\\location','lexicon\\education.university','lexicon\\venues', 'lexicon\\architecture.museum']
productgazetteerfilenames =['lexicon\\product','lexicon\\business.consumer_product','lexicon\\automotive.model','lexicon\\automotive.make']
titlegazetteerfilenames = ['lexicon\\award.award','lexicon\\base.events.festival_series','lexicon\\book.newspaper', 'lexicon\\tv.tv_program']
groupgazetteerfilenames = ['lexicon\\sports.sports_team']
othergazetteerfilenames = ['lexicon\\time.holiday', 'lexicon\\time.recurring_event','lexicon\\base.events.festival_series','lexicon\\broadcast.tv_channel','lexicon\\cvg.cvg_platform','lexicon\\sports.sports_league', 'lexicon\\transportation.road', 'lexicon\\tv.tv_network']


def loadGazetteer(filenames):
    combinedlist = []
    for filename in filenames:
        linelist = [line.rstrip('\n').lower() for line in open(filename, encoding="utf8")]
        #print(len(linelist))
        combinedlist = set().union(combinedlist, linelist)
        #print(len(combinedlist))
    return combinedlist

def isWordInGazette(gazetteer,word):
    if word.lower() in gazetteer:
        return True
    
    return False
    
def isWordGroupInGazette(gazetteer,entity,sentence):
    for entry in gazetteer:
        if entity.lower() in entry:
            if entry in sentence: 
                return True
  
    return False  

persongazetteer = loadGazetteer(persongazetteerfilenames)
companygazetteer = loadGazetteer(companygazetteerfilenames)  
locationgazetteer = loadGazetteer(locationgazetteerfilenames)
productgazetteer = loadGazetteer(productgazetteerfilenames)
titlegazetteer = loadGazetteer(titlegazetteerfilenames)
groupgazetteer = loadGazetteer(groupgazetteerfilenames)
othergazetteer = loadGazetteer(othergazetteerfilenames)

In [11]:
def isURL(string):
    return len(re.findall('http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+] |[!*\(\), ]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', string))>0

def isHashtagUserName(string):
    return (string[0]=='#' or string[0]=='@')

def isAnyDigit(s):
    return any(i.isdigit() for i in s)

def shape(string):
    t1 = re.sub('[A-Z]', 'X',string)
    t2 = re.sub('[a-z]', 'x', t1)
    return re.sub('[0-9]', 'd', t2)

def isAbbr(string):
    regex = re.compile('[A-Z]([A-Z]|\.|&)+') 
    if(regex.search(string) == None):
        return False
    return True

def isPostUpper(post):
    fullpost = ""
    for wordlabel in post:
        word,label = wordlabel
        fullpost += word + " "
    return fullpost.isupper()

def isStopWord(string):
    if string in stopwords.words('english'):
        return True
    return False

def poststring(post, separator=" "):
    fullpost = ""
    for wordlabel in post:
        word,label = wordlabel
        fullpost += word + separator
    return fullpost 

def poststringgazetteer(postwithgaz, separator=" "):
    fullpost = ""
    #print("post poststringgazetteer: ",postwithgaz)
    for wordlabel in postwithgaz:
        #print("wordlabel",wordlabel)
        word,label,persongaz, locationgaz, productgaz, titilegaz, groupgaz, othergaz,companygaz = wordlabel
        fullpost += word + separator
    return fullpost  

def sentenceTag(post):
    fullpost = ""
    for wordlabel in post:
        word,label = wordlabel
        fullpost += word + " "
    return nltk.pos_tag(fullpost.split())

def sentenceTagGazetteer(postwithgaz):
    fullpost = poststringgazetteer(postwithgaz)
    return nltk.pos_tag(fullpost.split()) 


#gaz_wiki_place = open("gazetteer\\wikipedia_place_titles.pickle", 'rb')
#gaz_wiki_place_db = list(pickle.load(gaz_wiki_place, encoding='bytes'))
#def isPlace(key):
    #for keys in gaz_wiki_place_db:
    #    if key in keys:#
#    if key in gaz_wiki_place_db:
#        return True
#    return False

def isromannum(word):
    validRomanNumerals = ["M", "D", "C", "L", "X", "V", "I"]
    for letters in word.upper():
        if letters not in validRomanNumerals:
            return False
    return True

def haspunctuation(word):
    punctuations = string.punctuation
    for letters in word:
        if letters  in punctuations:
            return True
    return False

def ispunctuation(word):
    punctuations = string.punctuation
    for letters in word:
        if letters not in punctuations:
            return False
    return True

triggerwordlist = [line.rstrip('\n').lower() for line in open("triggerwordlist.txt", encoding="utf8")]
def istriggerword(word):
    if word.lower() in set(triggerwordlist):
        return True
    return False

def wordtypepatterns(poststring):
    pattern =""
    #print("inside ",poststring)
    for word in poststring.split():
        
        if word.islower():
            pattern += "l"
        elif word.isupper():
            pattern += "C"
        elif word.istitle():
            pattern += "T"
        elif word in string.punctuation:
            pattern += "."
        else:
            pattern += "x"

    return pattern

def addgazetteer(posts):
    finalresult = []
    for post in posts:
        fullstring = poststring(post)
        result = []
        for line in post:
            word,label = line
            #persongaz, locationgaz, productgaz, titilegaz, groupgaz, othergaz
            newrecord = word, label, isWordInGazette(persongazetteer, word) ,isWordGroupInGazette(locationgazetteer, word, fullstring), \
                          isWordGroupInGazette(productgazetteer, word, fullstring),isWordGroupInGazette(titlegazetteer, word, fullstring),\
                         isWordGroupInGazette(groupgazetteer, word, fullstring),isWordGroupInGazette(othergazetteer, word, fullstring)
            result.append(newrecord)
        finalresult.append(result)    
        
    return finalresult   
        

def preprocess(raw_data):
    posts = raw_data.split("\n\n")
    output = []
    for post in posts:
        lines = post.split("\n")
        
        outputline = []
        for line in lines:
            #print(line)
            if line != "":
                word, label = tuple(line.split("\t"))
                #if(label[0]=='B' or label[0]=='I'):
                #    label = label[2:]
                outputline.append((word,label))
        output.append(outputline)  
    return output 

def preprocessnotag(raw_data):
    posts = raw_data.split("\n\n")
    output = []
    for post in posts:
        
        words = post.split("\n")
        #print(words)
        outputline = []
        for word in words:
            if word != "":
                outputline.append((word," "))
        output.append(outputline)
        #break
    return output 



# REMOVE DUPLICATE POSTS

def removeDuplicate(postswithgaz):
    stringlist = []
    result = []
    for postwithgaz in postswithgaz:
        string = poststringgazetteer(postwithgaz)
        if string not in stringlist:
            stringlist.append(string)
            result.append(postwithgaz)
    
    #print(len(posts),len(stringlist))
    return result    


def postPunctuationAsNER(postwithgaz):
    punctuations = string.punctuation
    for line in postwithgaz:
        word= line[0]
        label= line[1]
        if word in punctuations and label not in 'O':
            return True
    return False

def removePunctuationAsNER(postswithgaz):
    result = []
    
    for postwithgaz in postswithgaz:
        if not postPunctuationAsNER(postwithgaz):
            result.append(postwithgaz)
    return result     

In [12]:
"""train_raw_data = open("train.txt","r").read()
dev_raw_data = open("dev.txt","r").read()
test_raw_data = open("test_no_tag.txt","r",encoding="utf8").read()"""

'train_raw_data = open("train.txt","r").read()\ndev_raw_data = open("dev.txt","r").read()\ntest_raw_data = open("test_no_tag.txt","r",encoding="utf8").read()'

In [13]:
"""train_posts = preprocess(train_raw_data)
dev_posts = preprocess(dev_raw_data)
test_posts = preprocessnotag(test_raw_data)"""

'train_posts = preprocess(train_raw_data)\ndev_posts = preprocess(dev_raw_data)\ntest_posts = preprocessnotag(test_raw_data)'

In [14]:
with open('train_posts_with_gazetteer.data', 'rb') as filehandle:
    # read the data as binary data stream
    train_posts_with_gazetteer = pickle.load(filehandle)   
with open('dev_posts_with_gazetteer.data', 'rb') as filehandle:
    # read the data as binary data stream
    dev_posts_with_gazetteer = pickle.load(filehandle)       
with open('test_posts_with_gazetteer.data', 'rb') as filehandle:
    # read the data as binary data stream
    test_posts_with_gazetteer = pickle.load(filehandle)

In [15]:
train_posts_with_gazetteer[0][0]

('@SammieLynnsMom', 'O', False, False, False, False, False, False, False)

In [16]:
train_posts_with_gazetteer = removeDuplicate(train_posts_with_gazetteer)   
dev_posts_with_gazetteer = removeDuplicate(dev_posts_with_gazetteer)

train_posts_with_gazetteer =  removePunctuationAsNER(train_posts_with_gazetteer)  
dev_posts_with_gazetteer =  removePunctuationAsNER(dev_posts_with_gazetteer)    

In [17]:
"""%time
print(datetime.now().strftime("%Y%m%d_%H%M"))
train_posts_with_gazetteer = addgazetteer(train_posts)
with open('train_posts_with_gazetteer'+datetime.now().strftime("%Y%m%d_%H%M")+'.data', 'wb') as filehandle:
    # store the data as binary data stream
    pickle.dump(train_posts_with_gazetteer, filehandle)
print(datetime.now().strftime("%Y%m%d_%H%M"))"""

'%time\nprint(datetime.now().strftime("%Y%m%d_%H%M"))\ntrain_posts_with_gazetteer = addgazetteer(train_posts)\nwith open(\'train_posts_with_gazetteer\'+datetime.now().strftime("%Y%m%d_%H%M")+\'.data\', \'wb\') as filehandle:\n    # store the data as binary data stream\n    pickle.dump(train_posts_with_gazetteer, filehandle)\nprint(datetime.now().strftime("%Y%m%d_%H%M"))'

In [18]:
"""%time
print(datetime.now().strftime("%Y%m%d_%H%M"))
dev_posts_with_gazetteer = addgazetteer(dev_posts)
with open('dev_posts_with_gazetteer'+datetime.now().strftime("%Y%m%d_%H%M")+'.data', 'wb') as filehandle:
    # store the data as binary data stream
    pickle.dump(dev_posts_with_gazetteer, filehandle)
print(datetime.now().strftime("%Y%m%d_%H%M"))"""

'%time\nprint(datetime.now().strftime("%Y%m%d_%H%M"))\ndev_posts_with_gazetteer = addgazetteer(dev_posts)\nwith open(\'dev_posts_with_gazetteer\'+datetime.now().strftime("%Y%m%d_%H%M")+\'.data\', \'wb\') as filehandle:\n    # store the data as binary data stream\n    pickle.dump(dev_posts_with_gazetteer, filehandle)\nprint(datetime.now().strftime("%Y%m%d_%H%M"))'

In [19]:
"""%time
print(datetime.now().strftime("%Y%m%d_%H%M"))
test_posts_with_gazetteer = addgazetteer(test_posts)
with open('test_posts_with_gazetteer'+datetime.now().strftime("%Y%m%d_%H%M")+'.data', 'wb') as filehandle:
    # store the data as binary data stream
    pickle.dump(test_posts_with_gazetteer, filehandle)
print(datetime.now().strftime("%Y%m%d_%H%M"))"""

'%time\nprint(datetime.now().strftime("%Y%m%d_%H%M"))\ntest_posts_with_gazetteer = addgazetteer(test_posts)\nwith open(\'test_posts_with_gazetteer\'+datetime.now().strftime("%Y%m%d_%H%M")+\'.data\', \'wb\') as filehandle:\n    # store the data as binary data stream\n    pickle.dump(test_posts_with_gazetteer, filehandle)\nprint(datetime.now().strftime("%Y%m%d_%H%M"))'

In [20]:
def word2features(postwithgaz,i, postag, fullpost):
    word, label,persongaz, locationgaz, productgaz, titlegaz, groupgaz, othergaz,companygaz = postwithgaz[i]
    
    features = {
        'bias': 1.0,
       #'word.lower()': word.lower(),
        #'word[-6:]': word[-6:],
        #'word[-5:]': word[-5:],
        'word[-4:]': word[-4:],
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],
        'word[:4]': word[:4],
        'word[:3]': word[:3],
        'word[:2]': word[:2],
        'len(word)': len(word),
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),
        'word.isalpha()': word.isalpha(),
        #'word.isalnum()': word.isalnum(),
        'isHashTagUserName(word)':isHashtagUserName(word),
        'istriggerword(word)':istriggerword(word),
        'isAnyDigit(word)':isAnyDigit(word),
        'isPostUpper(post)':fullpost.isupper(),
       # 'isStopWord(word)':isStopWord(word),
       # 'isAbbr(word)':isAbbr(word),
        'shape(word)':shape(word),
        'isURL(word)':isURL(word),
        'postag': postag[i][1],
        'postag[:2]': postag[i][1][:2],
       # 'isPlace(word)':isPlace(word),
        'lemma':wordnet_lemmatizer.lemmatize(word),
        'stem.portar':porter.stem(word),
        'stem.lancaster':lancaster.stem(word),
        'word.isromannum':isromannum(word),
        #'word.haspunctuation':haspunctuation(word),
        'word.ispunctuation':ispunctuation(word),
        
        #'sentpattern':sentpattern,
        'person.gazetteer':persongaz,
        'company.gazetteer':companygaz,
        'location.gazetteer':locationgaz,
        'product.gazetteer':productgaz,
        'title.gazetteer':titlegaz,
        'group.gazetteer':groupgaz,
        'other.gazetteer':othergaz,
    }
    if i > 0:
        word1 = postwithgaz[i-1][0]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            '-1:word.isdigit()': word1.isdigit(),
            '-1:word.isalpha()': word1.isalpha(),
            #'-1:isAnyDigit(word)':isAnyDigit(word1),
            '-1:istriggerword(word)':istriggerword(word1),
            '-1:shape(word)':shape(word1),
            '-1:isURL(word)':isURL(word1),
            '-1:word[-4:]': word1[-4:],
            '-1:word[-3:]': word1[-3:],
            '-1:word[-2:]': word1[-2:],
            '-1:word[:4]': word1[:4],
            '-1:word[:3]': word1[:3],
            '-1:word[:2]': word1[:2],
            '-1:len(word)': len(word1),
            #'-1:isHashTagUserName(word)':isHashtagUserName(word1),
            '-1:postag': postag[i-1][1],
            '-1:postag[:2]': postag[i-1][1][:2]
        })
        if i>1:
            word2 = postwithgaz[i-2][0]
            features.update({
                '-2:word.lower()': word2.lower(),
                '-2:word.istitle()': word2.istitle(),
                '-2:word.isupper()': word2.isupper(),
                '-2:word.isdigit()': word2.isdigit(),
                '-2:word.isalpha()': word2.isalpha(),
                '-2:istriggerword(word)':istriggerword(word2),
                '-2:postag': postag[i-2][1],
                '-2:postag[:2]': postag[i-2][1][:2],
                
            })
            """if i>2:
                word3 = postwithgaz[i-3][0]
                features.update({
                '-3:word.lower()': word3.lower(),
                '-3:word.istitle()': word3.istitle(),
                '-3:word.isupper()': word3.isupper(),
                '-3:word.isdigit()': word3.isdigit(),
                '-3:word.isalpha()': word3.isalpha(),
                '-3:postag': postag[i-3][1],
                '-3:postag[:2]': postag[i-3][1][:2],
                })""" 
    else:
        features['BOS'] = True

    if i < len(postwithgaz)-1:
        word1 = postwithgaz[i+1][0]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            '+1:word.isdigit()': word1.isdigit(),
            '+1:word.isalpha()': word1.isalpha(),
            #'+1:isAnyDigit(word)':isAnyDigit(word1),
            '+1:shape(word)':shape(word1),
            '+1:isURL(word)':isURL(word1),
            '+1:word[-4:]': word1[-4:],
            '+1:word[-3:]': word1[-3:],
            '+1:word[-2:]': word1[-2:],
            '+1:word[:4]': word1[:4],
            '+1:word[:3]': word1[:3],
            '+1:word[:2]': word1[:2],
            '+1:len(word)': len(word1),
            #'+1:isHashTagUserName(word)':isHashtagUserName(word1),
            '+1:postag': postag[i+1][1],
            '+1:postag[:2]': postag[i+1][1][:2]
        })
        if i < len(postwithgaz) - 2:
            word2 = postwithgaz[i+2][0]
            features.update({
                '+2:word.lower()': word2.lower(),
                '+2:word.istitle()': word2.istitle(),
                '+2:word.isupper()': word2.isupper(),
                '+2:word.isdigit()': word2.isdigit(),
                '+2:word.isalpha()': word2.isalpha(),
                '+2:postag': postag[i+2][1],
                '+2:postag[:2]': postag[i+2][1][:2],
            })
            """if i < len(postwithgaz) - 3:
                word3= postwithgaz[i+3][0]
                features.update({
                '+3:word.lower()': word3.lower(),
                '+3:word.istitle()': word3.istitle(),
                '+3:word.isupper()': word3.isupper(),
                '+3:word.isdigit()': word3.isdigit(),
                '+3:word.isalpha()': word3.isalpha(),
                '+3:postag': postag[i+3][1],
                '+3:postag[:2]': postag[i+3][1][:2],
                })"""
    else:
        features['EOS'] = True

    return features


def post2features(postwithgaz):
    postag = sentenceTagGazetteer(postwithgaz)
    fullpost = poststringgazetteer(postwithgaz)
    #print("post",post)
    #print("postag in post2features: ",postag)
    return [word2features(postwithgaz, i, postag, fullpost) for i in range(len(postwithgaz))]

def post2labels(postwithgaz):
    return [label for word, label,persongaz, locationgaz, productgaz, titilegaz, groupgaz, othergaz,companygaz in postwithgaz]


In [21]:
%%time
post2features(train_posts_with_gazetteer[5])[5]

Wall time: 3.3 s


{'bias': 1.0,
 'word[-4:]': 'if',
 'word[-3:]': 'if',
 'word[-2:]': 'if',
 'word[:4]': 'if',
 'word[:3]': 'if',
 'word[:2]': 'if',
 'len(word)': 2,
 'word.isupper()': False,
 'word.istitle()': False,
 'word.isdigit()': False,
 'word.isalpha()': True,
 'isHashTagUserName(word)': False,
 'istriggerword(word)': False,
 'isAnyDigit(word)': False,
 'isPostUpper(post)': False,
 'shape(word)': 'xx',
 'isURL(word)': False,
 'postag': 'IN',
 'postag[:2]': 'IN',
 'lemma': 'if',
 'stem.portar': 'if',
 'stem.lancaster': 'if',
 'word.isromannum': False,
 'word.ispunctuation': False,
 'person.gazetteer': False,
 'company.gazetteer': False,
 'location.gazetteer': False,
 'product.gazetteer': False,
 'title.gazetteer': False,
 'group.gazetteer': False,
 'other.gazetteer': False,
 '-1:word.lower()': 'this',
 '-1:word.istitle()': False,
 '-1:word.isupper()': False,
 '-1:word.isdigit()': False,
 '-1:word.isalpha()': True,
 '-1:istriggerword(word)': False,
 '-1:shape(word)': 'xxxx',
 '-1:isURL(word)': Fal

In [22]:
%%time
post2labels(train_posts_with_gazetteer[5])[5]

Wall time: 0 ns


'O'

In [23]:
%%time
X_train = [post2features(s) for s in train_posts_with_gazetteer]
y_train = [post2labels(s) for s in train_posts_with_gazetteer]

Wall time: 11.6 s


In [24]:
print(datetime.now().strftime("%Y%m%d_%H%M"))

20200211_2257


In [25]:
%%time
X_dev = [post2features(s) for s in dev_posts_with_gazetteer]
y_dev = [post2labels(s) for s in dev_posts_with_gazetteer]

Wall time: 4.3 s


In [26]:
print(datetime.now().strftime("%Y%m%d_%H%M"))

20200211_2257


In [27]:
%%time
X_test = [post2features(s) for s in test_posts_with_gazetteer]

Wall time: 12.4 s


In [28]:
print(datetime.now().strftime("%Y%m%d_%H%M"))

20200211_2257


In [29]:
%%time
##(0.15,0.2)
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs',
    c1=0.15,
    c2=0.2,
    max_iterations=100,
    all_possible_transitions=True
)
crf.fit(X_train, y_train)

Wall time: 42.2 s




CRF(algorithm='lbfgs', all_possible_states=None, all_possible_transitions=True,
    averaging=None, c=None, c1=0.15, c2=0.2, calibration_candidates=None,
    calibration_eta=None, calibration_max_trials=None, calibration_rate=None,
    calibration_samples=None, delta=None, epsilon=None, error_sensitive=None,
    gamma=None, keep_tempfiles=None, linesearch=None, max_iterations=100,
    max_linesearch=None, min_freq=None, model_filename=None, num_memories=None,
    pa_type=None, period=None, trainer_cls=None, variance=None, verbose=False)

In [30]:
filename = 'crf_lbfgs_gazetteer_'+datetime.now().strftime("%Y%m%d_%H%M")+'.sav'
pickle.dump(crf, open(filename, 'wb'))

In [31]:
labels = list(crf.classes_)
labels.remove('O')
#labels

In [32]:
#0.9397331037451572

y_pred = crf.predict(X_dev)
metrics.flat_f1_score(y_dev, y_pred,
                      average='weighted', labels=labels)
#metrics.flat_accuracy_score(y_dev, y_pred)

0.312360562076091

In [33]:
test_pred = crf.predict(X_test)

In [36]:
def savepredictions(test_posts, test_pred, filename):
    file = open(filename,"w", encoding="utf8")
    j=0
    for post in test_posts:
        predpostlabel = test_pred[j]
        #print(post)
        for i in range(len(post)):
            word = post[i][0]
            
            predlabel = predpostlabel[i]
          #  print(word, predlabel)
            file.write(word+" "+predlabel+"\n")
            
        #if j>2:
         #   break
        file.write("\n")    
        j+=1
    file.close()  
    
    
savepredictions(test_posts_with_gazetteer, test_pred, filename="test_prediction_"+datetime.now().strftime("%Y%m%d_%H%M")+".txt")        

In [37]:
"""precision    recall  f1-score   support

      person      0.588     0.429     0.496       266
       title      0.400     0.062     0.108        32
    location      0.616     0.430     0.506       235
     company      0.462     0.122     0.194        49
     product      0.350     0.044     0.079       158
       group      0.333     0.025     0.047       159
       other      0.214     0.183     0.198       229

   micro avg      0.457     0.245     0.319      1128
   macro avg      0.423     0.185     0.232      1128
weighted avg      0.438     0.245     0.292      1128"""
# group B and I results
sorted_labels = sorted(
    labels,
    key=lambda name: (name[1:], name[0])
)
print(metrics.flat_classification_report(
    y_dev, y_pred, labels=sorted_labels, digits=3
))

              precision    recall  f1-score   support

   B-company      0.636     0.179     0.280        39
   I-company      0.500     0.100     0.167        10
     B-group      0.750     0.030     0.058       100
     I-group      0.000     0.000     0.000        43
  B-location      0.602     0.397     0.479       141
  I-location      0.622     0.418     0.500        67
     B-other      0.370     0.153     0.216       131
     I-other      0.179     0.215     0.195        93
    B-person      0.614     0.424     0.502       165
    I-person      0.605     0.634     0.619        82
   B-product      0.800     0.129     0.222        31
   I-product      0.500     0.028     0.053        72
     B-title      0.200     0.062     0.095        16
     I-title      0.200     0.083     0.118        12

   micro avg      0.490     0.264     0.343      1002
   macro avg      0.470     0.204     0.250      1002
weighted avg      0.513     0.264     0.312      1002



In [38]:
import seqeval.metrics  as seqevalmetrics
seqevalmetrics.f1_score(y_dev, y_pred)

0.33223322332233224

In [39]:
seqevalmetrics.accuracy_score(y_dev, y_pred)

0.9459168827882851

In [40]:
print(seqevalmetrics.classification_report(y_dev, y_pred, digits=3))

           precision    recall  f1-score   support

 location      0.559     0.369     0.444       141
    group      0.750     0.030     0.058       100
  company      0.636     0.179     0.280        39
   person      0.605     0.418     0.495       165
    other      0.315     0.130     0.184       131
  product      0.400     0.065     0.111        31
    title      0.200     0.062     0.095        16

micro avg      0.528     0.242     0.332       623
macro avg      0.538     0.242     0.305       623



In [None]:
# define fixed parameters and parameters to search
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs', 
    max_iterations=100, 
    all_possible_transitions=True
)
params_space = {
    'c1': scipy.stats.expon(scale=0.5),
    'c2': scipy.stats.expon(scale=0.05),
}

# use the same metric for evaluation //flat_f1_score, average='weighted', labels=labels
f1_scorer = make_scorer(metrics.flat_accuracy_score)

# search
rs = RandomizedSearchCV(crf, params_space, 
                        cv=3, 
                        verbose=1, 
                        n_jobs=-1, 
                        n_iter=50, 
                        scoring=f1_scorer)
rs.fit(X_train, y_train)

In [None]:
# crf = rs.best_estimator_
print('best params:', rs.best_params_)
print('best CV score:', rs.best_score_)
print('model size: {:0.2f}M'.format(rs.best_estimator_.size_ / 1000000))

In [None]:
#filename = "crf_best_estimator_3_fold_"+datetime.now().strftime("%Y%m%d_%H%M")++".sav"
#pickle.dump(rs.best_estimator_, open(filename, 'wb'))



In [None]:
crf_best = rs.best_estimator_

from sklearn.externals import joblib
joblib.dump(crf_best, "crf_best_estimator_3_fold_"+datetime.now().strftime("%Y%m%d_%H%M")+'.pkl')

y_pred = crf_best.predict(X_dev)
print(metrics.flat_classification_report(
    y_dev, y_pred, labels=sorted_labels, digits=3
))

In [None]:
test_pred_best = crf_best.predict(X_test)
savepredictions(test_posts, test_pred_best, filename="test_prediction_3_gazetteer_best_"+datetime.now().strftime("%Y%m%d_%H%M")+".txt")    

In [238]:
from collections import Counter
def print_state_features(state_features):
    for (attr, label), weight in state_features:
        print("%0.6f %-8s %s" % (weight, label, attr))    

print("Top positive:")
print_state_features(Counter(crf.state_features_).most_common(30))

print("\nTop negative:")
print_state_features(Counter(crf.state_features_).most_common()[-30:])

Top positive:
4.742212 O        word.ispunctuation
3.054434 O        bias
2.868522 O        BOS
2.801316 O        isHashTagUserName(word)
2.502802 O        shape(word):xx
2.479192 O        EOS
2.367502 B-company company.gazetteer
2.147256 O        word[-3:]:day
2.127316 B-product word[:2]:iP
1.965681 O        shape(word):x
1.897500 O        postag[:2]:PR
1.860392 B-location word[-3:]:nia
1.858355 B-person person.gazetteer
1.857844 O        shape(word):xxx
1.851765 O        shape(word):xxxxx
1.829012 I-person person.gazetteer
1.795681 O        word[-2:]:ed
1.701880 B-person word[:2]:Je
1.682853 B-other  word[-2:]:GP
1.665031 B-company word.lower():twitter
1.665031 B-company stem.portar:twitter
1.609379 B-other  word[-3:]:mas
1.601490 O        shape(word):Xx
1.565082 B-person word[:2]:Jo
1.564179 O        shape(word):X
1.560749 I-location -2:word.lower():at
1.555447 B-location shape(word):XX
1.550178 B-location word[-2:]:ia
1.516298 B-location word[:2]:CA
1.496681 B-product product.gazet

In [43]:
import eli5
eli5.show_weights(crf, top=(5,5), targets=['O','B-person', 'I-person'])



From \ To,O,B-person,I-person
O,2.277,0.714,-2.737
B-person,-0.031,-0.802,5.264
I-person,0.022,-0.561,2.228

Weight?,Feature,Unnamed: 2_level_0
Weight?,Feature,Unnamed: 2_level_1
Weight?,Feature,Unnamed: 2_level_2
+4.201,word.ispunctuation,
+3.259,BOS,
+3.047,EOS,
+2.585,isHashTagUserName(word),
+2.162,shape(word):xx,
… 5562 more positive …,… 5562 more positive …,
… 3825 more negative …,… 3825 more negative …,
-1.237,word[-2:]:GP,
-1.463,word[:2]:iP,
-1.648,+2:word.lower():by,

Weight?,Feature
+4.201,word.ispunctuation
+3.259,BOS
+3.047,EOS
+2.585,isHashTagUserName(word)
+2.162,shape(word):xx
… 5562 more positive …,… 5562 more positive …
… 3825 more negative …,… 3825 more negative …
-1.237,word[-2:]:GP
-1.463,word[:2]:iP
-1.648,+2:word.lower():by

Weight?,Feature
+1.657,person.gazetteer
+1.388,word[:2]:Je
+1.242,word[:2]:Jo
+1.176,word[-2:]:ie
+0.995,word[-2:]:en
… 3205 more positive …,… 3205 more positive …
… 313 more negative …,… 313 more negative …
-0.736,-1:postag:NNP
-0.770,+1:word[-2:]:ss
-0.849,shape(word):Xx

Weight?,Feature
+1.456,person.gazetteer
+0.932,word[-2:]:on
+0.814,word[-2:]:ey
+0.806,-1:word[:2]:An
+0.798,-1:word[:2]:Do
… 1358 more positive …,… 1358 more positive …
… 136 more negative …,… 136 more negative …
-0.503,+2:postag[:2]:.
-0.503,+2:postag:.
-0.585,shape(word):xxxx


In [56]:
eli5.show_weights(crf, top=(10,10), feature_re='^word\.is',
                  horizontal_layout=False, show=['targets'])



Weight?,Feature
4.201,word.ispunctuation
0.141,word.isromannum
0.067,word.isdigit()
0.041,word.isalpha()
-0.167,word.isupper()
-0.894,word.istitle()

Weight?,Feature
0.561,word.isalpha()
-0.52,word.isupper()
-0.552,word.istitle()

Weight?,Feature
0.049,word.istitle()
-0.018,word.isalpha()
-0.125,word.isupper()

Weight?,Feature
0.803,word.isupper()
0.193,word.istitle()
-0.013,word.isalpha()
-0.109,word.isromannum

Weight?,Feature
0.777,word.istitle()
0.001,word.isdigit()
0.0,word.isalpha()
-0.023,word.ispunctuation
-0.224,word.isupper()

Weight?,Feature
0.782,word.isupper()
0.523,word.istitle()
0.391,word.isromannum
0.045,word.isalpha()
-0.272,word.isdigit()

Weight?,Feature
0.094,word.istitle()
0.028,word.isalpha()
-0.218,word.isdigit()
-0.234,word.ispunctuation

Weight?,Feature
1.048,word.isupper()
0.008,word.isalpha()
-0.123,word.istitle()

Weight?,Feature
0.225,word.isromannum
0.156,word.isdigit()
0.037,word.istitle()
-0.009,word.isupper()
-0.013,word.isalpha()
-0.691,word.ispunctuation

Weight?,Feature
0.504,word.istitle()
-0.056,word.isalpha()
-0.242,word.isupper()
-0.582,word.isromannum

Weight?,Feature
0.327,word.isromannum
0.205,word.isupper()
0.103,word.istitle()
-0.029,word.isalpha()
-0.258,word.ispunctuation

Weight?,Feature
0.006,word.isalpha()
-0.137,word.isupper()
-0.168,word.isdigit()
-0.18,word.istitle()

Weight?,Feature
0.315,word.isdigit()
0.157,word.isupper()
-0.017,word.isalpha()
-0.018,word.istitle()

Weight?,Feature
0.285,word.istitle()
0.037,word.isupper()
-0.003,word.isalpha()

Weight?,Feature
0.11,word.isdigit()
0.0,word.istitle()
-0.002,word.isalpha()
-0.244,word.ispunctuation


In [52]:
eli5.show_weights(crf, top=(5,5), feature_re='.lower.',
                  horizontal_layout=False, show=['targets'])



Weight?,Feature
+0.959,-2:word.lower():wintor
+0.922,-2:word.lower():&amp;
+0.806,+2:word.lower():of
+0.720,-2:word.lower():all
+0.703,+2:word.lower():before
… 685 more positive …,… 685 more positive …
… 624 more negative …,… 624 more negative …
-0.858,+2:word.lower():next
-0.862,-2:word.lower():went
-0.886,-2:word.lower():big

Weight?,Feature
+0.913,-2:word.lower():updates
+0.766,-2:word.lower():update
+0.756,-2:word.lower():win
+0.743,-2:word.lower():nose
+0.732,-2:word.lower():back
… 231 more positive …,… 231 more positive …
… 5 more negative …,… 5 more negative …
-0.067,"+1:word.lower():,"
-0.086,-1:word.lower():the
-0.089,-2:word.lower():the

Weight?,Feature
+0.553,-2:word.lower():with
+0.278,-2:word.lower():port
+0.271,-1:word.lower():city
+0.259,-2:word.lower():deep
+0.249,-2:word.lower():.
… 82 more positive …,… 82 more positive …

Weight?,Feature
+0.921,-1:word.lower():go
+0.848,-2:word.lower():vs
+0.559,-2:word.lower():for
+0.557,+2:word.lower():tomorrow
+0.550,-2:word.lower():rt
… 206 more positive …,… 206 more positive …
… 1 more negative …,… 1 more negative …
-0.009,+2:word.lower():is
-0.095,-1:word.lower():the
-0.108,+1:word.lower():.

Weight?,Feature
+0.782,-1:word.lower():dj
+0.635,-2:word.lower():kings
+0.544,-1:word.lower():green
+0.523,-2:word.lower():.
+0.442,-1:word.lower():la
… 148 more positive …,… 148 more positive …
-0.000,-2:word.lower():a
-0.003,+1:word.lower():.
-0.108,-2:word.lower()::

Weight?,Feature
+1.181,-1:word.lower():in
+1.176,-2:word.lower():in
+0.861,+2:word.lower():where
+0.736,+1:word.lower():york
+0.720,-1:word.lower():at
… 383 more positive …,… 383 more positive …
… 31 more negative …,… 31 more negative …
-0.198,-1:word.lower():with
-0.207,"+2:word.lower():,"
-0.212,+2:word.lower():(

Weight?,Feature
+1.596,-2:word.lower():at
+0.804,-1:word.lower():new
+0.578,-2:word.lower():the
+0.576,+2:word.lower():game
+0.440,+1:word.lower():in
… 187 more positive …,… 187 more positive …
-0.030,+1:word.lower():(
-0.066,-2:word.lower()::
-0.073,+1:word.lower()::
-0.205,-2:word.lower():from

Weight?,Feature
+0.765,-2:word.lower():be
+0.680,-2:word.lower():today
+0.639,+2:word.lower():voters
+0.587,+2:word.lower():fashion
+0.566,+1:word.lower():kippur
… 303 more positive …,… 303 more positive …
… 9 more negative …,… 9 more negative …
-0.075,"+1:word.lower():,"
-0.171,-1:word.lower():of
-0.193,-2:word.lower():the

Weight?,Feature
+0.670,"-2:word.lower():"""
+0.636,-2:word.lower():until
+0.581,-1:word.lower():fashion
+0.512,+1:word.lower():september
+0.480,+2:word.lower():time
… 352 more positive …,… 352 more positive …
… 9 more negative …,… 9 more negative …
-0.026,+2:word.lower():...
-0.075,"+2:word.lower():,"
-0.089,-2:word.lower():from

Weight?,Feature
+0.771,+2:word.lower():had
+0.765,+2:word.lower():coming
+0.656,-2:word.lower():girl
+0.645,+2:word.lower():as
+0.616,-1:word.lower():silly
… 449 more positive …,… 449 more positive …
… 22 more negative …,… 22 more negative …
-0.366,+2:word.lower():at
-0.369,+2:word.lower():.
-0.378,+2:word.lower():just

Weight?,Feature
+0.474,+2:word.lower():gwen
+0.354,-2:word.lower()::
+0.323,-2:word.lower():by
+0.317,+2:word.lower():your
+0.316,-1:word.lower():justin
… 177 more positive …,… 177 more positive …
… 5 more negative …,… 5 more negative …
-0.066,-2:word.lower():and
-0.076,+1:word.lower():tonight
-0.077,+2:word.lower():in

Weight?,Feature
+0.976,-2:word.lower():antivirus
+0.668,-2:word.lower():#1
+0.632,+2:word.lower():for
+0.504,+2:word.lower():baby
+0.466,-2:word.lower():chicago
… 180 more positive …,… 180 more positive …
-0.022,-1:word.lower():to
-0.029,"-1:word.lower():,"
-0.086,-1:word.lower():the

Weight?,Feature
+0.782,+2:word.lower():by
+0.355,-2:word.lower():like
+0.270,-1:word.lower():after
+0.268,-1:word.lower():club
+0.249,-2:word.lower():with
… 139 more positive …,… 139 more positive …
-0.000,-2:word.lower():the
-0.000,+1:word.lower()::
-0.080,+2:word.lower():in

Weight?,Feature
+0.809,-1:word.lower():watch
+0.701,-2:word.lower():watching
+0.653,-2:word.lower():!
+0.576,-2:word.lower():story
+0.454,+2:word.lower():really
… 134 more positive …,… 134 more positive …
-0.000,-1:word.lower():.
-0.055,"+2:word.lower():"""
-0.076,-1:word.lower()::

Weight?,Feature
+0.517,+2:word.lower():i
+0.510,+2:word.lower():next
+0.489,+2:word.lower():week
+0.378,+2:word.lower():will
+0.355,+2:word.lower():back
… 142 more positive …,… 142 more positive …
-0.002,-2:word.lower():the
-0.005,"+2:word.lower():,"
-0.048,+2:word.lower():.
-0.069,+1:word.lower():.


In [53]:
eli5.show_weights(crf, top=(5,5), feature_re='.gazetteer',
                  horizontal_layout=False, show=['targets'])



Weight?,Feature
0.46,title.gazetteer
0.054,other.gazetteer
-0.257,location.gazetteer
-0.431,group.gazetteer
-0.542,company.gazetteer
-0.577,product.gazetteer
-0.868,person.gazetteer

Weight?,Feature
2.27,company.gazetteer
-0.132,title.gazetteer
-0.78,location.gazetteer
-0.983,person.gazetteer

Weight?,Feature
0.169,company.gazetteer
0.155,other.gazetteer
-0.007,person.gazetteer

Weight?,Feature
0.979,group.gazetteer
0.57,title.gazetteer
0.018,location.gazetteer
0.013,product.gazetteer
-0.309,person.gazetteer

Weight?,Feature
0.531,group.gazetteer
0.013,location.gazetteer
-0.019,person.gazetteer
-0.147,company.gazetteer

Weight?,Feature
1.03,location.gazetteer
0.319,person.gazetteer
-0.0,group.gazetteer
-0.007,other.gazetteer
-0.117,company.gazetteer
-0.606,title.gazetteer

Weight?,Feature
0.507,person.gazetteer
0.311,location.gazetteer
-0.051,title.gazetteer
-0.085,other.gazetteer
-0.131,company.gazetteer
-0.211,product.gazetteer

Weight?,Feature
-0.111,title.gazetteer
-0.159,person.gazetteer
-0.389,product.gazetteer
-0.532,location.gazetteer

Weight?,Feature
0.648,title.gazetteer
0.521,company.gazetteer
0.129,location.gazetteer
0.078,person.gazetteer
0.039,product.gazetteer

Weight?,Feature
1.657,person.gazetteer
0.689,location.gazetteer
0.041,title.gazetteer
-0.164,other.gazetteer
-0.255,product.gazetteer

Weight?,Feature
1.456,person.gazetteer
0.371,location.gazetteer
0.226,group.gazetteer
0.157,product.gazetteer
-0.249,title.gazetteer

Weight?,Feature
1.362,product.gazetteer
-0.0,title.gazetteer
-0.583,person.gazetteer
-0.68,location.gazetteer

Weight?,Feature
0.285,title.gazetteer
0.201,product.gazetteer
0.115,location.gazetteer
0.024,other.gazetteer
-0.012,person.gazetteer

Weight?,Feature
-0.0,product.gazetteer
-0.005,person.gazetteer
-0.064,location.gazetteer

Weight?,Feature
0.892,other.gazetteer
0.142,location.gazetteer
0.075,person.gazetteer


In [54]:
from collections import Counter

def print_transitions(trans_features):
    for (label_from, label_to), weight in trans_features:
        print("%-6s -> %-7s %0.6f" % (label_from, label_to, weight))

print("Top likely transitions:")
print_transitions(Counter(crf.transition_features_).most_common(20))

print("\nTop unlikely transitions:")
print_transitions(Counter(crf.transition_features_).most_common()[-20:])

Top likely transitions:
B-person -> I-person 5.263763
B-other -> I-other 4.915519
I-other -> I-other 4.832353
B-product -> I-product 4.726522
B-group -> I-group 4.683012
B-title -> I-title 4.511986
B-location -> I-location 4.340661
I-title -> I-title 4.299823
I-group -> I-group 3.742417
I-product -> I-product 3.614964
I-location -> I-location 3.407666
B-company -> I-company 3.343142
I-company -> I-company 2.508343
O      -> O       2.276834
I-person -> I-person 2.227529
O      -> B-person 0.714408
O      -> B-group 0.713942
O      -> B-product 0.638069
O      -> B-company 0.606045
O      -> B-title 0.474205

Top unlikely transitions:
B-company -> I-person -0.449016
B-location -> I-person -0.462645
I-location -> I-other -0.499200
B-company -> I-other -0.505813
B-location -> I-group -0.532597
I-person -> B-person -0.560622
B-person -> I-other -0.569272
B-person -> I-title -0.590095
B-other -> O       -0.663970
B-group -> O       -0.801635
B-person -> B-person -0.801665
B-location -> I-ot

In [55]:
def print_state_features(state_features):
    for (attr, label), weight in state_features:
        print("%0.6f %-8s %s" % (weight, label, attr))

print("Top positive:")
print_state_features(Counter(crf.state_features_).most_common(30))

print("\nTop negative:")
print_state_features(Counter(crf.state_features_).most_common()[-30:])

Top positive:
4.201385 O        word.ispunctuation
3.258766 O        BOS
3.046755 O        EOS
2.584555 O        isHashTagUserName(word)
2.270356 B-company company.gazetteer
2.161867 O        shape(word):xx
1.856841 O        shape(word):xxx
1.854384 B-product word[:2]:iP
1.801530 O        word[-3:]:day
1.738099 O        shape(word):xxxxx
1.735327 O        shape(word):x
1.730097 O        shape(word):Xx
1.716465 B-company stem.portar:twitter
1.657414 B-person person.gazetteer
1.595931 I-location -2:word.lower():at
1.576426 B-location word[-3:]:nia
1.543778 O        word[-2:]:ed
1.529740 B-other  word[-3:]:mas
1.468842 B-other  word[-2:]:GP
1.455897 I-person person.gazetteer
1.433209 O        bias
1.430280 B-location shape(word):XX
1.424896 B-company stem.lancaster:twit
1.391891 B-location stem.lancaster:uk
1.388768 O        word[-2:]:me
1.387976 B-person word[:2]:Je
1.361885 B-product product.gazetteer
1.354040 B-company stem.portar:facebook
1.354040 B-company stem.lancaster:facebook
1.3