# CSE 291 Assignment 2 BiLSTM CRF

## Download Data/Eval Script

In [1]:
# !wget https://raw.githubusercontent.com/sighsmile/conlleval/master/conlleval.py
# !wget https://raw.githubusercontent.com/tberg12/cse291spr21/main/assignment2/train.data.quad
# !wget https://raw.githubusercontent.com/tberg12/cse291spr21/main/assignment2/dev.data.quad

In [2]:
# !pip install pytorch-transformers
!pip install -U numpy

You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m


In [3]:
import torch
from pytorch_transformers import BertTokenizer
from pytorch_transformers import BertModel## Load pretrained model/tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


In [4]:

bertmodel = BertModel.from_pretrained('bert-base-uncased',
                                  output_hidden_states = True, # Whether the model returns all hidden-states.
                                  )


def get_word_embedding(marked_text):

    # Tokenize our sentence with the BERT tokenizer.
    tokenized_text = [word.lower() for word in marked_text.split(" ")]

    # Map the token strings to their vocabulary indeces.
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    
    segments_ids = [1] * len(tokenized_text)

    tokens_tensor = torch.tensor([indexed_tokens])
    
    segments_tensors = torch.tensor([segments_ids])
    
    with torch.no_grad():

        outputs = bertmodel(tokens_tensor, segments_tensors)

        # Evaluating the model will return a different number of objects based on 
        # how it's  configured in the `from_pretrained` call earlier. In this case, 
        # becase we set `output_hidden_states = True`, the third item will be the 
        # hidden states from all layers. See the documentation for more details:
        # https://huggingface.co/transformers/model_doc/bert.html#bertmodel
        hidden_states = outputs[2]
        
        token_embeddings = torch.stack(hidden_states, dim=0)
        
        token_embeddings = torch.squeeze(token_embeddings, dim=1)

        token_embeddings = token_embeddings.permute(1,0,2)

        token_vecs_sum = []

        # `token_embeddings` is a [22 x 12 x 768] tensor.

        # For each token in the sentence...
        for token in token_embeddings:

            # `token` is a [12 x 768] tensor

            # Sum the vectors from the last four layers.
            sum_vec = torch.sum(token[-4:], dim=0)

            # Use `sum_vec` to represent `token`.
            token_vecs_sum.append(sum_vec)
        
        
        return token_vecs_sum

In [5]:
import conlleval
from tqdm import tqdm
import numpy as np
from collections import defaultdict, Counter
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
from torchtext.vocab import Vocab
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import pdb 
from sklearn.metrics import precision_recall_fscore_support


torch.manual_seed(291)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


## Data Preparation

In [6]:
TRAIN_DATA = 'train.data.quad'
VALID_DATA = 'dev.data.quad'
UNK = '<unk>'
PAD = '<pad>'
START_TAG = "<start>"  # you can add this explicitly or use it implicitly in your CRF layer
STOP_TAG = "<stop>"    # you can add this explicitly or use it implicitly in your CRF layer


def read_conll_sentence(path):
    """ Read a CONLL-format sentence into vocab objects
    Args:
        :param path: path to CONLL-format data file
        :param word_vocab: Vocabulary object for source
        :param label_vocab: Vocabulary object for target
    """
    sent = [[], []]
    with open(path) as f:
        for line in f:
            line = line.strip().split()
            if line:
                # replace numbers with 0000
                word = line[0]
                word = '0000' if word.isnumeric() else word
                sent[0].append(word)
                sent[1].append(line[3])
            else:
                yield sent[0], sent[1]
                sent = [[], []]



def prepare_dataset(dataset, word_vocab, label_vocab):
    dataset = [
      [
        torch.tensor([word_vocab.stoi[word] for word in sent[0]], dtype=torch.long),
        torch.tensor([label_vocab.stoi[label] for label in sent[1]], dtype=torch.long),
      ]
      for sent in dataset
    ]
    return dataset


