In [22]:
import pandas as pd
import math
import os
import nltk
from sklearn import metrics
from collections import Counter

In [23]:
## Module to load training, dev and test data
def get_token_tag_tuples(sent):
    return([nltk.tag.str2tuple(t) for t in sent.split()])

def get_tagged_sentences(text):
    sentences = []
    blocks = text.split("======================================")
    for block in blocks:
        sents = block.split("\n\n")
        for sent in sents:
            sent = sent.replace("\n", "").replace("[", "").replace("]", "")
            if sent is not "":
                sentences.append(sent)
    return sentences

def load_treebank_splits(datadir):
    train = []
    dev = []
    test = []
    print("Loading treebank data...")
    for subdir, dirs, files in os.walk(datadir):
        for filename in files:
            if filename.endswith(".pos"):
                if subdir.split(os.sep)[-1].startswith('.'):
                    pass
                else:
                    filepath = subdir + os.sep + filename
                    # print(filepath)
                    with open(filepath, "r") as fh:
                        text = fh.read()
                        if int(subdir.split(os.sep)[-1]) in range(0,19):
                            train += get_tagged_sentences(text)
                        if int(subdir.split(os.sep)[-1]) in range(19, 22):
                            dev += get_tagged_sentences(text)
                        if int(subdir.split(os.sep)[-1]) in range(22, 25):
                            test += get_tagged_sentences(text)
    print("Train set size: ", len(train))
    print("Dev set size: ", len(dev))
    print("Test set size: ", len(test))
    return train, dev, test


  if sent is not "":


In [24]:
datadir = 'data/penn-treeban3-wsj/wsj'
train, dev, test = load_treebank_splits(datadir)

Loading treebank data...
Train set size:  51681
Dev set size:  7863
Test set size:  9046


In [25]:
## Tokenize data and add <START> and <STOP> tags
def tokenize_data(data):
    processed_data = []
    for sentence in data:
        temp = []
        temp.append(('<START>', '<START>'))
        processed_sent = get_token_tag_tuples(sentence)
        temp.extend(processed_sent)
        temp.append(('<STOP>', '<STOP>'))
        processed_data.append(temp)
    return processed_data

def flatten_pairs(data):
    return [tup for sent in data for tup in sent]

In [26]:
train_pairs = tokenize_data(train)
dev_pairs = tokenize_data(dev)
test_pairs = tokenize_data(test)

In [27]:
print('Train set: ', len(train_pairs))
print('Dev set: ', len(dev_pairs))
print('Test set: ', len(test_pairs))

Train set:  51681
Dev set:  7863
Test set:  9046


In [28]:
train_pairs_flat = flatten_pairs(train_pairs)
dev_pairs_flat = flatten_pairs(dev_pairs)
test_pairs_flat = flatten_pairs(test_pairs)

In [29]:
print(len(train_pairs_flat))
print(len(dev_pairs_flat))
print(len(test_pairs_flat))

1073238
163883
189230


In [30]:
# Create train token and tag vocab

In [31]:
# Create tokena and tag vocabulary for training data
train_tokens = [tag[0] for tag in train_pairs_flat]
train_tags = [tag[1] for tag in train_pairs_flat]

train_vocab = set(train_tokens)
train_tag_vocab = set(train_tags)

print('Train vocabulary:', len(train_vocab))
print('Train tag vocabulary:', len(train_tag_vocab))

Train vocabulary: 44547
Train tag vocabulary: 79


In [32]:
## Create useful counter dictionaries
def get_train_freq_dict(train_pairs, train_tags):
    word_tag_bigrams = []
    tags_bigram = []
    for sent in train_pairs:
        word_tag_bigrams.extend(sent)
        sent_tags = [tup[1] for tup in sent]
        tags_bigram.extend([(s1, s2) for s1, s2 in zip(sent_tags, sent_tags[1:])])
    tags_counter = Counter(train_tags)
    tags_bigram_counter = Counter(tags_bigram)
    word_tag_bigram_counter = Counter(word_tag_bigrams)
    return word_tag_bigram_counter, tags_bigram_counter, tags_counter

