# CSE 291 Assignment 2 BiLSTM CRF

## Download Data/Eval Script

In [1]:
# !wget https://raw.githubusercontent.com/sighsmile/conlleval/master/conlleval.py
# !wget https://raw.githubusercontent.com/tberg12/cse291spr21/main/assignment2/train.data.quad
# !wget https://raw.githubusercontent.com/tberg12/cse291spr21/main/assignment2/dev.data.quad

In [2]:
import conlleval
from tqdm import tqdm
import numpy as np
from collections import defaultdict, Counter
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
from torchtext.vocab import Vocab
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import pdb 
from sklearn.metrics import precision_recall_fscore_support


torch.manual_seed(291)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


## Data Preparation

In [3]:
TRAIN_DATA = 'train.data.quad'
VALID_DATA = 'dev.data.quad'
UNK = '<unk>'
PAD = '<pad>'
START_TAG = "<start>"  # you can add this explicitly or use it implicitly in your CRF layer
STOP_TAG = "<stop>"    # you can add this explicitly or use it implicitly in your CRF layer


def read_conll_sentence(path):
    """ Read a CONLL-format sentence into vocab objects
    Args:
        :param path: path to CONLL-format data file
        :param word_vocab: Vocabulary object for source
        :param label_vocab: Vocabulary object for target
    """
    sent = [[], []]
    with open(path) as f:
        for line in f:
            line = line.strip().split()
            if line:
                # replace numbers with 0000
                word = line[0]
                word = '0000' if word.isnumeric() else word
                sent[0].append(word)
                sent[1].append(line[3])
            else:
                yield sent[0], sent[1]
                sent = [[], []]


def prepare_dataset(dataset, word_vocab, label_vocab):
    dataset = [
      [
        torch.tensor([word_vocab.stoi[word] for word in sent[0]], dtype=torch.long),
        torch.tensor([label_vocab.stoi[label] for label in sent[1]], dtype=torch.long),
      ]
      for sent in dataset
    ]
    return dataset


# load a list of sentences, where each word in the list is a tuple containing the word and the label
train_data = list(read_conll_sentence(TRAIN_DATA))
train_word_counter = Counter([word for sent in train_data for word in sent[0]])
train_label_counter = Counter([label for sent in train_data for label in sent[1]])
word_vocab = Vocab(train_word_counter, specials=(UNK, PAD), min_freq=2)

label_vocab = Vocab(train_label_counter, specials=(), min_freq=1)
train_data = prepare_dataset(train_data, word_vocab, label_vocab)
print('Train word vocab:', len(word_vocab), 'symbols.')
print('Train label vocab:', len(label_vocab), f'symbols: {list(label_vocab.stoi.keys())}')

valid_data = list(read_conll_sentence(VALID_DATA))

val_word_counter = Counter([word for sent in valid_data for word in sent[0]])
val_word_vocab = Vocab(val_word_counter, specials=(UNK, PAD), min_freq=2)

print('Val word vocab:', len(val_word_vocab), 'symbols.')


valid_data = prepare_dataset(valid_data, word_vocab, label_vocab)
print('Train data:', len(train_data), 'sentences.')
print('Valid data:', len(valid_data))

print(' '.join([word_vocab.itos[i.item()] for i in train_data[0][0]]))
print(' '.join([label_vocab.itos[i.item()] for i in train_data[0][1]]))

print(' '.join([word_vocab.itos[i.item()] for i in valid_data[1][0]]))
print(' '.join([label_vocab.itos[i.item()] for i in valid_data[1][1]]))

Train word vocab: 3947 symbols.
Train label vocab: 8 symbols: ['O', 'I-PER', 'I-ORG', 'I-LOC', 'I-MISC', 'B-MISC', 'B-ORG', 'B-LOC']
Val word vocab: 1078 symbols.
Train data: 3420 sentences.
Valid data: 800
Pusan 0000 0000 0000 0000 0000 0000
I-ORG O O O O O O
Earlier this month , <unk> denied a Kabul government statement that the two sides had agreed to a ceasefire in the north .
O O O O I-PER O O I-LOC O O O O O O O O O O O O O O O


In [4]:
inVocabWord = set((val_word_vocab.freqs)).intersection(set(word_vocab.freqs))
oovWord = set((val_word_vocab.freqs)).difference(set(word_vocab.freqs))

In [5]:
word_vocab.itos[0]

'<unk>'

In [6]:

def log_sum_exp(vec):
    max_score = vec.max()
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