# load a list of sentences, where each word in the list is a tuple containing the word and the label
train_data = list(read_conll_sentence(TRAIN_DATA))
train_word_counter = Counter([word for sent in train_data for word in sent[0]])
train_label_counter = Counter([label for sent in train_data for label in sent[1]])
word_vocab = Vocab(train_word_counter, specials=(UNK, PAD), min_freq=2)
label_vocab = Vocab(train_label_counter, specials=(), min_freq=1)
print(train_data[0])

train_raw_sent = [" ".join(sent[0]) for sent in train_data]

train_data = prepare_dataset(train_data, word_vocab, label_vocab)
print('Train word vocab:', len(word_vocab), 'symbols.')
print('Train label vocab:', len(label_vocab), f'symbols: {list(label_vocab.stoi.keys())}')
valid_data = list(read_conll_sentence(VALID_DATA))

val_raw_sent = [" ".join(sent[0]) for sent in valid_data]

valid_data = prepare_dataset(valid_data, word_vocab, label_vocab)
print('Train data:', len(train_data), 'sentences.')
print('Valid data:', len(valid_data))

print(' '.join([word_vocab.itos[i.item()] for i in train_data[1][0]]))
print(' '.join([label_vocab.itos[i.item()] for i in train_data[1][1]]))

print(' '.join([word_vocab.itos[i.item()] for i in valid_data[1][0]]))
print(' '.join([label_vocab.itos[i.item()] for i in valid_data[1][1]]))

(['Pusan', '0000', '0000', '0000', '0000', '0000', '0000'], ['I-ORG', 'O', 'O', 'O', 'O', 'O', 'O'])
Train word vocab: 3947 symbols.
Train label vocab: 8 symbols: ['O', 'I-PER', 'I-ORG', 'I-LOC', 'I-MISC', 'B-MISC', 'B-ORG', 'B-LOC']
Train data: 3420 sentences.
Valid data: 800
Martin Brooks , president of Miss Universe Inc , said he spoke with Machado to <unk> her that organisers were not <unk> pressure on her .
I-PER I-PER O O O I-ORG I-ORG I-ORG O O O O O I-PER O O O O O O O O O O O O
Earlier this month , <unk> denied a Kabul government statement that the two sides had agreed to a ceasefire in the north .
O O O O I-PER O O I-LOC O O O O O O O O O O O O O O O


In [7]:
embeddingTrainDict = {}

for i, (sent, tags) in enumerate(train_data):
    
    embeddingTrainDict[i] = torch.vstack(get_word_embedding(train_raw_sent[i])).unsqueeze(0).to(device)

embeddingValDict = {}

for i, (sent, tags) in enumerate(valid_data):
    
    embeddingValDict[i] = torch.vstack(get_word_embedding(val_raw_sent[i])).unsqueeze(0).to(device)


In [8]:
train_raw_sent[1], tokenizer.tokenize(train_raw_sent[1])

('Martin Brooks , president of Miss Universe Inc , said he spoke with Machado to assure her that organisers were not putting pressure on her .',
 ['martin',
  'brooks',
  ',',
  'president',
  'of',
  'miss',
  'universe',
  'inc',
  ',',
  'said',
  'he',
  'spoke',
  'with',
  'mach',
  '##ado',
  'to',
  'assure',
  'her',
  'that',
  'organise',
  '##rs',
  'were',
  'not',
  'putting',
  'pressure',
  'on',
  'her',
  '.'])

In [9]:
embeddingTrainDict[0].shape

torch.Size([1, 7, 768])

In [10]:

def log_sum_exp(vec):
    max_score = vec.max()
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

## BiLSTMTagger

In [11]:
class CRF():
    def __init__(self, vocab_size, tag_vocab_size):
        
        self.vocab_size = vocab_size
        self.tagset_size = tag_vocab_size

        self.transitions = nn.Parameter(
            torch.randn(self.tagset_size, self.tagset_size)).to(device)

        self.start_index = tag_vocab_size - 2
        self.end_index = tag_vocab_size - 1
        
        self.transitions.data[self.start_index, :] = -10000
        self.transitions.data[:, self.end_index] = -10000
        
    
    def _viterbi_decode(self, feats):
        backtrace = []
        alpha = torch.full((1, self.tagset_size), -10000.0, device=device)
        alpha[0][self.start_index] = 0
        feats = feats.squeeze(0)