In [33]:
word_tag_bigram_counter, tags_bigram_counter, tags_counter = get_train_freq_dict(train_pairs, train_tags)

In [34]:
## Transition and emission probabilities
def get_transition_prob(t_prev, t, alpha):
    v = len(train_vocab)
    prob = (tags_bigram_counter[(t_prev, t)]+ alpha)/(tags_counter[t_prev] + v*alpha)
    return prob

def get_emission_prob(w, t, alpha):
    v = len(train_vocab)
    prob = (word_tag_bigram_counter[(w,t)] + alpha)/(tags_counter[t] + v*alpha)
    return prob

In [37]:
# Viterbi ALgorithm

import numpy as np
from tqdm import tqdm

def evaluate(test_sentences, predictions):
    gold = [str(tag) for sentence in test_sentences for token, tag in sentence]
    pred = [tag for sentence in predictions for tag in sentence]
    print(metrics.confusion_matrix(gold, pred))
    print(metrics.classification_report(gold, pred))

def viterbi(data_pair, train_tag_vocab, alpha):
    all_sequences = []
    train_tag_vocab = list(train_tag_vocab)
    predictions = []
    for sent in tqdm(data_pair):
        v_matrix = np.zeros((len(train_tag_vocab), len(sent)-1))
        for key, token in enumerate(sent[1:]):
            for i, possible_tag in enumerate(train_tag_vocab):
                if key == 0:
                    previous_tag = '<START>'
                    log_t_prob = math.log(get_transition_prob(previous_tag, possible_tag, alpha),2)
                    log_e_prob = math.log(get_emission_prob(token[0], possible_tag, alpha), 2)
                    joint_prob = log_t_prob + log_e_prob

                    v_matrix[i][key] =  joint_prob
                else:
                    state = []
                    for j, prev_tag in enumerate(train_tag_vocab):
                        log_t_prob = math.log(get_transition_prob(prev_tag, possible_tag, alpha),2)
                        log_e_prob = math.log(get_emission_prob(token[0], possible_tag, alpha), 2)
                        v_prev = v_matrix[j][key-1]
                        joint_prob = log_t_prob + log_e_prob + v_prev
                        state.append((prev_tag, joint_prob))
                    max_prob = max([tup[1] for tup in state])
                    max_state = [tup for tup in state if tup[1] == max_prob][0][0]
                    if not isinstance(max_prob, float):
                        # print(max_prob)
                        raise Exception
                    v_matrix[i][key] = max_prob
        # backtrack to get tag seq
        temp_df = pd.DataFrame(v_matrix, columns = range(len(sent)-1))
        seq = []
        for token in reversed(range(0, len(sent)-1)):
            loc = temp_df[token].idxmax()
            seq.append(train_tag_vocab[loc])
        seq = seq[::-1]
        all_sequences.append(seq)
    return all_sequences
    
          

In [15]:
alpha_list = [0.2, 0.1, 0.01, 0.001, 1]
for alpha in alpha_list:
    print('Alpha value: ', alpha)
    seq = viterbi(dev_pairs, train_tag_vocab, alpha)
    new_seq = []
    for sent in seq:
        n_seq = []
        n_seq.append('<START>')
        n_seq.extend(sent)
        new_seq.append(n_seq)
    evaluate(dev_pairs, new_seq)
    

Alpha value:  0.2


