# CSE 291 Assignment 2 BiLSTM CRF

## Download Data/Eval Script

In [1]:
# !wget https://raw.githubusercontent.com/sighsmile/conlleval/master/conlleval.py
# !wget https://raw.githubusercontent.com/tberg12/cse291spr21/main/assignment2/train.data.quad
# !wget https://raw.githubusercontent.com/tberg12/cse291spr21/main/assignment2/dev.data.quad

In [2]:
import conlleval
from tqdm import tqdm
import numpy as np
from collections import defaultdict, Counter
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
from torchtext.vocab import Vocab
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import pdb 
from sklearn.metrics import precision_recall_fscore_support


torch.manual_seed(291)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


## Data Preparation

In [3]:
TRAIN_DATA = 'train.data.quad'
VALID_DATA = 'dev.data.quad'
UNK = '<unk>'
PAD = '<pad>'
START_TAG = "<start>"  # you can add this explicitly or use it implicitly in your CRF layer
STOP_TAG = "<stop>"    # you can add this explicitly or use it implicitly in your CRF layer


def read_conll_sentence(path):
    """ Read a CONLL-format sentence into vocab objects
    Args:
        :param path: path to CONLL-format data file
        :param word_vocab: Vocabulary object for source
        :param label_vocab: Vocabulary object for target
    """
    sent = [[], []]
    with open(path) as f:
        for line in f:
            line = line.strip().split()
            if line:
                # replace numbers with 0000
                word = line[0]
                word = '0000' if word.isnumeric() else word
                sent[0].append(word)
                sent[1].append(line[3])
            else:
                yield sent[0], sent[1]
                sent = [[], []]



def prepare_dataset(dataset, word_vocab, label_vocab):
    dataset = [
      [
        torch.tensor([word_vocab.stoi[word] for word in sent[0]], dtype=torch.long),
        torch.tensor([label_vocab.stoi[label] for label in sent[1]], dtype=torch.long),
      ]
      for sent in dataset
    ]
    return dataset


# load a list of sentences, where each word in the list is a tuple containing the word and the label
train_data = list(read_conll_sentence(TRAIN_DATA))
train_word_counter = Counter([word for sent in train_data for word in sent[0]])
train_label_counter = Counter([label for sent in train_data for label in sent[1]])
word_vocab = Vocab(train_word_counter, specials=(UNK, PAD), min_freq=2)
label_vocab = Vocab(train_label_counter, specials=(), min_freq=1)
print(train_data[0])

train_raw_sent = [" ".join(sent[0]) for sent in train_data]

train_data = prepare_dataset(train_data, word_vocab, label_vocab)
print('Train word vocab:', len(word_vocab), 'symbols.')
print('Train label vocab:', len(label_vocab), f'symbols: {list(label_vocab.stoi.keys())}')
valid_data = list(read_conll_sentence(VALID_DATA))

val_raw_sent = [" ".join(sent[0]) for sent in valid_data]

valid_data = prepare_dataset(valid_data, word_vocab, label_vocab)
print('Train data:', len(train_data), 'sentences.')
print('Valid data:', len(valid_data))

print(' '.join([word_vocab.itos[i.item()] for i in train_data[1][0]]))
print(' '.join([label_vocab.itos[i.item()] for i in train_data[1][1]]))

print(' '.join([word_vocab.itos[i.item()] for i in valid_data[1][0]]))
print(' '.join([label_vocab.itos[i.item()] for i in valid_data[1][1]]))

(['Pusan', '0000', '0000', '0000', '0000', '0000', '0000'], ['I-ORG', 'O', 'O', 'O', 'O', 'O', 'O'])
Train word vocab: 3947 symbols.
Train label vocab: 8 symbols: ['O', 'I-PER', 'I-ORG', 'I-LOC', 'I-MISC', 'B-MISC', 'B-ORG', 'B-LOC']
Train data: 3420 sentences.
Valid data: 800
Martin Brooks , president of Miss Universe Inc , said he spoke with Machado to <unk> her that organisers were not <unk> pressure on her .
I-PER I-PER O O O I-ORG I-ORG I-ORG O O O O O I-PER O O O O O O O O O O O O
Earlier this month , <unk> denied a Kabul government statement that the two sides had agreed to a ceasefire in the north .
O O O O I-PER O O I-LOC O O O O O O O O O O O O O O O