#         pdb.set_trace()
        for feat in feats:
            t2 = (self.transitions.T + (feat)).T + alpha
            t3 = t2.max(dim=1).values
            alpha = t3 + (t2.T - t3).T.exp().sum(dim=1).log().view(1,-1)
            backtrace.append(t2.argmax(dim=1))

        vec = alpha.T + self.transitions[:, [self.end_index]]
        best_tag = vec.squeeze(1).argmax().item()
        optimal_path = [best_tag]
        
        seq = reversed(backtrace[1:])
        
        for bp in seq: 
            best_tag = bp[best_tag].item()
            optimal_path.append(best_tag)
        
        score = (vec - vec.max()).exp().sum(axis=0, keepdim=True).log() + vec.max()
        return score,[ optimal_path[::-1]] 
    
     
    def _forward_alg(self, feats):
        init_alphas = torch.full((1, self.tagset_size), -10000.)

        init_alphas[0][self.start_index] = 0.


        forward_var = init_alphas.to(device)

        feats = feats.squeeze(0)

        for feat in feats:
            alphas_t = []  # The forward tensors at this timestep
            t2 = (self.transitions.T + (feat)).T + forward_var
            t3 = t2.max(dim=1).values
            forward_var = t3 + (t2.T - t3).T.exp().sum(dim=1).log().view(1,-1)

        terminal_var = forward_var + self.transitions[self.end_index]
        alpha = log_sum_exp(terminal_var)
        return alpha
    
    
    def _score_sentence(self, feats, tags):

        score = torch.zeros(1).to(device)
        feats = feats.squeeze(0)
        tags = tags.squeeze(0)
        tags = torch.cat([torch.tensor([self.start_index], dtype=torch.long).to(device), tags]).to(device)
        for i, feat in enumerate(feats):
            score = score + self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]

        score = score + self.transitions[self.end_index, tags[-1]]
        return score


In [12]:
# Starter code implementing a BiLSTM Tagger
# which makes locally normalized, independent
# tag classifications at each time step