100% 7863/7863 [17:44<00:00,  7.38it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           #       1.00      1.00      1.00        31
           $       0.98      1.00      0.99      1248
          ''       0.83      1.00      0.91      1168
           (       1.00      1.00      1.00       244
           )       1.00      1.00      1.00       244
           ,       1.00      1.00      1.00      7931
           .       1.00      1.00      1.00      6125
           :       0.99      0.99      0.99       775
     <START>       1.00      1.00      1.00      7863
      <STOP>       0.99      1.00      0.99      7863
          CC       0.99      1.00      0.99      3777
          CD       0.97      0.94      0.95      5766
          DT       0.95      0.99      0.97     12639
          EX       0.91      0.62      0.74       133
          FW       0.00      0.00      0.00        25
          IN       0.95      0.98      0.97     15497
       IN|RB       0.00      0.00      0.00         1
          JJ       0.87    

100% 7863/7863 [17:39<00:00,  7.42it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           #       1.00      1.00      1.00        31
           $       0.98      1.00      0.99      1248
          ''       0.85      1.00      0.92      1168
           (       1.00      1.00      1.00       244
           )       1.00      1.00      1.00       244
           ,       1.00      1.00      1.00      7931
           .       0.97      1.00      0.99      6125
           :       0.99      0.99      0.99       775
     <START>       1.00      1.00      1.00      7863
      <STOP>       0.99      1.00      1.00      7863
          CC       0.99      1.00      0.99      3777
          CD       0.96      0.95      0.95      5766
          DT       0.95      0.99      0.97     12639
          EX       0.92      0.77      0.84       133
          FW       0.00      0.00      0.00        25
          IN       0.97      0.98      0.97     15497
       IN|RB       0.00      0.00      0.00         1
          JJ       0.88    

100% 7863/7863 [17:38<00:00,  7.43it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           #       1.00      1.00      1.00        31
           $       0.99      1.00      0.99      1248
          ''       0.86      1.00      0.92      1168
           (       1.00      1.00      1.00       244
           )       1.00      1.00      1.00       244
           ,       1.00      1.00      1.00      7931
           .       0.98      1.00      0.99      6125
           :       0.99      0.99      0.99       775
     <START>       1.00      1.00      1.00      7863
      <STOP>       1.00      1.00      1.00      7863
          CC       0.99      1.00      0.99      3777
          CD       0.97      0.95      0.96      5766
          DT       0.99      0.99      0.99     12639
          EX       0.40      0.85      0.54       133
          FW       0.07      0.44      0.11        25
          IN       0.97      0.98      0.97     15497
       IN|RB       0.00      0.00      0.00         1
          JJ       0.88    

100% 7863/7863 [17:42<00:00,  7.40it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           #       1.00      1.00      1.00        31
           $       1.00      1.00      1.00      1248
          ''       1.00      1.00      1.00      1168
           (       0.95      1.00      0.97       244
           )       0.66      1.00      0.79       244
           ,       1.00      1.00      1.00      7931
           .       0.98      1.00      0.99      6125
           :       0.99      0.99      0.99       775
     <START>       1.00      1.00      1.00      7863
      <STOP>       1.00      1.00      1.00      7863
          CC       0.99      1.00      0.99      3777
          CD       0.99      0.95      0.97      5766
          DT       0.99      0.99      0.99     12639
          EX       0.30      0.97      0.46       133
          FW       0.05      0.56      0.09        25
          IN       0.97      0.97      0.97     15497
       IN|RB       0.00      0.00      0.00         1
          JJ       0.88    

100% 7863/7863 [17:14<00:00,  7.60it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           #       0.00      0.00      0.00        31
           $       0.98      1.00      0.99      1248
          ''       1.00      1.00      1.00      1168
           (       1.00      0.94      0.97       244
           )       1.00      0.99      1.00       244
           ,       1.00      1.00      1.00      7931
           .       1.00      1.00      1.00      6125
           :       1.00      0.99      1.00       775
     <START>       1.00      1.00      1.00      7863
      <STOP>       0.98      1.00      0.99      7863
          CC       0.93      0.99      0.96      3777
          CD       0.96      0.91      0.94      5766
          DT       0.86      0.99      0.92     12639
          EX       0.94      0.35      0.51       133
          FW       0.00      0.00      0.00        25
          IN       0.90      0.99      0.94     15497
       IN|RB       0.00      0.00      0.00         1
          JJ       0.88    

  _warn_prf(average, modifier, msg_start, len(result))


In [17]:
# generating test resullts
alpha = 1
seq = viterbi(test_pairs, train_tag_vocab, alpha)

100% 9046/9046 [20:14<00:00,  7.45it/s]


In [19]:
new_seq = []
for sent in seq:
    n_seq = []
    n_seq.append('<START>')
    n_seq.extend(sent)
    new_seq.append(n_seq)
evaluate(test_pairs, new_seq)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           #       0.00      0.00      0.00        22
           $       0.98      0.99      0.99      1138
          ''       1.00      1.00      1.00      1423
           (       1.00      0.96      0.98       249
           )       1.00      0.98      0.99       252
           ,       1.00      1.00      1.00      9056
           .       1.00      1.00      1.00      7035
           :       1.00      0.99      1.00       983
     <START>       1.00      1.00      1.00      9046
      <STOP>       0.98      1.00      0.99      9046
          CC       0.93      1.00      0.96      4289
          CD       0.96      0.90      0.93      6023
          DT       0.87      0.99      0.93     14946
          EX       0.96      0.45      0.61       174
          FW       0.00      0.00      0.00        38
          IN       0.90      0.99      0.94     18147
          JJ       0.88      0.83      0.85     10704
         JJR       0.67    

  _warn_prf(average, modifier, msg_start, len(result))


In [38]:
alpha = 0.01
seq = viterbi(test_pairs, train_tag_vocab, alpha)
new_seq = []
for sent in seq:
    n_seq = []
    n_seq.append('<START>')
    n_seq.extend(sent)
    new_seq.append(n_seq)
evaluate(test_pairs, new_seq)

100% 9046/9046 [20:28<00:00,  7.36it/s]


[[  21    0    0 ...    0    0    0]
 [   0 1138    0 ...    0    0    0]
 [   0    0 1418 ...    0    0    0]
 ...
 [   0    0    0 ...   47    0    0]
 [   0    0    0 ...    0  422    0]
 [   0    0    0 ...    0    0 1422]]


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           #       1.00      0.95      0.98        22
           $       0.98      1.00      0.99      1138
          ''       0.87      1.00      0.93      1423
           (       1.00      1.00      1.00       249
           )       1.00      1.00      1.00       252
           ,       1.00      1.00      1.00      9056
           .       0.98      1.00      0.99      7035
           :       1.00      1.00      1.00       983
     <START>       1.00      1.00      1.00      9046
      <STOP>       1.00      1.00      1.00      9046
          CC       0.99      1.00      1.00      4289
          CD       0.97      0.93      0.95      6023
          DT       0.99      0.99      0.99     14946
          EX       0.52      0.94      0.67       174
          FW       0.12      0.37      0.19        38
          IN       0.96      0.98      0.97     18147
       IN|RB       0.00      0.00      0.00         0
          JJ       0.89    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [42]:
gold = [str(tag) for sentence in test_pairs for token, tag in sentence]
tokens = [str(token) for sentence in test_pairs for token, tag in sentence]
pred = [tag for sentence in new_seq for tag in sentence]

for p_token, g_token, token in zip(pred, gold, tokens):
    if p_token != g_token:
        print('token:', token)
        print('Prediction:', p_token)
        print('Actual token:', g_token)
        print('.....')

token: Influential
Prediction: WDT
Actual token: JJ
.....
token: Ways
Prediction: NNPS
Actual token: NNP
.....
token: working
Prediction: NN
Actual token: VBG
.....
token: RTC-owned
Prediction: JJS
Actual token: JJ
.....
token: self-help
Prediction: EX
Actual token: NN
.....
token: '
Prediction: POS
Actual token: ''
.....
token: more
Prediction: JJR
Actual token: RBR
.....
token: issued
Prediction: VBD
Actual token: VBN
.....
token: working
Prediction: VBG
Actual token: NN
.....
token: Cooke
Prediction: POS
Actual token: NNP
.....
token: selling
Prediction: NN
Actual token: VBG
.....
token: Absent
Prediction: WDT
Actual token: VB
.....
token: complicated
Prediction: JJ
Actual token: VBN
.....
token: that
Prediction: IN
Actual token: DT
.....
token: subtracting
Prediction: JJR
Actual token: VBG
.....
token: working-capital
Prediction: JJS
Actual token: JJ
.....
token: Schumer
Prediction: POS
Actual token: NNP
.....
token: remiss
Prediction: RP
Actual token: JJ
.....
token: informed
Pred