# CSE 291 Assignment 2 BiLSTM CRF

## Download Data/Eval Script

In [1]:
!wget https://raw.githubusercontent.com/sighsmile/conlleval/master/conlleval.py
!wget https://raw.githubusercontent.com/tberg12/cse291spr21/main/assignment2/train.data.quad
!wget https://raw.githubusercontent.com/tberg12/cse291spr21/main/assignment2/dev.data.quad

--2021-05-31 05:08:32--  https://raw.githubusercontent.com/sighsmile/conlleval/master/conlleval.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7502 (7.3K) [text/plain]
Saving to: ‘conlleval.py.3’


2021-05-31 05:08:32 (55.5 MB/s) - ‘conlleval.py.3’ saved [7502/7502]

--2021-05-31 05:08:32--  https://raw.githubusercontent.com/tberg12/cse291spr21/main/assignment2/train.data.quad
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 745734 (728K) [text/plain]
Saving to: ‘train.data.quad.2’


2021-05-31 05:08:32 (11.9 MB/s) - ‘train.dat

In [2]:
import conlleval
from tqdm import tqdm
import numpy as np
from collections import defaultdict, Counter
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
from torchtext.vocab import Vocab
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import pdb 
from sklearn.metrics import precision_recall_fscore_support


torch.manual_seed(291)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


## Data Preparation

In [3]:
TRAIN_DATA = 'train.data.quad'
VALID_DATA = 'dev.data.quad'
UNK = '<unk>'
PAD = '<pad>'
START_TAG = "<start>"  # you can add this explicitly or use it implicitly in your CRF layer
STOP_TAG = "<stop>"    # you can add this explicitly or use it implicitly in your CRF layer


def read_conll_sentence(path):
    """ Read a CONLL-format sentence into vocab objects
    Args:
        :param path: path to CONLL-format data file
        :param word_vocab: Vocabulary object for source
        :param label_vocab: Vocabulary object for target
    """
    sent = [[], []]
    with open(path) as f:
        for line in f:
            line = line.strip().split()
            if line:
                # replace numbers with 0000
                word = line[0]
                word = '0000' if word.isnumeric() else word
                sent[0].append(word)
                sent[1].append(line[3])
            else:
                yield sent[0], sent[1]
                sent = [[], []]


def prepare_dataset(dataset, word_vocab, label_vocab):
    dataset = [
      [
        torch.tensor([word_vocab.stoi[word] for word in sent[0]], dtype=torch.long),
        torch.tensor([label_vocab.stoi[label] for label in sent[1]], dtype=torch.long),
      ]
      for sent in dataset
    ]
    return dataset


# load a list of sentences, where each word in the list is a tuple containing the word and the label
train_data = list(read_conll_sentence(TRAIN_DATA))
train_word_counter = Counter([word for sent in train_data for word in sent[0]])
train_label_counter = Counter([label for sent in train_data for label in sent[1]])
word_vocab = Vocab(train_word_counter, specials=(UNK, PAD), min_freq=2)
label_vocab = Vocab(train_label_counter, specials=(), min_freq=1)
train_data = prepare_dataset(train_data, word_vocab, label_vocab)
print('Train word vocab:', len(word_vocab), 'symbols.')
print('Train label vocab:', len(label_vocab), f'symbols: {list(label_vocab.stoi.keys())}')
valid_data = list(read_conll_sentence(VALID_DATA))
valid_data = prepare_dataset(valid_data, word_vocab, label_vocab)
print('Train data:', len(train_data), 'sentences.')
print('Valid data:', len(valid_data))

print(' '.join([word_vocab.itos[i.item()] for i in train_data[0][0]]))
print(' '.join([label_vocab.itos[i.item()] for i in train_data[0][1]]))

print(' '.join([word_vocab.itos[i.item()] for i in valid_data[1][0]]))
print(' '.join([label_vocab.itos[i.item()] for i in valid_data[1][1]]))

Train word vocab: 3947 symbols.
Train label vocab: 8 symbols: ['O', 'I-PER', 'I-ORG', 'I-LOC', 'I-MISC', 'B-MISC', 'B-ORG', 'B-LOC']
Train data: 3420 sentences.
Valid data: 800
Pusan 0000 0000 0000 0000 0000 0000
I-ORG O O O O O O
Earlier this month , <unk> denied a Kabul government statement that the two sides had agreed to a ceasefire in the north .
O O O O I-PER O O I-LOC O O O O O O O O O O O O O O O


## BiLSTMTagger

In [4]:
# Starter code implementing a BiLSTM Tagger
# which makes locally normalized, independent
# tag classifications at each time step

