Importing libraries

In [49]:
from nltk import word_tokenize, pos_tag, ne_chunk, tree2conlltags, bigrams
from nltk.corpus import treebank
from sklearn.model_selection import train_test_split
import numpy as np
from collections import defaultdict
from tqdm import tqdm

Preparing data

In [46]:
all_sents = treebank.tagged_sents()
preprocessed_sents = []
for sent in tqdm(all_sents):
    preprocessed_sent = []
    ne = ne_chunk(sent)
    bio = tree2conlltags(ne)
    preprocessed_sent.append(("###", "START"))
    for (word, _, tag) in bio:
        preprocessed_sent.append((word, tag))
    preprocessed_sent.append(("&&&", "END"))
    preprocessed_sents.append(preprocessed_sent)
np.random.seed(42)
train_data, test_data = train_test_split(preprocessed_sents, train_size=0.85, shuffle=True)
all_tags = set()
all_words = set()
for sent in train_data:
    for (word, tag) in sent:
        all_tags.add(tag)
        all_words.add(word)
all_tags = list(all_tags)
all_words = list(all_words)

100%|██████████| 3914/3914 [00:36<00:00, 107.38it/s]


Calculate emission probabilities

In [47]:
epsilon = 0.0001
default_dict = {word.lower(): epsilon for word in all_words}
emission_freq = defaultdict(lambda: default_dict.copy())
for sent in train_data:
    for (word, tag) in sent:
        word = word.lower()
        emission_freq[tag][word] += 1

default_dict = defaultdict(lambda: epsilon)
emission_prob = defaultdict(lambda: default_dict.copy())
for tag in emission_freq.keys():
    count = sum(emission_freq[tag].values())
    for word, freq in emission_freq[tag].items():
        emission_prob[tag][word] = freq / count

Calculate transition probabilities

In [50]:
default_dict = {tag: epsilon for tag in all_tags}
transition_freq = defaultdict(lambda: default_dict.copy())
for sent in train_data:
    sent_bigrams = list(bigrams(sent))
    for (word1, tag1), (word2, tag2) in sent_bigrams:
        transition_freq[tag1][tag2] += 1


transition_prob = defaultdict(dict)
for tag1 in transition_freq.keys():
    count = sum(transition_freq[tag1].values())
    for tag2, freq in transition_freq[tag1].items():
        transition_prob[tag1][tag2] = transition_freq[tag1][tag2] / count


Prepare test dataset

In [51]:
test_sents = []
test_true_tags = []
test_pred_tags = []

for sent in test_data:
    sent_words = []
    sent_tags = []
    del sent[0]
    del sent[-1]
    for (word, tag) in sent:
        sent_words.append(word.lower())
        sent_tags.append(tag)
    test_sents.append(sent_words)
    test_true_tags.append(sent_tags)

Viterbi algorithm

In [52]:
transition_prob['END'] = {tag: 0 for tag in all_tags}
epsilon = 0.00001
num_of_tags = len(all_tags)
for sent in tqdm(test_sents):
    viterbi = np.zeros((num_of_tags, len(sent))) + epsilon
    backpointer = np.zeros((num_of_tags, len(sent)))
    for i, tag in enumerate(all_tags):
        viterbi[i][0] = transition_prob['START'][tag] * emission_prob[tag][sent[0]]
    for t in range(1, len(sent)):
        for s, current_tag in enumerate(all_tags):
            viterbi_row = []
            backpointer_row = []
            for s_prime, prev_tag in enumerate(all_tags):
                viterbi_row.append(viterbi[s_prime][t - 1] * transition_prob[all_tags[s_prime]][all_tags[s]] * emission_prob[all_tags[s]][sent[t]])
                backpointer_row.append(viterbi[s_prime][t - 1] * transition_prob[all_tags[s_prime]][all_tags[s]])
            viterbi[s][t] = max(viterbi_row)
            backpointer[s][t] = np.argmax(backpointer_row)
    argmax = np.argmax(viterbi[:, -1])
    sent_pred_tags = []
    highest_idx = np.argmax(viterbi[:, -1])
    sent_pred_tags.insert(0, all_tags[highest_idx])
    for i in list(reversed(range(1, len(sent)))):
        highest_idx = int(backpointer[highest_idx][i])
        sent_pred_tags.insert(0, all_tags[highest_idx])
    test_pred_tags.append(sent_pred_tags)

100%|██████████| 588/588 [00:11<00:00, 52.90it/s]


In [53]:
from sklearn.metrics import classification_report

flat_true_tags, flat_pred_tags = [], []
for true_tags, pred_tags in zip(test_true_tags, test_pred_tags):
    true_temp, pred_temp = [], []
    for t1, t2 in zip(true_tags, pred_tags):
        if t1 == 'START' or t1 == 'END' or t2 == 'START' or t2 == 'END':
            pass
        else:
            true_temp.append(t1)
            pred_temp.append(t2)
        flat_true_tags.extend(true_temp)
        flat_pred_tags.extend(pred_temp)
labels = all_tags.copy()
labels.remove('START')
labels.remove('END')
print(classification_report(flat_true_tags, flat_pred_tags, labels=labels))

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                precision    recall  f1-score   support

    I-FACILITY       0.89      0.97      0.93       166
    B-LOCATION       0.72      0.84      0.78        89
      B-PERSON       0.76      0.58      0.66      5960
I-ORGANIZATION       0.58      0.58      0.58      2865
         I-GPE       0.72      0.93      0.81       555
    I-LOCATION       0.74      0.79      0.77        89
      I-PERSON       0.58      0.60      0.59      2983
         B-GPE       0.82      0.74      0.77      5131
             O       0.98      0.99      0.98    228479
    B-FACILITY       0.82      0.98      0.89       168
         B-GSP       1.00      0.38      0.55        77
         I-GSP       0.00      0.00      0.00         2
B-ORGANIZATION       0.64      0.56      0.60      4462

      accuracy                           0.96    251026
     macro avg       0.71      0.69      0.69    251026
  weighted avg       0.96      0.96      0.96    251026



  _warn_prf(average, modifier, msg_start, len(result))