In [5]:
# inVocabWord = set((val_word_vocab.freqs)).intersection(set(word_vocab.freqs))
# oovWord = set((val_word_vocab.freqs)).difference(set(word_vocab.freqs))

In [6]:
embeddingTrainDict = {}

for i, (sent, tags) in enumerate(train_data):
    sentence = train_raw_sent[i]
    embeddingTrainDict[i] = (torch.Tensor([[1] if word[0].isupper() else [0] for word in train_raw_sent[i].split()])).unsqueeze(0).to(device)

embeddingValDict = {}

for i, (sent, tags) in enumerate(valid_data):
    
    embeddingValDict[i] = (torch.Tensor([[1] if word[0].isupper() else [0] for word in val_raw_sent[i].split()])).unsqueeze(0).to(device)


In [7]:
embeddingTrainDict[0].shape

torch.Size([1, 7, 1])

In [8]:
train_raw_sent[0]

'Pusan 0000 0000 0000 0000 0000 0000'

In [9]:

def log_sum_exp(vec):
    max_score = vec.max()
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

## BiLSTMTagger

In [10]:
class CRF():
    def __init__(self, vocab_size, tag_vocab_size):
        
        self.vocab_size = vocab_size
        self.tagset_size = tag_vocab_size

        self.transitions = nn.Parameter(
            torch.randn(self.tagset_size, self.tagset_size)).to(device)

        self.start_index = tag_vocab_size - 2
        self.end_index = tag_vocab_size - 1
        
        self.transitions.data[self.start_index, :] = -10000
        self.transitions.data[:, self.end_index] = -10000
        
    
    def _viterbi_decode(self, feats):
        backtrace = []
        alpha = torch.full((1, self.tagset_size), -10000.0, device=device)
        alpha[0][self.start_index] = 0
        feats = feats.squeeze(0)
#         pdb.set_trace()
        for feat in feats:
            t2 = (self.transitions.T + (feat)).T + alpha
            t3 = t2.max(dim=1).values
            alpha = t3 + (t2.T - t3).T.exp().sum(dim=1).log().view(1,-1)
            backtrace.append(t2.argmax(dim=1))

        vec = alpha.T + self.transitions[:, [self.end_index]]
        
        best_tag = vec.squeeze(1).argmax().item()
        optimal_path = [best_tag]
        
        seq = reversed(backtrace[1:])
        
        for bp in seq: 
            best_tag = bp[best_tag].item()
            optimal_path.append(best_tag)
        
        score = (vec - vec.max()).exp().sum(axis=0, keepdim=True).log() + vec.max()
        return score,[ optimal_path[::-1]] 
    
     
    def _forward_alg(self, feats):
        init_alphas = torch.full((1, self.tagset_size), -10000.)

        init_alphas[0][self.start_index] = 0.


        forward_var = init_alphas.to(device)

        feats = feats.squeeze(0)

        for feat in feats:
            alphas_t = []  # The forward tensors at this timestep
            t2 = (self.transitions.T + (feat)).T + forward_var
            t3 = t2.max(dim=1).values
            forward_var = t3 + (t2.T - t3).T.exp().sum(dim=1).log().view(1,-1)

        terminal_var = forward_var + self.transitions[self.end_index]
        alpha = log_sum_exp(terminal_var)
        return alpha
    
    
    def _score_sentence(self, feats, tags):

        score = torch.zeros(1).to(device)
        feats = feats.squeeze(0)
        tags = tags.squeeze(0)
        tags = torch.cat([torch.tensor([self.start_index], dtype=torch.long).to(device), tags]).to(device)
        for i, feat in enumerate(feats):
            score = score + self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]

        score = score + self.transitions[self.end_index, tags[-1]]
        return score


In [11]:
# Starter code implementing a BiLSTM Tagger
# which makes locally normalized, independent
# tag classifications at each time step