## BiLSTMTagger

In [7]:
class CRF():
    def __init__(self, vocab_size, tag_vocab_size):
        
        self.vocab_size = vocab_size
        self.tagset_size = tag_vocab_size

        self.transitions = nn.Parameter(
            torch.randn(self.tagset_size, self.tagset_size)).to(device)

        self.start_index = tag_vocab_size - 2
        self.end_index = tag_vocab_size - 1
        
        self.transitions.data[self.start_index, :] = -10000
        self.transitions.data[:, self.end_index] = -10000
        
    
    def _viterbi_decode(self, feats):
        backtrace = []
        alpha = torch.full((1, self.tagset_size), -10000.0, device=device)
        alpha[0][self.start_index] = 0
        feats = feats.squeeze(0)
#         pdb.set_trace()
        for feat in feats:
            t2 = (self.transitions.T + (feat)).T + alpha
            t3 = t2.max(dim=1).values
            alpha = t3 + (t2.T - t3).T.exp().sum(dim=1).log().view(1,-1)
            backtrace.append(t2.argmax(dim=1))

        vec = alpha.T + self.transitions[:, [self.end_index]]

        
        best_tag = vec.squeeze(1).argmax().item()

        optimal_path = [best_tag]
        
        seq = reversed(backtrace[1:])
        
        for bp in seq: 
            best_tag = bp[best_tag].item()
            optimal_path.append(best_tag)
        
        score = (vec - vec.max()).exp().sum(axis=0, keepdim=True).log() + vec.max()
        return score,[ optimal_path[::-1]] 
    
     
    def _forward_alg(self, feats):
        init_alphas = torch.full((1, self.tagset_size), -10000.)

        init_alphas[0][self.start_index] = 0.


        forward_var = init_alphas.to(device)

        feats = feats.squeeze(0)

        for feat in feats:
            alphas_t = []  # The forward tensors at this timestep
            t2 = (self.transitions.T + (feat)).T + forward_var
            t3 = t2.max(dim=1).values
            forward_var = t3 + (t2.T - t3).T.exp().sum(dim=1).log().view(1,-1)

        terminal_var = forward_var + self.transitions[self.end_index]
        alpha = log_sum_exp(terminal_var)
        return alpha
    
    
    def _score_sentence(self, feats, tags):

        score = torch.zeros(1).to(device)
        feats = feats.squeeze(0)
        tags = tags.squeeze(0)
        tags = torch.cat([torch.tensor([self.start_index], dtype=torch.long).to(device), tags]).to(device)
        for i, feat in enumerate(feats):
            score = score + self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]

        score = score + self.transitions[self.end_index, tags[-1]]
        return score


In [8]:
# Starter code implementing a BiLSTM Tagger
# which makes locally normalized, independent
# tag classifications at each time step