class BiLSTMTagger(nn.Module):
    def __init__(self, vocab_size, tag_vocab_size, embedding_dim, hidden_dim, dropout=0.3):
        super(BiLSTMTagger, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.tagset_size = tag_vocab_size
#         self.word_embeds = nn.Embedding(vocab_size, embedding_dim).to(device)
        self.bilstm = nn.LSTM(embedding_dim, hidden_dim // 2,
                            num_layers=1, bidirectional=True, batch_first=True).to(device)
        self.tag_projection_layer = nn.Linear(hidden_dim, self.tagset_size).to(device)
        self.dropout = nn.Dropout(p=dropout)

        self.crf = CRF(vocab_size, tag_vocab_size)


    def init_hidden(self):
        return (torch.randn(2, 1, self.hidden_dim // 2).to(device),
                torch.randn(2, 1, self.hidden_dim // 2).to(device))

    def compute_lstm_emission_features(self, sentence, embedding):
        hidden = self.init_hidden()
        embeds = self.dropout(embedding)
        bilstm_out, hidden = self.bilstm(embeds, hidden)
        bilstm_out = self.dropout(bilstm_out)
        bilstm_out = bilstm_out
        bilstm_feats = self.tag_projection_layer(bilstm_out)
        return bilstm_feats

     
    def forward(self, sentence, embedding):
        bilstm_feats = self.compute_lstm_emission_features(sentence, embedding)

        return self.crf._viterbi_decode(bilstm_feats)
  
    def loss(self, sentence, tags, embedding):
        feats = self.compute_lstm_emission_features(sentence, embedding)
        forward_score = self.crf._forward_alg(feats)
        gold_score = self.crf._score_sentence(feats, tags)
        return forward_score - gold_score

    

## Train / Eval loop

In [13]:
def train(model, train_data, valid_data, word_vocab, label_vocab, epochs, log_interval=25):
    losses_per_epoch = []
    for epoch in range(epochs):
        print(f'--- EPOCH {epoch} ---')
        model.train()
        losses_per_epoch.append([])
        for i, (sent, tags) in enumerate(train_data):
            model.zero_grad()
            sent, tags = sent.to(device), tags.to(device)
            sent = sent.unsqueeze(0)
            tags = tags.unsqueeze(0)
            loss = model.loss(sent, tags, embeddingTrainDict[i])
            loss.backward()
            optimizer.step()

            losses_per_epoch[-1].append(loss.detach().cpu().item())
            if i > 0 and i % log_interval == 0:
                print(f'Avg loss over last {log_interval} updates: {np.mean(losses_per_epoch[-1][-log_interval:])}')

        evaluate(model, valid_data, word_vocab, label_vocab)


def evaluate(model, dataset, word_vocab, label_vocab):
    model.eval()
    losses = []
    scores = []
    true_tags = []
    pred_tags = []
    sents = []
    for i, (sent, tags) in enumerate(dataset):
        if i == 0:
            pass
        with torch.no_grad():
            sent, tags = sent.to(device), tags.to(device)
            sent = sent.unsqueeze(0)
            tags = tags.unsqueeze(0)
            losses.append(model.loss(sent, tags, embeddingValDict[i]).cpu().detach().item())
            score, pred_tag_seq = model(sent, embeddingValDict[i])
            scores.append(score)
            try:
                true_tags.append([label_vocab.itos[i] for i in tags.tolist()[0]])
                pred_tags.append([ label_vocab.itos[0] if i == 9 else label_vocab.itos[i] for i in pred_tag_seq[0]])
            except:
                print(pred_tag_seq)
                pass

            sents.append([word_vocab.itos[i] for i in sent[0]])

    print('Avg evaluation loss:', np.mean(losses))
    a = [tag for tags in true_tags for tag in tags]
    b = [tag for tags in pred_tags for tag in tags]
    print(conlleval.evaluate(a, b, verbose=True))


    scores_token = precision_recall_fscore_support(a,b, average=None, labels=label_vocab.itos)
    for i in range(len(label_vocab.itos)):
        print(label_vocab.itos[i] , "precision", scores_token[0][i], "recall", scores_token[1][i], "f1", scores_token[2][i], "count", scores_token[3][i])

#     print('\n5 random evaluation samples:')
#     for i in np.random.randint(0, len(sents), size=2):
#         print('SENT:', ' '.join(sents[i]))
#         print('TRUE:', ' '.join(true_tags[i]))
#         print('PRED:', ' '.join(pred_tags[i]))
    return sents, true_tags, pred_tags


## Training

In [14]:
# Train BiLSTM Tagger Baseline
model = BiLSTMTagger(len(word_vocab), len(label_vocab)+2, 768, 256).to(device)


optimizer = optim.Adam(model.parameters(), lr=1e-3)
train(model, train_data, valid_data, word_vocab, label_vocab, epochs=10, log_interval=500)

--- EPOCH 0 ---
Avg loss over last 500 updates: 7.185383933067322
Avg loss over last 500 updates: 5.238719256401062
Avg loss over last 500 updates: 4.438768275260926
Avg loss over last 500 updates: 4.228360407352447
Avg loss over last 500 updates: 4.062265147686005
Avg loss over last 500 updates: 4.0703121407032015
Avg evaluation loss: 4.02682015940547
processed 11170 tokens with 1231 phrases; found: 829 phrases; correct: 523.
accuracy:  51.10%; (non-O)
accuracy:  91.32%; precision:  63.09%; recall:  42.49%; FB1:  50.78
              LOC: precision:  81.17%; recall:  53.44%; FB1:  64.45  239
             MISC: precision:  63.64%; recall:  25.52%; FB1:  36.43  77
              ORG: precision:  43.59%; recall:  27.69%; FB1:  33.86  195
              PER: precision:  61.32%; recall:  52.85%; FB1:  56.77  318
(63.08805790108565, 42.48578391551584, 50.77669902912621)
O precision 0.9247033010870649 recall 0.9912336968141972 f1 0.9568133739229142 count 9354
I-PER precision 0.8476953907815631 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Avg loss over last 500 updates: 3.7595493021011355
Avg loss over last 500 updates: 4.0923817048072815
Avg loss over last 500 updates: 3.6457367897033692
Avg loss over last 500 updates: 3.5603699798583985
Avg loss over last 500 updates: 3.51646524477005
Avg loss over last 500 updates: 3.605425795793533
Avg evaluation loss: 3.8896447882056235
processed 11170 tokens with 1231 phrases; found: 858 phrases; correct: 556.
accuracy:  53.80%; (non-O)
accuracy:  91.86%; precision:  64.80%; recall:  45.17%; FB1:  53.23
              LOC: precision:  82.86%; recall:  55.92%; FB1:  66.78  245
             MISC: precision:  58.76%; recall:  29.69%; FB1:  39.45  97
              ORG: precision:  47.92%; recall:  29.97%; FB1:  36.87  192
              PER: precision:  62.96%; recall:  55.28%; FB1:  58.87  324
(64.80186480186481, 45.166531275385864, 53.231211105792255)
O precision 0.927936031984008 recall 0.9925165704511439 f1 0.9591404514696007 count 9354
I-PER precision 0.8772277227722772 recall 0.69

In [15]:
evaluate(model, valid_data, word_vocab, label_vocab)


Avg evaluation loss: 3.4975596888363363
processed 11170 tokens with 1231 phrases; found: 983 phrases; correct: 664.
accuracy:  61.34%; (non-O)
accuracy:  92.77%; precision:  67.55%; recall:  53.94%; FB1:  59.98
              LOC: precision:  77.52%; recall:  63.64%; FB1:  69.89  298
             MISC: precision:  56.19%; recall:  30.73%; FB1:  39.73  105
              ORG: precision:  55.61%; recall:  40.39%; FB1:  46.79  223
              PER: precision:  70.03%; recall:  67.75%; FB1:  68.87  357
(67.54832146490337, 53.93988627132412, 59.98193315266486)
O precision 0.9426154316583427 recall 0.9886679495403036 f1 0.9650926167492826 count 9354
I-PER precision 0.8595600676818951 recall 0.7962382445141066 f1 0.8266883645240032 count 638
I-ORG precision 0.75 recall 0.4959183673469388 f1 0.5970515970515972 count 490
I-LOC precision 0.8368580060422961 recall 0.6595238095238095 f1 0.7376830892143809 count 420
I-MISC precision 0.7610619469026548 recall 0.3269961977186312 f1 0.4574468085106383 

([['-DOCSTART-'],
  ['Earlier',
   'this',
   'month',
   ',',
   '<unk>',
   'denied',
   'a',
   'Kabul',
   'government',
   'statement',
   'that',
   'the',
   'two',
   'sides',
   'had',
   'agreed',
   'to',
   'a',
   'ceasefire',
   'in',
   'the',
   'north',
   '.'],
  ['0000', '<unk>', '<unk>', '0000', '0000', '0000', '0000'],
  ['9.', '<unk>', '<unk>', '(', 'China', ')', '<unk>'],
  ['<unk>', '0000', '0000', '0000', '0000', '0000', '0000', '0000'],
  ['The',
   'court',
   '<unk>',
   '<unk>',
   "'s",
   '<unk>',
   'that',
   '<unk>',
   "'s",
   '<unk>',
   'from',
   'Denmark',
   ',',
   'where',
   'he',
   'was',
   'arrested',
   'in',
   'March',
   'last',
   'year',
   'at',
   'the',
   'request',
   'of',
   'German',
   'authorities',
   ',',
   'was',
   'illegal',
   '.'],
  ['<unk>',
   'Sudan',
   'plane',
   'expected',
   'at',
   'London',
   "'s",
   'Stansted',
   '.'],
  ['The',
   'town',
   "'s",
   '<unk>',
   '<unk>',
   'ground',
   'is',
   '