# Import packages and load dataset

In [None]:
from datasets import load_dataset
import nltk as nltk
from nltk.tag import pos_tag
from nltk.tag import CRFTagger
import numpy as np
import re, unicodedata


nltk.download('averaged_perceptron_tagger')
dataset = load_dataset(
    "tner/bionlp2004", 
    cache_dir='./data_cache'
)

print(f'The dataset is a dictionary with {len(dataset)} splits: \n\n{dataset}')

# Formatting the dataset splits

In [29]:
# Formatting the dataset splits for the tagger.

train_sentences_ner = [item['tokens'] for item in dataset['train']]
train_labels_ner = [[str(tag) for tag in item['tags']] for item in dataset['train']]

val_sentences_ner = [item['tokens'] for item in dataset['validation']]
val_labels_ner = [[str(tag) for tag in item['tags']] for item in dataset['validation']]

test_sentences_ner = [item['tokens'] for item in dataset['test']]
test_labels_ner = [[str(tag) for tag in item['tags']] for item in dataset['test']]

In [48]:
print('Number of training sentences = {}'.format(len(train_sentences_ner)))
print('Number of validation sentences = {}'.format(len(val_sentences_ner)))
print('Number of test sentences = {}'.format(len(test_sentences_ner)))

Number of training sentences = 16619
Number of validation sentences = 1927
Number of test sentences = 3856


In [50]:
print('What does one instance look like from the training set? \n\n{}'.format(train_sentences_ner[101]))
print('Corresponding label: \n\n{}'.format(train_labels_ner[101]))

What does one instance look like from the training set? 

['Normal', 'T', 'lymphocytes', 'whose', 'surface', 'expression', 'of', 'CD3', 'was', 'depleted', 'showed', 'impaired', 'UV-induced', 'tyrosine', 'phosphorylation', 'and', 'Ca2+', 'signals', '.']
Corresponding label: 