class BiLSTMTagger(nn.Module):
    def __init__(self, vocab_size, tag_vocab_size, embedding_dim, hidden_dim, dropout=0.3):
        super(BiLSTMTagger, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.tagset_size = tag_vocab_size
        self.word_embeds = nn.Embedding(vocab_size, embedding_dim).to(device)
        self.bilstm = nn.LSTM(embedding_dim, hidden_dim // 2,
                            num_layers=1, bidirectional=True, batch_first=True).to(device)
        self.tag_projection_layer = nn.Linear(hidden_dim, self.tagset_size).to(device)
        self.dropout = nn.Dropout(p=dropout)

        self.crf = CRF(vocab_size, tag_vocab_size)


    def init_hidden(self):
        return (torch.randn(2, 1, self.hidden_dim // 2).to(device),
                torch.randn(2, 1, self.hidden_dim // 2).to(device))

    def compute_lstm_emission_features(self, sentence):
        hidden = self.init_hidden()
        embeds = self.dropout(self.word_embeds(sentence))
        bilstm_out, hidden = self.bilstm(embeds, hidden)
        bilstm_out = self.dropout(bilstm_out)
        bilstm_out = bilstm_out
        bilstm_feats = self.tag_projection_layer(bilstm_out)
        return bilstm_feats

    
    def forward(self, sentence):
        bilstm_feats = self.compute_lstm_emission_features(sentence)

        return self.crf._viterbi_decode(bilstm_feats)
   

    def loss(self, sentence, tags):
        feats = self.compute_lstm_emission_features(sentence)
        forward_score = self.crf._forward_alg(feats)
        gold_score = self.crf._score_sentence(feats, tags)
        return forward_score - gold_score


## Train / Eval loop

In [9]:

def train(model, train_data, valid_data, word_vocab, label_vocab, epochs, log_interval=25):
    losses_per_epoch = []
    for epoch in range(epochs):
        print(f'--- EPOCH {epoch} ---')
        model.train()
        losses_per_epoch.append([])
        for i, (sent, tags) in enumerate(train_data):
            model.zero_grad()
            sent, tags = sent.to(device), tags.to(device)
            sent = sent.unsqueeze(0)
            tags = tags.unsqueeze(0)
            loss = model.loss(sent, tags)
            loss.backward()
            optimizer.step()

            losses_per_epoch[-1].append(loss.detach().cpu().item())
            if i > 0 and i % log_interval == 0:
                print(f'Avg loss over last {log_interval} updates: {np.mean(losses_per_epoch[-1][-log_interval:])}')

        evaluate(model, valid_data, word_vocab, label_vocab)


def evaluate(model, dataset, word_vocab, label_vocab):
    model.eval()
    losses = []
    scores = []
    true_tags = []
    pred_tags = []
    sents = []
    for i, (sent, tags) in enumerate(dataset):
        if i == 0:
            pass
        with torch.no_grad():
            sent, tags = sent.to(device), tags.to(device)
            sent = sent.unsqueeze(0)
            tags = tags.unsqueeze(0)
            losses.append(model.loss(sent, tags).cpu().detach().item())
            score, pred_tag_seq = model(sent)
            scores.append(score)
            try:
                true_tags.append([label_vocab.itos[i] for i in tags.tolist()[0]])
                pred_tags.append([ label_vocab.itos[0] if i == 9 else label_vocab.itos[i] for i in pred_tag_seq[0]])
            except:
                print(pred_tag_seq)
                pass

            sents.append([word_vocab.itos[i] for i in sent[0]])
    
    print('Avg evaluation loss:', np.mean(losses))
    a = [tag for tags in true_tags for tag in tags]
    b = [tag for tags in pred_tags for tag in tags]
    print(conlleval.evaluate(a, b, verbose=True))

    scores_token = precision_recall_fscore_support(a,b, average=None, labels=label_vocab.itos)
        
    for i in range(len(label_vocab.itos)):
        print(label_vocab.itos[i] , "precision", scores_token[0][i], "recall", scores_token[1][i], "f1", scores_token[2][i], "count", scores_token[3][i])
#     print('\n5 random evaluation samples:')
#     for i in np.random.randint(0, len(sents), size=2):
#         print('SENT:', ' '.join(sents[i]))
#         print('TRUE:', ' '.join(true_tags[i]))
#         print('PRED:', ' '.join(pred_tags[i]))
    return sents, true_tags, pred_tags


## Training

In [10]:
# Train BiLSTM Tagger Baseline
model = BiLSTMTagger(len(word_vocab), len(label_vocab)+2, 128, 256).to(device)


optimizer = optim.Adam(model.parameters(), lr=1e-3)
train(model, train_data, valid_data, word_vocab, label_vocab, epochs=10, log_interval=500)

--- EPOCH 0 ---
Avg loss over last 500 updates: 10.90044804263115
Avg loss over last 500 updates: 8.894040877342224
Avg loss over last 500 updates: 7.136677108049392
Avg loss over last 500 updates: 6.2077221388816834
Avg loss over last 500 updates: 5.5089131655693055
Avg loss over last 500 updates: 5.408730796813964
Avg evaluation loss: 5.117264733016491
processed 11170 tokens with 1231 phrases; found: 817 phrases; correct: 438.
accuracy:  42.35%; (non-O)
accuracy:  88.88%; precision:  53.61%; recall:  35.58%; FB1:  42.77
              LOC: precision:  79.35%; recall:  40.22%; FB1:  53.38  184
             MISC: precision:  49.15%; recall:  15.10%; FB1:  23.11  59
              ORG: precision:  49.77%; recall:  34.85%; FB1:  41.00  215
              PER: precision:  43.45%; recall:  42.28%; FB1:  42.86  359
(53.61077111383109, 35.580828594638504, 42.7734375)
O precision 0.9102564102564102 recall 0.9791533033996151 f1 0.9434487021013597 count 9354
I-PER precision 0.6814814814814815 reca

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Avg loss over last 500 updates: 4.909052533149719
Avg loss over last 500 updates: 5.382206438064575
Avg loss over last 500 updates: 4.474524360179901
Avg loss over last 500 updates: 4.233311654567719
Avg loss over last 500 updates: 3.79393673324585
Avg loss over last 500 updates: 3.767577681541443
Avg evaluation loss: 3.996925927102566
processed 11170 tokens with 1231 phrases; found: 967 phrases; correct: 614.
accuracy:  56.33%; (non-O)
accuracy:  91.48%; precision:  63.50%; recall:  49.88%; FB1:  55.87
              LOC: precision:  82.30%; recall:  55.10%; FB1:  66.01  243
             MISC: precision:  66.12%; recall:  41.67%; FB1:  51.12  121
              ORG: precision:  50.19%; recall:  42.67%; FB1:  46.13  261
              PER: precision:  59.36%; recall:  55.01%; FB1:  57.10  342
(63.49534643226473, 49.87814784727863, 55.868971792538666)
O precision 0.9344512195121951 recall 0.9830019243104554 f1 0.9581119099718662 count 9354
I-PER precision 0.7924187725631769 recall 0.688087

In [11]:
label_vocab.__dict__

{'freqs': Counter({'I-ORG': 2258,
          'O': 38899,
          'I-PER': 2544,
          'I-LOC': 1836,
          'I-MISC': 1011,
          'B-MISC': 9,
          'B-ORG': 5,
          'B-LOC': 3}),
 'itos': ['O',
  'I-PER',
  'I-ORG',
  'I-LOC',
  'I-MISC',
  'B-MISC',
  'B-ORG',
  'B-LOC'],
 'unk_index': None,
 'stoi': defaultdict(None,
             {'O': 0,
              'I-PER': 1,
              'I-ORG': 2,
              'I-LOC': 3,
              'I-MISC': 4,
              'B-MISC': 5,
              'B-ORG': 6,
              'B-LOC': 7}),
 'vectors': None}

In [12]:
sents, true_tags, pred_tags =  evaluate(model, valid_data, word_vocab, label_vocab)


Avg evaluation loss: 3.389647416174412
processed 11170 tokens with 1231 phrases; found: 1149 phrases; correct: 835.
accuracy:  73.02%; (non-O)
accuracy:  94.23%; precision:  72.67%; recall:  67.83%; FB1:  70.17
              LOC: precision:  86.85%; recall:  78.24%; FB1:  82.32  327
             MISC: precision:  76.00%; recall:  59.38%; FB1:  66.67  150
              ORG: precision:  61.90%; recall:  55.05%; FB1:  58.28  273
              PER: precision:  67.17%; recall:  72.63%; FB1:  69.79  399
(72.67188859878155, 67.83103168155971, 70.16806722689076)
O precision 0.9630443886097152 recall 0.9834295488561043 f1 0.973130223209563 count 9354
I-PER precision 0.8140243902439024 recall 0.8369905956112853 f1 0.8253477588871716 count 638
I-ORG precision 0.762962962962963 recall 0.6306122448979592 f1 0.6905027932960894 count 490
I-LOC precision 0.888 recall 0.7928571428571428 f1 0.8377358490566037 count 420
I-MISC precision 0.8287292817679558 recall 0.5703422053231939 f1 0.6756756756756757 c

In [13]:
a = []
b = []
for i,sent in enumerate(sents):
    for j,word in enumerate(sent):
        if word == "<unk>":
            a.append(true_tags[i][j])
            b.append(pred_tags[i][j])

In [14]:
scores_token = precision_recall_fscore_support(a,b, average=None, labels=label_vocab.itos)


In [15]:
scores_token = precision_recall_fscore_support(a,b, average=None, labels=label_vocab.itos)

for i in range(len(label_vocab.itos)):
    print(label_vocab.itos[i] , "precision", round(scores_token[0][i] * 100,2), 
          "recall", round(scores_token[1][i] * 100, 2), "f1", round(scores_token[2][i] * 100, 2), "count", scores_token[3][i])


O precision 81.71 recall 90.22 f1 85.75 count 1431
I-PER precision 76.99 recall 79.73 f1 78.34 count 449
I-ORG precision 62.62 recall 53.09 f1 57.46 count 243
I-LOC precision 73.08 recall 37.62 f1 49.67 count 101
I-MISC precision 38.89 recall 7.61 f1 12.73 count 92
B-MISC precision 0.0 recall 0.0 f1 0.0 count 1
B-ORG precision 0.0 recall 0.0 f1 0.0 count 0
B-LOC precision 0.0 recall 0.0 f1 0.0 count 4