class BiLSTMTagger(nn.Module):
    def __init__(self, vocab_size, tag_vocab_size, embedding_dim, hidden_dim, dropout=0.3, feature_size=1):
        super(BiLSTMTagger, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.tagset_size = tag_vocab_size
        self.word_embeds = nn.Embedding(vocab_size, embedding_dim).to(device)
        self.bilstm = nn.LSTM(embedding_dim+feature_size, hidden_dim // 2,
                            num_layers=1, bidirectional=True, batch_first=True).to(device)
        self.tag_projection_layer = nn.Linear(hidden_dim, self.tagset_size).to(device)
        self.dropout = nn.Dropout(p=dropout)

        self.crf = CRF(vocab_size, tag_vocab_size)


    def init_hidden(self):
        return (torch.randn(2, 1, self.hidden_dim // 2).to(device),
                torch.randn(2, 1, self.hidden_dim // 2).to(device))

    def compute_lstm_emission_features(self, sentence, feature):
        hidden = self.init_hidden()
        embeds = self.dropout(self.word_embeds(sentence))
        bilstm_out, hidden = self.bilstm(torch.cat((embeds, feature),dim=2), hidden)
        bilstm_out = self.dropout(bilstm_out)
        bilstm_out = bilstm_out
        bilstm_feats = self.tag_projection_layer(bilstm_out)
        return bilstm_feats

    
    def forward(self, sentence, feature):
        bilstm_feats = self.compute_lstm_emission_features(sentence, feature)

        return self.crf._viterbi_decode(bilstm_feats)
    

    def loss(self, sentence, tags, feature):
        feats = self.compute_lstm_emission_features(sentence, feature)
        forward_score = self.crf._forward_alg(feats)
        gold_score = self.crf._score_sentence(feats, tags)
        return forward_score - gold_score



## Train / Eval loop

In [12]:
def train(model, train_data, valid_data, word_vocab, label_vocab, epochs, log_interval=25):
    losses_per_epoch = []
    for epoch in range(epochs):
        print(f'--- EPOCH {epoch} ---')
        model.train()
        losses_per_epoch.append([])
        for i, (sent, tags) in enumerate(train_data):
            model.zero_grad()
            sent, tags = sent.to(device), tags.to(device)
            sent = sent.unsqueeze(0)
            tags = tags.unsqueeze(0)
            loss = model.loss(sent, tags, embeddingTrainDict[i])
            loss.backward()
            optimizer.step()

            losses_per_epoch[-1].append(loss.detach().cpu().item())
            if i > 0 and i % log_interval == 0:
                print(f'Avg loss over last {log_interval} updates: {np.mean(losses_per_epoch[-1][-log_interval:])}')

        evaluate(model, valid_data, word_vocab, label_vocab)


def evaluate(model, dataset, word_vocab, label_vocab):
    model.eval()
    losses = []
    scores = []
    true_tags = []
    pred_tags = []
    sents = []
    for i, (sent, tags) in enumerate(dataset):
        if i == 0:
            pass
        with torch.no_grad():
            sent, tags = sent.to(device), tags.to(device)
            sent = sent.unsqueeze(0)
            tags = tags.unsqueeze(0)
            losses.append(model.loss(sent, tags, embeddingValDict[i]).cpu().detach().item())
            score, pred_tag_seq = model(sent, embeddingValDict[i])
            scores.append(score)
            try:
                true_tags.append([label_vocab.itos[i] for i in tags.tolist()[0]])
                pred_tags.append([ label_vocab.itos[0] if i == 9 else label_vocab.itos[i] for i in pred_tag_seq[0]])
            except:
                print(pred_tag_seq)
                pass

            sents.append([word_vocab.itos[i] for i in sent[0]])

    print('Avg evaluation loss:', np.mean(losses))
    a = [tag for tags in true_tags for tag in tags]
    b = [tag for tags in pred_tags for tag in tags]
    print(conlleval.evaluate(a, b, verbose=True))


    scores_token = precision_recall_fscore_support(a,b, average=None, labels=label_vocab.itos)
    for i in range(len(label_vocab.itos)):
        print(label_vocab.itos[i] , "precision", scores_token[0][i], "recall", scores_token[1][i], "f1", scores_token[2][i], "count", scores_token[3][i])

#     print('\n5 random evaluation samples:')
#     for i in np.random.randint(0, len(sents), size=2):
#         print('SENT:', ' '.join(sents[i]))
#         print('TRUE:', ' '.join(true_tags[i]))
#         print('PRED:', ' '.join(pred_tags[i]))
    return sents, true_tags, pred_tags


## Training

In [13]:
# Train BiLSTM Tagger Baseline
model = BiLSTMTagger(len(word_vocab), len(label_vocab)+2, 128, 256).to(device)


optimizer = optim.Adam(model.parameters(), lr=1e-3)
train(model, train_data, valid_data, word_vocab, label_vocab, epochs=10, log_interval=500)

--- EPOCH 0 ---
Avg loss over last 500 updates: 10.385956935405732
Avg loss over last 500 updates: 7.233801934480667
Avg loss over last 500 updates: 4.807391242742538
Avg loss over last 500 updates: 4.248341681003571
Avg loss over last 500 updates: 3.521166482448578
Avg loss over last 500 updates: 3.534409637928009
Avg evaluation loss: 3.278457882106304
processed 11170 tokens with 1231 phrases; found: 1323 phrases; correct: 667.
accuracy:  60.85%; (non-O)
accuracy:  92.09%; precision:  50.42%; recall:  54.18%; FB1:  52.23
              LOC: precision:  72.66%; recall:  51.24%; FB1:  60.10  256
             MISC: precision:  38.78%; recall:  29.69%; FB1:  33.63  147
              ORG: precision:  33.20%; recall:  53.09%; FB1:  40.85  491
              PER: precision:  60.84%; recall:  70.73%; FB1:  65.41  429
(50.41572184429327, 54.18359057676686, 52.23179326546593)
O precision 0.9779529236340399 recall 0.9816121445370964 f1 0.9797791175372139 count 9354
I-PER precision 0.68894952251023

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Avg loss over last 500 updates: 3.0160853567123413
Avg loss over last 500 updates: 3.3547597320079805
Avg loss over last 500 updates: 2.601821985244751
Avg loss over last 500 updates: 2.670689762592316
Avg loss over last 500 updates: 2.082819594860077
Avg loss over last 500 updates: 2.3880317096710204
Avg evaluation loss: 2.594623070061207
processed 11170 tokens with 1231 phrases; found: 1316 phrases; correct: 822.
accuracy:  70.32%; (non-O)
accuracy:  93.98%; precision:  62.46%; recall:  66.77%; FB1:  64.55
              LOC: precision:  80.31%; recall:  70.80%; FB1:  75.26  320
             MISC: precision:  62.00%; recall:  48.44%; FB1:  54.39  150
              ORG: precision:  42.86%; recall:  62.54%; FB1:  50.86  448
              PER: precision:  70.35%; recall:  75.88%; FB1:  73.01  398
(62.462006079027354, 66.77497969130788, 64.54652532391049)
O precision 0.9824206264649478 recall 0.9857814838571735 f1 0.9840981856990395 count 9354
I-PER precision 0.7906626506024096 recall 0.8

In [14]:
label_vocab.__dict__

{'freqs': Counter({'I-ORG': 2258,
          'O': 38899,
          'I-PER': 2544,
          'I-LOC': 1836,
          'I-MISC': 1011,
          'B-MISC': 9,
          'B-ORG': 5,
          'B-LOC': 3}),
 'itos': ['O',
  'I-PER',
  'I-ORG',
  'I-LOC',
  'I-MISC',
  'B-MISC',
  'B-ORG',
  'B-LOC'],
 'unk_index': None,
 'stoi': defaultdict(None,
             {'O': 0,
              'I-PER': 1,
              'I-ORG': 2,
              'I-LOC': 3,
              'I-MISC': 4,
              'B-MISC': 5,
              'B-ORG': 6,
              'B-LOC': 7}),
 'vectors': None}

In [15]:
sents, true_tags, pred_tags =  evaluate(model, valid_data, word_vocab, label_vocab)


Avg evaluation loss: 2.505929661244154
processed 11170 tokens with 1231 phrases; found: 1279 phrases; correct: 940.
accuracy:  79.79%; (non-O)
accuracy:  95.84%; precision:  73.49%; recall:  76.36%; FB1:  74.90
              LOC: precision:  82.58%; recall:  80.99%; FB1:  81.78  356
             MISC: precision:  75.16%; recall:  61.46%; FB1:  67.62  157
              ORG: precision:  60.11%; recall:  68.73%; FB1:  64.13  351
              PER: precision:  76.39%; recall:  85.91%; FB1:  80.87  415
(73.49491790461298, 76.36068237205524, 74.9003984063745)
O precision 0.9868855954792621 recall 0.9895231986316014 f1 0.9882026370575989 count 9354
I-PER precision 0.8369723435225619 recall 0.9012539184952978 f1 0.8679245283018868 count 638
I-ORG precision 0.7272727272727273 recall 0.7510204081632653 f1 0.7389558232931728 count 490
I-LOC precision 0.8653366583541147 recall 0.8261904761904761 f1 0.8453105968331303 count 420
I-MISC precision 0.8112244897959183 recall 0.6045627376425855 f1 0.6928

In [16]:
a = []
b = []
for i,sent in enumerate(sents):
    for j,word in enumerate(sent):
        if word == "<unk>":
            a.append(true_tags[i][j])
            b.append(pred_tags[i][j])

In [17]:
scores_token = precision_recall_fscore_support(a,b, average=None, labels=label_vocab.itos)


In [18]:
scores_token = precision_recall_fscore_support(a,b, average=None, labels=label_vocab.itos)

for i in range(len(label_vocab.itos)):
    print(label_vocab.itos[i] , "precision", round(scores_token[0][i] * 100,2), 
          "recall", round(scores_token[1][i] * 100, 2), "f1", round(scores_token[2][i] * 100, 2), "count", scores_token[3][i])


O precision 94.71 recall 93.85 f1 94.28 count 1431
I-PER precision 80.16 recall 88.2 f1 83.99 count 449
I-ORG precision 58.62 recall 69.96 f1 63.79 count 243
I-LOC precision 65.0 recall 51.49 f1 57.46 count 101
I-MISC precision 53.85 recall 22.83 f1 32.06 count 92
B-MISC precision 0.0 recall 0.0 f1 0.0 count 1
B-ORG precision 0.0 recall 0.0 f1 0.0 count 0
B-LOC precision 0.0 recall 0.0 f1 0.0 count 4


In [19]:
sents, true_tags, pred_tags = evaluate(model, valid_data, word_vocab, label_vocab )


Avg evaluation loss: 2.5649742801487445
processed 11170 tokens with 1231 phrases; found: 1284 phrases; correct: 940.
accuracy:  79.57%; (non-O)
accuracy:  95.68%; precision:  73.21%; recall:  76.36%; FB1:  74.75
              LOC: precision:  83.90%; recall:  81.82%; FB1:  82.85  354
             MISC: precision:  74.38%; recall:  61.98%; FB1:  67.61  160
              ORG: precision:  57.85%; recall:  68.40%; FB1:  62.69  363
              PER: precision:  77.15%; recall:  85.09%; FB1:  80.93  407
(73.20872274143302, 76.36068237205524, 74.75149105367794)
O precision 0.9870767916266154 recall 0.9880265127218303 f1 0.9875514238392905 count 9354
I-PER precision 0.8404726735598228 recall 0.8918495297805643 f1 0.8653992395437263 count 638
I-ORG precision 0.699047619047619 recall 0.7489795918367347 f1 0.7231527093596058 count 490
I-LOC precision 0.8592592592592593 recall 0.8285714285714286 f1 0.8436363636363636 count 420
I-MISC precision 0.805 recall 0.6121673003802282 f1 0.6954643628509719

In [20]:
for i in range(len(sents)):
    print(" ".join(sents[i]))
    print(true_tags[i])
    print(pred_tags[i])
    print("\n\n")

-DOCSTART-
['O']
['O']



Earlier this month , <unk> denied a Kabul government statement that the two sides had agreed to a ceasefire in the north .
['O', 'O', 'O', 'O', 'I-PER', 'O', 'O', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
['O', 'O', 'O', 'O', 'I-PER', 'O', 'O', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']



0000 <unk> <unk> 0000 0000 0000 0000
['O', 'I-PER', 'I-PER', 'O', 'O', 'O', 'O']
['O', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O']



9. <unk> <unk> ( China ) <unk>
['O', 'I-PER', 'I-PER', 'O', 'I-LOC', 'O', 'O']
['O', 'I-PER', 'I-PER', 'O', 'I-LOC', 'O', 'O']



<unk> 0000 0000 0000 0000 0000 0000 0000
['I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
['I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O']



The court <unk> <unk> 's <unk> that <unk> 's <unk> from Denmark , where he was arrested in March last year at the request of German authorities , was illegal .
['O', 'O', 'O', 'I-PER', 'O', 'O', 'O', 'I-PER',