['0', '5', '6', '0', '0', '0', '0', '3', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0']


In [51]:
print('Number of unique labels: {}'.format(np.unique(np.concatenate(train_labels_ner))))

Number of unique labels: ['0' '1' '10' '2' '3' '4' '5' '6' '7' '8' '9']


In [52]:
# mapping from labels to the tags

all_labels = {
    "O": 0,
    "B-DNA": 1,
    "I-DNA": 2,
    "B-protein": 3,
    "I-protein": 4,
    "B-cell_type": 5,
    "I-cell_type": 6,
    "B-cell_line": 7,
    "I-cell_line": 8,
    "B-RNA": 9,
    "I-RNA": 10
}

mapping = {value:key for key, value in all_labels.items()}
print(mapping)

{0: 'O', 1: 'B-DNA', 2: 'I-DNA', 3: 'B-protein', 4: 'I-protein', 5: 'B-cell_type', 6: 'I-cell_type', 7: 'B-cell_line', 8: 'I-cell_line', 9: 'B-RNA', 10: 'I-RNA'}


In [34]:
train_set = [list(zip(train_sentences_ner[index],[mapping[int(i)]for i in train_labels_ner[index]]))for index, sentence in enumerate(train_sentences_ner)]

val_set = [list(zip(val_sentences_ner[index],[mapping[int(i)]for i in val_labels_ner[index]]))for index, sentence in enumerate(val_sentences_ner)]
val_tokens = [tok for tok in val_sentences_ner]
val_tags = [[mapping[int(i)]for i in item] for item in val_labels_ner]

test_set = [list(zip(test_sentences_ner[index],[mapping[int(i)]for i in test_labels_ner[index]]))for index, sentence in enumerate(test_sentences_ner)]
test_tokens = [tok for tok in test_sentences_ner]

test_tags = [[mapping[int(i)]for i in item] for item in test_labels_ner]
print(val_set[0])

[('IL-2', 'B-DNA'), ('gene', 'I-DNA'), ('expression', 'O'), ('and', 'O'), ('NF-kappa', 'B-protein'), ('B', 'I-protein'), ('activation', 'O'), ('through', 'O'), ('CD28', 'B-protein'), ('requires', 'O'), ('reactive', 'O'), ('oxygen', 'O'), ('production', 'O'), ('by', 'O'), ('5-lipoxygenase', 'B-protein'), ('.', 'O')]


In [35]:
model = CRFTagger(verbose= True)
model.train(train_set,'model.crf.my_tagger')

Feature generation
type: CRF1d
feature.minfreq: 0.000000
feature.possible_states: 0
feature.possible_transitions: 0
0....1....2....3....4....5....6....7....8....9....10
Number of features: 41166
Seconds required: 0.243

L-BFGS optimization
c1: 0.000000
c2: 1.000000
num_memories: 6
max_iterations: 2147483647
epsilon: 0.000010
stop: 10
delta: 0.000010
linesearch: MoreThuente
linesearch.max_iterations: 20

***** Iteration #1 *****
Loss: 670701.292937
Feature norm: 5.000000
Error norm: 153439.508663
Active features: 41166
Line search trials: 2
Line search step: 0.000016
Seconds required for this iteration: 0.740

***** Iteration #2 *****
Loss: 450069.016166
Feature norm: 3.656978
Error norm: 128870.363823
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.257

***** Iteration #3 *****
Loss: 371186.909816
Feature norm: 3.071075
Error norm: 51633.075776
Active features: 41166
Line search trials: 2
Line search step: 0.120890
Seconds 

***** Iteration #45 *****
Loss: 78921.787259
Feature norm: 56.154482
Error norm: 6459.665318
Active features: 41166
Line search trials: 2
Line search step: 0.421550
Seconds required for this iteration: 0.490

***** Iteration #46 *****
Loss: 78460.630137
Feature norm: 56.704766
Error norm: 3746.629139
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.242

***** Iteration #47 *****
Loss: 78012.679385
Feature norm: 57.474117
Error norm: 2405.390143
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.238

***** Iteration #48 *****
Loss: 77703.264361
Feature norm: 58.123588
Error norm: 2954.226033
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.251

***** Iteration #49 *****
Loss: 77253.919317
Feature norm: 59.055096
Error norm: 2654.172693
Active features: 41166
Line search trials: 1
Line search step: 1.000000

***** Iteration #92 *****
Loss: 65403.235706
Feature norm: 88.752375
Error norm: 1086.827478
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.258

***** Iteration #93 *****
Loss: 65192.893147
Feature norm: 89.695897
Error norm: 877.756609
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.244

***** Iteration #94 *****
Loss: 65135.569269
Feature norm: 89.988836
Error norm: 1777.675853
Active features: 41166
Line search trials: 2
Line search step: 0.203936
Seconds required for this iteration: 0.477

***** Iteration #95 *****
Loss: 65023.211767
Feature norm: 90.607369
Error norm: 1461.181196
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.253

***** Iteration #96 *****
Loss: 64802.971478
Feature norm: 91.473481
Error norm: 788.982954
Active features: 41166
Line search trials: 1
Line search step: 1.000000
S

***** Iteration #137 *****
Loss: 61717.559832
Feature norm: 95.411800
Error norm: 641.703506
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.240

***** Iteration #138 *****
Loss: 61681.578597
Feature norm: 95.473132
Error norm: 612.074900
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.251

***** Iteration #139 *****
Loss: 61632.357018
Feature norm: 95.572851
Error norm: 747.017782
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.264

***** Iteration #140 *****
Loss: 61608.649229
Feature norm: 95.745230
Error norm: 2348.120747
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.253

***** Iteration #141 *****
Loss: 61546.203281
Feature norm: 95.797042
Error norm: 838.778786
Active features: 41166
Line search trials: 1
Line search step: 1.00000

***** Iteration #177 *****
Loss: 60790.540770
Feature norm: 98.640710
Error norm: 284.364438
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.246

***** Iteration #178 *****
Loss: 60773.674785
Feature norm: 98.592766
Error norm: 354.428553
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.246

***** Iteration #179 *****
Loss: 60761.266564
Feature norm: 98.625826
Error norm: 262.811699
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.239

***** Iteration #180 *****
Loss: 60747.916512
Feature norm: 98.783004
Error norm: 858.351229
Active features: 41166
Line search trials: 2
Line search step: 0.412630
Seconds required for this iteration: 0.480

***** Iteration #181 *****
Loss: 60732.968479
Feature norm: 98.929606
Error norm: 230.201021
Active features: 41166
Line search trials: 1
Line search step: 1.000000

***** Iteration #226 *****
Loss: 60436.957960
Feature norm: 102.189405
Error norm: 162.461860
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.240

***** Iteration #227 *****
Loss: 60431.994141
Feature norm: 102.229023
Error norm: 200.825565
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.259

***** Iteration #228 *****
Loss: 60430.781840
Feature norm: 102.237557
Error norm: 432.012674
Active features: 41166
Line search trials: 2
Line search step: 0.213716
Seconds required for this iteration: 0.484

***** Iteration #229 *****
Loss: 60428.367415
Feature norm: 102.246457
Error norm: 299.646469
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.241

***** Iteration #230 *****
Loss: 60425.017424
Feature norm: 102.255769
Error norm: 137.257551
Active features: 41166
Line search trials: 1
Line search step: 1.0

***** Iteration #266 *****
Loss: 60351.982730
Feature norm: 102.962940
Error norm: 175.138001
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.252

***** Iteration #267 *****
Loss: 60351.153783
Feature norm: 102.961499
Error norm: 95.071355
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.248

***** Iteration #268 *****
Loss: 60350.044737
Feature norm: 102.970749
Error norm: 79.054395
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.244

***** Iteration #269 *****
Loss: 60349.177344
Feature norm: 102.991250
Error norm: 161.893382
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.261

***** Iteration #270 *****
Loss: 60348.080772
Feature norm: 103.032833
Error norm: 122.698086
Active features: 41166
Line search trials: 1
Line search step: 1.000

***** Iteration #311 *****
Loss: 60322.266265
Feature norm: 103.470652
Error norm: 50.913828
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.242

***** Iteration #312 *****
Loss: 60322.029045
Feature norm: 103.473678
Error norm: 50.150578
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.255

***** Iteration #313 *****
Loss: 60321.810351
Feature norm: 103.477890
Error norm: 60.050839
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.239

***** Iteration #314 *****
Loss: 60321.483813
Feature norm: 103.485522
Error norm: 50.531721
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.247

***** Iteration #315 *****
Loss: 60321.229654
Feature norm: 103.493157
Error norm: 156.605374
Active features: 41166
Line search trials: 1
Line search step: 1.00000

***** Iteration #361 *****
Loss: 60313.143963
Feature norm: 103.620500
Error norm: 44.639552
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.245

***** Iteration #362 *****
Loss: 60313.042984
Feature norm: 103.620456
Error norm: 24.727746
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.239

***** Iteration #363 *****
Loss: 60312.967499
Feature norm: 103.619827
Error norm: 23.945915
Active features: 41166
Line search trials: 1
Line search step: 1.000000
Seconds required for this iteration: 0.238

***** Iteration #364 *****
Loss: 60312.943016
Feature norm: 103.620885
Error norm: 54.389954
Active features: 41166
Line search trials: 2
Line search step: 0.174804
Seconds required for this iteration: 0.472

***** Iteration #365 *****
Loss: 60312.897681
Feature norm: 103.621655
Error norm: 39.088340
Active features: 41166
Line search trials: 1
Line search step: 1.000000

In [None]:
predicted_tags = model.tag_sents(val_tokens)
print(predicted_tags)

In [37]:
def span_fill(index, token, label, spans,start,id_):
    # Check if beginning
    if 'B-' in label:
        start = index
        ending = index + 1
        named_entity_type = label[2:]
    # check if inside
    elif 'I-' in label:
        ending = index + 1
    # check if not entity type
    elif start >= 0 and label == 'O':
        if named_entity_type not in spans:
            spans[named_entity_type] = []
        spans[named_entity_type].append((start, ending, id_))
        start = -1   
    if start >= 0:    
        if named_entity_type not in spans:
            spans[named_entity_type] = []
        spans[named_entity_type].append((start, ending, id_))

In [38]:
def score_printer(named_entity_types, true_spans, predicted_spans, F1_score_for_each_class):
    
    # Manually calculating F1, precision, recall. 
    for named_entity_type in named_entity_types:
        # We loop through all the named entity tpes
        # set TP, FN, and FP to zero.
        true_positive = 0
        false_positive = 0
        false_negative = 0
        
        for span in true_spans[named_entity_type]:
            # check if current true span not in the predicted spans
            if span not in predicted_spans[named_entity_type]:
                # If so...increment false negative value.
                false_negative = false_negative + 1
        
        for span in predicted_spans[named_entity_type]:
            # check if current predicted span in the true spans
            if span in true_spans[named_entity_type]:
                # If so, increment true positive val
                true_positive = true_positive + 1
            else:
                # otherwise increment false negative val
                false_positive = false_positive + 1       
        
            
        if true_positive + false_negative== 0:
            # set recall
            recall = 0
        else:
            # calculate recall using TP and FN
            recall = true_positive / float(true_positive + false_negative)

            
        if true_positive + false_positive == 0:
            # Set precision
            precision = 0
        else:
            # calculate precision using FP and TP
            precision = true_positive / float(false_positive + true_positive)
            

        if recall + precision == 0:
            # Set F1 score
            F1 = 0
        else:
            # Calculate F1 using precision and recall
            F1 = 2 * precision * recall / (precision + recall)
            

        F1_score_for_each_class.append(F1)
        print('F1 score for Class: {} = {}'.format(named_entity_type, F1))
        
    print('Macro averaged F1 score for all classes: {}'.format(np.mean(F1_score_for_each_class)))

In [39]:

def get_spans(tagged_sentences):
    # Create a dict to hold spans
    spans_dict = {}   
    for id_, sentence in enumerate(tagged_sentences):
        start = -1
        entity_type = None
        for index, (token, label) in enumerate(sentence):
            span_fill(index, token, label, spans_dict, start, id_)  
    return spans_dict

def get_f1_scores(test_sents, test_sents_with_pred):
    true_spans = get_spans(test_sents)
    predicted_spans = get_spans(test_sents_with_pred)
    # A list to hold F1 scores
    F1_score_for_each_class = []
    # Set named entity types
    named_entity_types = true_spans.keys()
    
    score_printer(named_entity_types, true_spans, predicted_spans, F1_score_for_each_class)
    

In [40]:
# Get F1 scores for validation set
get_f1_scores(val_set, predicted_tags)

F1 score for Class: DNA = 0.6493860845839017
F1 score for Class: protein = 0.7885906040268457
F1 score for Class: cell_type = 0.6825657894736842
F1 score for Class: cell_line = 0.6267605633802816
F1 score for Class: RNA = 0.7008547008547009
Macro averaged F1 score for all classes: 0.6896315484638829


In [41]:
# Now we add previous and next words as features.
class Current_next_previous_word_CRFTagger(CRFTagger):
    def _get_features(self, toks, i):
            tok = toks[i]
            # Get features from original method
            features = super()._get_features(toks,i)
            # Append the current word
            features.append("CURRENT_WORD" + tok)
            if i < len(toks)-1:
                # Append the next word
                features.append("NEXT_WORD_" + toks[i+1])
                # Append the previous word
            if i > 0:
                features.append("PREVIOUS_WORD_" + toks[i-1])
            return features
                

In [None]:
# Train a new model that uses the new features
model = Current_next_previous_word_CRFTagger(verbose=True)
model.train(train_set, 'model.crf.next_previous_word_CRFTagger')

In [43]:
predicted_tags = model.tag_sents(val_tokens)
get_f1_scores(val_set, predicted_tags)

F1 score for Class: DNA = 0.692885550154662
F1 score for Class: protein = 0.8195902048975513
F1 score for Class: cell_type = 0.7516233766233765
F1 score for Class: cell_line = 0.7122381477398015
F1 score for Class: RNA = 0.6942148760330579
Macro averaged F1 score for all classes: 0.7341104310896898


In [44]:
class POSBasedTagger(Current_next_previous_word_CRFTagger):
    _tokens = None
    def _get_features(self, toks, i):
        # Adding POS tags as a feature on top of the current features
        features = super()._get_features(toks,i)
        # Set Pos tagged toks
        if toks != self._tokens:
            self._pos_tagged_toks = pos_tag(toks)
            self._tokens = toks
        features.append(self._pos_tagged_toks[i][1])
        return features

In [None]:
model = POSBasedTagger(verbose=True)
model.train(train_set, 'model.crf.POS_Based_Tagger')

In [46]:
predicted_tags = model.tag_sents(val_tokens)
get_f1_scores(val_set, predicted_tags)

F1 score for Class: DNA = 0.6997802197802196
F1 score for Class: protein = 0.8212576332728172
F1 score for Class: cell_type = 0.7497975708502024
F1 score for Class: cell_line = 0.7130242825607065
F1 score for Class: RNA = 0.680672268907563
Macro averaged F1 score for all classes: 0.7329063950743018


In [47]:
# Now let's predict on unseen (test data) split to see how well it generalises.
predicted_tags = model.tag_sents(test_tokens)
get_f1_scores(test_set, predicted_tags)

F1 score for Class: protein = 0.7504676393565283
F1 score for Class: cell_type = 0.7271676300578035
F1 score for Class: DNA = 0.6768774703557312
F1 score for Class: cell_line = 0.5847373637264618
F1 score for Class: RNA = 0.6666666666666666
Macro averaged F1 score for all classes: 0.6811833540326383