class BiLSTMTagger(nn.Module):
    def __init__(self, vocab_size, tag_vocab_size, embedding_dim, hidden_dim, dropout=0.3):
        super(BiLSTMTagger, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.tagset_size = tag_vocab_size
        self.word_embeds = nn.Embedding(vocab_size, embedding_dim).to(device)
        self.bilstm = nn.LSTM(embedding_dim, hidden_dim // 2,
                            num_layers=1, bidirectional=True, batch_first=True).to(device)
        self.tag_projection_layer = nn.Linear(hidden_dim, self.tagset_size).to(device)
        self.dropout = nn.Dropout(p=dropout)

    def init_hidden(self):
        return (torch.randn(2, 1, self.hidden_dim // 2).to(device),
                torch.randn(2, 1, self.hidden_dim // 2).to(device))

    def compute_lstm_emission_features(self, sentence):
        hidden = self.init_hidden()
        embeds = self.dropout(self.word_embeds(sentence))
        bilstm_out, hidden = self.bilstm(embeds, hidden)
        bilstm_out = self.dropout(bilstm_out)
        bilstm_out = bilstm_out
        bilstm_feats = self.tag_projection_layer(bilstm_out)
        return bilstm_feats

    def forward(self, sentence):
        bilstm_feats = self.compute_lstm_emission_features(sentence)

        return bilstm_feats.max(-1)[0].sum(), bilstm_feats.argmax(-1)

    def loss(self, sentence, tags):
        bilstm_feats = self.compute_lstm_emission_features(sentence)
        # transform predictions to (n_examples, n_classes) and ground truth to (n_examples)
        return torch.nn.functional.cross_entropy(
              bilstm_feats.view(-1, self.tagset_size), 
              tags.view(-1), 
              reduction='sum'
            )



## Train / Eval loop

In [5]:

def train(model, train_data, valid_data, word_vocab, label_vocab, epochs, log_interval=25):
    losses_per_epoch = []
    for epoch in range(epochs):
        print(f'--- EPOCH {epoch} ---')
        model.train()
        losses_per_epoch.append([])
        for i, (sent, tags) in enumerate(train_data):
            model.zero_grad()
            sent, tags = sent.to(device), tags.to(device)
            sent = sent.unsqueeze(0)
            tags = tags.unsqueeze(0)
            loss = model.loss(sent, tags)
            loss.backward()
            optimizer.step()

            losses_per_epoch[-1].append(loss.detach().cpu().item())
            if i > 0 and i % log_interval == 0:
                print(f'Avg loss over last {log_interval} updates: {np.mean(losses_per_epoch[-1][-log_interval:])}')

        evaluate(model, valid_data, word_vocab, label_vocab)


def evaluate(model, dataset, word_vocab, label_vocab):
    model.eval()
    losses = []
    scores = []
    true_tags = []
    pred_tags = []
    sents = []
    for i, (sent, tags) in enumerate(dataset):
        if i == 0:
            pass
        with torch.no_grad():
            sent, tags = sent.to(device), tags.to(device)
            sent = sent.unsqueeze(0)
            tags = tags.unsqueeze(0)
            losses.append(model.loss(sent, tags).cpu().detach().item())
            score, pred_tag_seq = model(sent)
            scores.append(score)
            try:
                true_tags.append([label_vocab.itos[i] for i in tags.tolist()[0]])
                pred_tags.append([ label_vocab.itos[0] if i == 9 else label_vocab.itos[i] for i in pred_tag_seq[0]])
            except:
                print(pred_tag_seq)
                pass

            sents.append([word_vocab.itos[i] for i in sent[0]])
    
    print('Avg evaluation loss:', np.mean(losses))
    a = [tag for tags in true_tags for tag in tags]
    b = [tag for tags in pred_tags for tag in tags]
    print(conlleval.evaluate(a, b, verbose=True))

    scores_token = precision_recall_fscore_support(a,b, average=None, labels=label_vocab.itos)
    for i in range(len(label_vocab.itos)):
        print(label_vocab.itos[i] , "precision", scores_token[0][i], "recall", scores_token[1][i], "f1", scores_token[2][i], "count", scores_token[3][i])
#     print('\n5 random evaluation samples:')
#     for i in np.random.randint(0, len(sents), size=2):
#         print('SENT:', ' '.join(sents[i]))
#         print('TRUE:', ' '.join(true_tags[i]))
#         print('PRED:', ' '.join(pred_tags[i]))
    return sents, true_tags, pred_tags


## Training

In [6]:
# Train BiLSTM Tagger Baseline
model = BiLSTMTagger(len(word_vocab), len(label_vocab), 128, 256).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
train(model, train_data, valid_data, word_vocab, label_vocab, epochs=10, log_interval=500)

--- EPOCH 0 ---
Avg loss over last 500 updates: 9.35207405039668
Avg loss over last 500 updates: 8.168168832413853
Avg loss over last 500 updates: 6.62803889903985
Avg loss over last 500 updates: 5.7294120835699145
Avg loss over last 500 updates: 5.023957918460248
Avg loss over last 500 updates: 5.084533827416599
Avg evaluation loss: 4.702194057154993
processed 11170 tokens with 1231 phrases; found: 743 phrases; correct: 471.
accuracy:  44.05%; (non-O)
accuracy:  89.64%; precision:  63.39%; recall:  38.26%; FB1:  47.72
              LOC: precision:  80.43%; recall:  40.77%; FB1:  54.11  184
             MISC: precision:  53.23%; recall:  17.19%; FB1:  25.98  62
              ORG: precision:  60.31%; recall:  38.11%; FB1:  46.71  194
              PER: precision:  57.10%; recall:  46.88%; FB1:  51.49  303
(63.39165545087483, 38.26157595450853, 47.72036474164134)
O precision 0.9116366514941618 recall 0.9849262347658756 f1 0.9468653648509764 count 9354
I-PER precision 0.7270992366412213 r

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Avg loss over last 500 updates: 4.534435162678128
Avg loss over last 500 updates: 4.958389579779817
Avg loss over last 500 updates: 4.050141175725388
Avg loss over last 500 updates: 3.8272592019468576
Avg loss over last 500 updates: 3.3551137455590068
Avg loss over last 500 updates: 3.4140576019254003
Avg evaluation loss: 3.7935850188352833
processed 11170 tokens with 1231 phrases; found: 954 phrases; correct: 635.
accuracy:  57.98%; (non-O)
accuracy:  91.76%; precision:  66.56%; recall:  51.58%; FB1:  58.12
              LOC: precision:  85.19%; recall:  57.02%; FB1:  68.32  243
             MISC: precision:  72.90%; recall:  40.62%; FB1:  52.17  107
              ORG: precision:  51.85%; recall:  45.60%; FB1:  48.53  270
              PER: precision:  62.87%; recall:  56.91%; FB1:  59.74  334
(66.56184486373165, 51.58407798537774, 58.12356979405035)
O precision 0.9376083188908145 recall 0.9832157365832799 f1 0.9598705839377968 count 9354
I-PER precision 0.784965034965035 recall 0.703

In [7]:
label_vocab.__dict__

{'freqs': Counter({'I-ORG': 2258,
          'O': 38899,
          'I-PER': 2544,
          'I-LOC': 1836,
          'I-MISC': 1011,
          'B-MISC': 9,
          'B-ORG': 5,
          'B-LOC': 3}),
 'itos': ['O',
  'I-PER',
  'I-ORG',
  'I-LOC',
  'I-MISC',
  'B-MISC',
  'B-ORG',
  'B-LOC'],
 'unk_index': None,
 'stoi': defaultdict(None,
             {'O': 0,
              'I-PER': 1,
              'I-ORG': 2,
              'I-LOC': 3,
              'I-MISC': 4,
              'B-MISC': 5,
              'B-ORG': 6,
              'B-LOC': 7}),
 'vectors': None}

In [9]:
sents, true_tags, pred_tags = evaluate(model, valid_data, word_vocab, label_vocab )


Avg evaluation loss: 3.480769702887892
processed 11170 tokens with 1231 phrases; found: 1117 phrases; correct: 844.
accuracy:  72.69%; (non-O)
accuracy:  94.32%; precision:  75.56%; recall:  68.56%; FB1:  71.89
              LOC: precision:  87.50%; recall:  77.13%; FB1:  81.99  320
             MISC: precision:  79.58%; recall:  58.85%; FB1:  67.66  142
              ORG: precision:  68.05%; recall:  58.96%; FB1:  63.18  266
              PER: precision:  69.41%; recall:  73.17%; FB1:  71.24  389
(75.55953446732319, 68.5621445978879, 71.89097103918229)
O precision 0.9605962681121651 recall 0.9851400470387001 f1 0.9727133583153007 count 9354
I-PER precision 0.8143074581430746 recall 0.8385579937304075 f1 0.8262548262548264 count 638
I-ORG precision 0.8198924731182796 recall 0.6224489795918368 f1 0.7076566125290024 count 490
I-LOC precision 0.8888888888888888 recall 0.780952380952381 f1 0.8314321926489227 count 420
I-MISC precision 0.8539325842696629 recall 0.5779467680608364 f1 0.68934

In [10]:
a = []
b = []
for i,sent in enumerate(sents):
    for j,word in enumerate(sent):
        if word == "<unk>":
            a.append(true_tags[i][j])
            b.append(pred_tags[i][j])

In [18]:
scores_token = precision_recall_fscore_support(a,b, average=None, labels=label_vocab.itos)

for i in range(len(label_vocab.itos)):
    print(label_vocab.itos[i] , "precision", round(scores_token[0][i] * 100,2), 
          "recall", round(scores_token[1][i] * 100, 2), "f1", round(scores_token[2][i] * 100, 2), "count", scores_token[3][i])


O precision 81.05 recall 90.85 f1 85.67 count 1431
I-PER precision 76.38 recall 79.96 f1 78.13 count 449
I-ORG precision 69.83 recall 51.44 f1 59.24 count 243
I-LOC precision 75.51 recall 36.63 f1 49.33 count 101
I-MISC precision 52.63 recall 10.87 f1 18.02 count 92
B-MISC precision 0.0 recall 0.0 f1 0.0 count 1
B-ORG precision 0.0 recall 0.0 f1 0.0 count 0
B-LOC precision 0.0 recall 0.0 f1 0.0 count 4


In [13]:
scores_token = precision_recall_fscore_support(a,b)
