# Project #2: Named Entity Recognition

In this assignment, you will implement a bidirectional LSTM-CNN-CRF for sequence labeling, following [this paper by Xuezhe Ma and Ed Hovy](https://www.aclweb.org/anthology/P16-1101.pdf), on the CoNLL named entity recognition dataset.  Before starting the assignment, we recommend reading the Ma and Hovy paper.

First, let's import some libraries and make sure the runtime has access to a GPU.


In [2]:
# Licensing Information:  You are free to use or extend this project for
# educational purposes provided that (1) you do not distribute or publish
# solutions, (2) you retain this notice, and (3) you provide clear
# attribution to The Georgia Institute of Technology, including a link to  https://aritter.github.io/CS-7650-sp22/

# Attribution Information: This assignment was developed at The Georgia Institute of Technology
# by Alan Ritter (alan.ritter@cc.gatech.edu)
# Contributors: Xurui Zhang (Spring 2022)

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
    print('and then re-execute this cell.')
else:
    print(gpu_info)

print(f'GPU available: {torch.cuda.is_available()}')

Tue Mar 14 02:21:22 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   57C    P0    28W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Download the Data

Run the following code to download the English part of the CoNLL 2003 dataset, the evaluation script and pre-filtered GloVe embeddings we are providing for this data.

In [4]:
#CoNLL 2003 data
!wget https://raw.githubusercontent.com/patverga/torch-ner-nlp-from-scratch/master/data/conll2003/eng.train
!wget https://raw.githubusercontent.com/patverga/torch-ner-nlp-from-scratch/master/data/conll2003/eng.testa
!wget https://raw.githubusercontent.com/patverga/torch-ner-nlp-from-scratch/master/data/conll2003/eng.testb
!cat eng.train | awk '{print $1 "\t" $4}' > train
!cat eng.testa | awk '{print $1 "\t" $4}' > dev
!cat eng.testb | awk '{print $1 "\t" $4}' > test

#Evaluation Script
!wget https://raw.githubusercontent.com/aritter/twitter_nlp/master/data/annotated/wnut16/conlleval.pl

#Pre-filtered GloVe embeddings
!wget https://raw.githubusercontent.com/aritter/aritter.github.io/master/files/glove.840B.300d.conll_filtered.txt

--2023-03-14 02:21:28--  https://raw.githubusercontent.com/patverga/torch-ner-nlp-from-scratch/master/data/conll2003/eng.train
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3283420 (3.1M) [text/plain]
Saving to: ‘eng.train’


2023-03-14 02:21:28 (125 MB/s) - ‘eng.train’ saved [3283420/3283420]

--2023-03-14 02:21:28--  https://raw.githubusercontent.com/patverga/torch-ner-nlp-from-scratch/master/data/conll2003/eng.testa
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 827443 (808K) [text/plain]
Saving to: ‘eng.testa’


2023-03-14 

## CoNLL Data Format

Run the following cell to see a sample of the data in CoNLL format.  As you can see, each line in the file represents a word and its labeled named entity tag in BIO format.  A blank line is used to seperate sentences.

In [5]:
!head -n 20 train

-DOCSTART-	O
	
EU	I-ORG
rejects	O
German	I-MISC
call	O
to	O
boycott	O
British	I-MISC
lamb	O
.	O
	
Peter	I-PER
Blackburn	I-PER
	
BRUSSELS	I-LOC
1996-08-22	O
	
The	O
European	I-ORG


## Reading in the Data

Below we provide a bit of code to read in data in the CoNLL format.  This also reads in the filtered GloVe embeddings, to save you some effort - we will discuss this more later.

In [6]:
#Read in the training data
def read_conll_format(filename):
    (words, tags, currentSent, currentTags) = ([],[],['-START-'],['START'])
    for line in open(filename).readlines():
        line = line.strip()
        #print(line)
        if line == "":
            currentSent.append('-END-')
            currentTags.append('END')
            words.append(currentSent)
            tags.append(currentTags)
            (currentSent, currentTags) = (['-START-'], ['START'])
        else:
            (word, tag) = line.split()
            currentSent.append(word)
            currentTags.append(tag)
    return (words, tags)

def sentences2char(sentences):
    return [[['start'] + [c for c in w] + ['end'] for w in l] for l in sentences]


(sentences_train, tags_train) = read_conll_format("train")
(sentences_dev, tags_dev)     = read_conll_format("dev")

print(sentences_train[2])
print(tags_train[2])

sentencesChar = sentences2char(sentences_train)

print(sentencesChar[2])

['-START-', 'Peter', 'Blackburn', '-END-']
['START', 'I-PER', 'I-PER', 'END']
[['start', '-', 'S', 'T', 'A', 'R', 'T', '-', 'end'], ['start', 'P', 'e', 't', 'e', 'r', 'end'], ['start', 'B', 'l', 'a', 'c', 'k', 'b', 'u', 'r', 'n', 'end'], ['start', '-', 'E', 'N', 'D', '-', 'end']]


In [7]:
#Read GloVe embeddings.
def read_GloVe(filename):
    embeddings = {}
    for line in open(filename).readlines():
        #print(line)
        fields = line.strip().split(" ")
        word = fields[0]
        embeddings[word] = [float(x) for x in fields[1:]]
    return embeddings

GloVe = read_GloVe("glove.840B.300d.conll_filtered.txt")

print(GloVe["the"])
print("dimension of glove embedding:", len(GloVe["the"]))

[0.27204, -0.06203, -0.1884, 0.023225, -0.018158, 0.0067192, -0.13877, 0.17708, 0.17709, 2.5882, -0.35179, -0.17312, 0.43285, -0.10708, 0.15006, -0.19982, -0.19093, 1.1871, -0.16207, -0.23538, 0.003664, -0.19156, -0.085662, 0.039199, -0.066449, -0.04209, -0.19122, 0.011679, -0.37138, 0.21886, 0.0011423, 0.4319, -0.14205, 0.38059, 0.30654, 0.020167, -0.18316, -0.0065186, -0.0080549, -0.12063, 0.027507, 0.29839, -0.22896, -0.22882, 0.14671, -0.076301, -0.1268, -0.0066651, -0.052795, 0.14258, 0.1561, 0.05551, -0.16149, 0.09629, -0.076533, -0.049971, -0.010195, -0.047641, -0.16679, -0.2394, 0.0050141, -0.049175, 0.013338, 0.41923, -0.10104, 0.015111, -0.077706, -0.13471, 0.119, 0.10802, 0.21061, -0.051904, 0.18527, 0.17856, 0.041293, -0.014385, -0.082567, -0.035483, -0.076173, -0.045367, 0.089281, 0.33672, -0.22099, -0.0067275, 0.23983, -0.23147, -0.88592, 0.091297, -0.012123, 0.013233, -0.25799, -0.02972, 0.016754, 0.01369, 0.32377, 0.039546, 0.042114, -0.088243, 0.30318, 0.087747, 0.1634

## Mapping Tokens to Indices

As in the last project, we will need to convert words in the dataset to numeric indices, so they can be presented as input to a neural network.  Code to handle this for you with sample usage is provided below.

In [8]:
#Create mappings between tokens and indices.

from collections import Counter
import random

#Will need this later to remove 50% of words that only appear once in the training data from the vocabulary (and don't have GloVe embeddings).
wordCounts = Counter([w for l in sentences_train for w in l])
charCounts = Counter([c for l in sentences_train for w in l for c in w])
singletons = set([w for (w,c) in wordCounts.items() if c == 1 and not w in GloVe.keys()])
charSingletons = set([w for (w,c) in charCounts.items() if c == 1])

#Build dictionaries to map from words, characters to indices and vice versa.
#Save first two words in the vocabulary for padding and "UNK" token.
word2i = {w:i+2 for i,w in enumerate(set([w for l in sentences_train for w in l] + list(GloVe.keys())))}
char2i = {w:i+2 for i,w in enumerate(set([c for l in sentencesChar for w in l for c in w]))}
i2word = {i:w for w,i in word2i.items()}
i2char = {i:w for w,i in char2i.items()}

vocab_size = max(word2i.values()) + 1
char_vocab_size = max(char2i.values()) + 1

#Tag dictionaries.
tag2i = {w:i for i,w in enumerate(set([t for l in tags_train for t in l]))}
i2tag = {i:t for t,i in tag2i.items()}

#When training, randomly replace singletons with UNK tokens sometimes to simulate situation at test time.
def getDictionaryRandomUnk(w, dictionary, train=False):
    if train and (w in singletons and random.random() > 0.5):
        return 1
    else:
        return dictionary.get(w, 1)

#Map a list of sentences from words to indices.
def sentences2indices(words, dictionary, train=False):
    #1.0 => UNK
    return [[getDictionaryRandomUnk(w,dictionary, train=train) for w in l] for l in words]

#Map a list of sentences containing to indices (character indices)
def sentences2indicesChar(chars, dictionary):
    #1.0 => UNK
    return [[[dictionary.get(c,1) for c in w] for w in l] for l in chars]

#Indices
X       = sentences2indices(sentences_train, word2i, train=True)
X_char  = sentences2indicesChar(sentencesChar, char2i)
Y       = sentences2indices(tags_train, tag2i)

print("vocab size:", vocab_size)
print("char vocab size:", char_vocab_size)
print()

print("index of word 'the':", word2i["the"])
print("word of index 253:", i2word[253])
print()

#Print out some examples of what the dev inputs will look like
for i in range(10):
    print(" ".join([i2word.get(w,'UNK') for w in X[i]]))

vocab size: 29148
char vocab size: 88

index of word 'the': 27702
word of index 253: Commons

-START- -DOCSTART- -END-
-START- EU rejects German call to boycott British lamb . -END-
-START- Peter Blackburn -END-
-START- BRUSSELS 1996-08-22 -END-
-START- The European Commission said on Thursday it disagreed with German advice to consumers to shun British lamb until scientists determine whether mad cow disease can be transmitted to sheep . -END-
-START- Germany 's representative to the European Union 's veterinary committee Werner Zwingmann said on Wednesday consumers should buy sheepmeat from countries other than Britain until the scientific advice was clearer . -END-
-START- " We do n't support any such recommendation because we do n't see any grounds for it , " the Commission 's chief spokesman Nikolaus van der Pas told a news briefing . -END-
-START- He said further scientific study was required and if it was found that action was needed it should be taken by the European Union . -EN

## Padding and Batching

In this assignment, you should train your models using minibatched SGD, rather than using a batch size of 1 as we did in the previous project.  When presenting multiple sentences to the network at the same time, we will need to pad them to be of the same length. We use [torch.nn.utils.rnn.pad_sequence](https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pad_sequence.html) to do so.

Below we provide some code to prepare batches of data to present to the network.  pad the sequence so that all the sequences have the same length.

**Side Note:** PyTorch includes utilities in [`torch.utils.data`](https://pytorch.org/docs/stable/data.html) to help with padding, batching, shuffling and some other things, but for this assignment we will do everything from scratch to help you see exactly how this works.

In [9]:
#Pad inputs to max sequence length (for batching)
def prepare_input(X_list):
    X_padded = torch.nn.utils.rnn.pad_sequence([torch.as_tensor(l) for l in X_list], batch_first=True).type(torch.LongTensor) # padding the sequences with 0
    X_mask   = torch.nn.utils.rnn.pad_sequence([torch.as_tensor([1.0] * len(l)) for l in X_list], batch_first=True).type(torch.FloatTensor) # consisting of 0 and 1, 0 for padded positions, 1 for non-padded positions
    return (X_padded, X_mask)

#Maximum word length (for character representations)
MAX_CLEN=32

def prepare_input_char(X_list):
    MAX_SLEN = max([len(l) for l in X_list])
    X_padded  = [l + [[]]*(MAX_SLEN-len(l))  for l in X_list]
    X_padded  = [[w[0:MAX_CLEN] for w in l] for l in X_padded]
    X_padded  = [[w + [1]*(MAX_CLEN-len(w)) for w in l] for l in X_padded]
    return torch.as_tensor(X_padded).type(torch.LongTensor)

#Pad outputs using one-hot encoding
def prepare_output_onehot(Y_list, NUM_TAGS=max(tag2i.values())+1):
    Y_onehot = [torch.zeros(len(l), NUM_TAGS) for l in Y_list]
    for i in range(len(Y_list)):
        for j in range(len(Y_list[i])):
            Y_onehot[i][j,Y_list[i][j]] = 1.0
    Y_padded = torch.nn.utils.rnn.pad_sequence(Y_onehot, batch_first=True).type(torch.FloatTensor)
    return Y_padded

print("max slen:", max([len(x) for x in X_char]))  #Max sequence length in the training data is 39.

(X_padded, X_mask) = prepare_input(X)
X_padded_char      = prepare_input_char(X_char)
Y_onehot           = prepare_output_onehot(Y)

print("X_padded:", X_padded.shape)
print("X_mask:", X_mask.shape)
print("X_padded_char:", X_padded_char.shape)
print("Y_onehot:", Y_onehot.shape)

max slen: 115
X_padded: torch.Size([14987, 115])
X_mask: torch.Size([14987, 115])
X_padded_char: torch.Size([14987, 115, 32])
Y_onehot: torch.Size([14987, 115, 10])


## **Your code starts here:** Basic LSTM Tagger (10 points)

OK, now you should have everything you need to get started.

Recall that your goal is to to implement the BiLSTM-CNN-CRF, as described in [(Ma and Hovy, 2016)](https://www.aclweb.org/anthology/P16-1101.pdf).  This is a relatively complex network with various components.  Below we provide starter code to break down your implementation into increasingly complex versions of the final model, starting with a Basic LSTM tagger.  This way you can be confident that each part is working correctly before incrementally increasing the complexity of your implementation.  This is generally a good approach to take when implementing complex models, since buggy PyTorch code is often partially working, but produces worse results than a correct implementation, so it's hard to know whether added complexities are helping or hurting.  Also, if you aren't able to match published results it's hard to know which component of your model has the problem (or even whether or not it is a problem in the published result!)

Fill in the functions marked as `TODO` in the code block below.  If everything is working correctly, you should be able to achieve an **F1 score of 0.87 on the dev set and 0.83 on the test set (with GloVe embeddings)**. You are required to initialize word embeddings with GloVe later, but you can randomly initialize the word embeddings in the beginning.

In [10]:
class BasicLSTMtagger(nn.Module):
    def __init__(self, DIM_EMB=10, DIM_HID=10):
        super(BasicLSTMtagger, self).__init__()
        NUM_TAGS = max(tag2i.values())+1

        (self.DIM_EMB, self.NUM_TAGS) = (DIM_EMB, NUM_TAGS)
        #TODO: initialize parameters - embedding layer, nn.LSTM, nn.Linear and nn.LogSoftmax
        self.embedding = nn.Embedding(vocab_size, DIM_EMB)
        self.embedding = self.embedding.to('cuda')
        self.LSTM = nn.LSTM(DIM_EMB, DIM_HID, bidirectional = True, batch_first = True)
        self.LSTM = self.LSTM.to('cuda')
        self.Linear = nn.Linear(2 * DIM_HID, NUM_TAGS)
        self.Linear = self.Linear.to('cuda')
        self.LogSoftmax = nn.LogSoftmax(dim = 2)
        self.LogSoftmax = self.LogSoftmax.to('cuda')

    def forward(self, X, train=False):
        #TODO: Implement the forward computation.
        word_embeddings = self.embedding(X.to('cuda'))
        word_embeddings = word_embeddings.to('cuda')
        lstm_out, _ = self.LSTM(word_embeddings)
        predictions = self.Linear(lstm_out)
        logsoftmax = self.LogSoftmax(predictions)
        return logsoftmax
        #return torch.randn((X.shape[0], X.shape[1], self.NUM_TAGS))  #Random baseline.

    def init_glove(self, GloVe):
        #TODO: initialize word embeddings using GloVe (you can skip this part in your first version, if you want, see instructions below).
        tmp = torch.zeros((vocab_size, self.DIM_EMB)).uniform_(-1, 1).to('cuda')
        for word, embedding in GloVe.items():
          tmp[word2i[word], :] = torch.FloatTensor(embedding).to('cuda')
        self.embedding = nn.Embedding.from_pretrained(tmp, freeze=False)
        #print("Successful Init Glove")
        #pass

    def inference(self, sentences):
        X       = prepare_input(sentences2indices(sentences, word2i))[0].cuda()
        pred = self.forward(X).argmax(dim=2)
        return [[i2tag[pred[i,j].item()] for j in range(len(sentences[i]))] for i in range(len(sentences))]

    def print_predictions(self, words, tags):
        Y_pred = self.inference(words)
        for i in range(len(words)):
            print("----------------------------")
            print(" ".join([f"{words[i][j]}/{Y_pred[i][j]}/{tags[i][j]}" for j in range(len(words[i]))]))
            print("Predicted:\t", Y_pred[i])
            print("Gold:\t\t", tags[i])

    def write_predictions(self, sentences, outFile):
        fOut = open(outFile, 'w')
        for s in sentences:
            y = self.inference([s])[0]
            #print("\n".join(y[1:len(y)-1]))
            fOut.write("\n".join(y[1:len(y)-1]))  #Skip start and end tokens
            fOut.write("\n\n")

#The following code will initialize a model and test that your forward computation runs without errors.
lstm_test   = BasicLSTMtagger(DIM_HID=7, DIM_EMB=300)
lstm_test.init_glove(GloVe)
#print((prepare_input(X[0:5])[0]).shape)
#print(len(X))
lstm_output = lstm_test.forward(prepare_input(X[0:5])[0])
Y_onehot    = prepare_output_onehot(Y[0:5])

#Check the shape of the lstm_output and one-hot label tensors.
print("lstm output shape:", lstm_output.shape)
print("Y onehot shape:", Y_onehot.shape)

lstm output shape: torch.Size([5, 32, 10])
Y onehot shape: torch.Size([5, 32, 10])


In [11]:
#Read in the data

(sentences_dev, tags_dev)     = read_conll_format('dev')
(sentences_train, tags_train) = read_conll_format('train')
(sentences_test, tags_test)   = read_conll_format('test')

# Train your Model (10 points)

Next, implement the function below to train your basic BiLSTM tagger.  See [torch.nn.lstm](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html).  Make sure to save your predictions on the test set (`test_pred_lstm.txt`) for submission to GradeScope. Feel free to change number of epochs, optimizer, learning rate and batch size.

In [12]:
#Training

from random import sample
import tqdm
import os
import subprocess
import random

def shuffle_sentences(sentences, tags):
    shuffled_sentences = []
    shuffled_tags      = []
    indices = list(range(len(sentences)))
    random.shuffle(indices)
    for i in indices:
        shuffled_sentences.append(sentences[i])
        shuffled_tags.append(tags[i])
    return (shuffled_sentences, shuffled_tags)

nEpochs = 10

def train_basic_lstm(sentences, tags, lstm):
  optimizer = optim.Adadelta(lstm.parameters(), lr=1)
  #TODO: initialize optimizer
  batchSize = 50
  #loss_function = torch.nn.NLLLoss()
  totalLoss = 0.0
  for epoch in range(nEpochs):
      (sentences_shuffled, tags_shuffled) = shuffle_sentences(sentences, tags)
      for batch in tqdm.notebook.tqdm(range(0, len(sentences), batchSize), leave=False):
          #TODO: Implement gradient update.
          lstm.zero_grad()
          inputTensor = prepare_input(sentences2indices(sentences_shuffled[batch:batch + batchSize], word2i))[0]
          inputTensor = inputTensor.to('cuda')
          lstm_log_probs = lstm.forward(inputTensor)
          Y_one_hot = prepare_output_onehot(sentences2indices(tags_shuffled[batch:batch + batchSize], tag2i))
          Y_one_hot = Y_one_hot.to('cuda')
          #loss = loss_function(lstm_log_probs.reshape(-1, 10), Y_one_hot.argmax(dim=-1).reshape(-1, 1).squeeze())
          loss = torch.sum(torch.neg(lstm_log_probs) * (Y_one_hot)) / batchSize
          totalLoss += loss
          loss.backward()
          optimizer.step()
          
      print(f"loss on epoch {epoch} = {totalLoss}")
      lstm.write_predictions(sentences_dev, 'dev_pred')   #Performance on dev set
      print('conlleval:')
      print(subprocess.Popen('paste dev dev_pred | perl conlleval.pl -d "\t"', shell=True, stdout=subprocess.PIPE,stderr=subprocess.STDOUT).communicate()[0].decode('UTF-8'))

      if epoch % 10 == 0:
          s = sample(range(len(sentences_dev)), 5)
          lstm.print_predictions([sentences_dev[i] for i in s], [tags_dev[i] for i in s])

lstm = BasicLSTMtagger(DIM_HID=500, DIM_EMB=300).cuda()
lstm.init_glove(GloVe)
train_basic_lstm(sentences_train, tags_train, lstm)

  0%|          | 0/300 [00:00<?, ?it/s]

loss on epoch 0 = 706.5680541992188
conlleval:
processed 51578 tokens with 5942 phrases; found: 6027 phrases; correct: 5090.
accuracy:  97.68%; precision:  84.45%; recall:  85.66%; FB1:  85.05
              LOC: precision:  89.90%; recall:  92.05%; FB1:  90.96  1881
             MISC: precision:  79.37%; recall:  76.36%; FB1:  77.83  887
              ORG: precision:  78.24%; recall:  74.79%; FB1:  76.48  1282
              PER: precision:  85.58%; recall:  91.86%; FB1:  88.61  1977

----------------------------
-START-/START/START Russian/I-MISC/I-MISC peacemaker/O/O Alexander/I-PER/I-PER Lebed/I-PER/I-PER and/O/O Chechen/I-MISC/I-MISC separatist/O/O military/O/O leader/O/O Aslan/I-PER/I-PER Maskhadov/I-PER/I-PER started/O/O a/O/O new/O/O round/O/O of/O/O peace/O/O talks/O/O on/O/O Friday/O/O in/O/O this/O/O settlement/O/O just/O/O outside/O/O the/O/O rebel/O/O region/O/O ./O/O -END-/END/END
Predicted:	 ['START', 'I-MISC', 'O', 'I-PER', 'I-PER', 'O', 'I-MISC', 'O', 'O', 'O', 'I-PER', 

  0%|          | 0/300 [00:00<?, ?it/s]

loss on epoch 1 = 962.5924682617188
conlleval:
processed 51578 tokens with 5942 phrases; found: 5891 phrases; correct: 5263.
accuracy:  98.27%; precision:  89.34%; recall:  88.57%; FB1:  88.95
              LOC: precision:  94.59%; recall:  91.40%; FB1:  92.97  1775
             MISC: precision:  83.80%; recall:  81.34%; FB1:  82.55  895
              ORG: precision:  81.01%; recall:  84.64%; FB1:  82.79  1401
              PER: precision:  93.35%; recall:  92.24%; FB1:  92.79  1820



  0%|          | 0/300 [00:00<?, ?it/s]

loss on epoch 2 = 1124.33203125
conlleval:
processed 51578 tokens with 5942 phrases; found: 5928 phrases; correct: 5314.
accuracy:  98.37%; precision:  89.64%; recall:  89.43%; FB1:  89.54
              LOC: precision:  95.56%; recall:  91.29%; FB1:  93.37  1755
             MISC: precision:  81.35%; recall:  82.32%; FB1:  81.83  933
              ORG: precision:  81.95%; recall:  86.35%; FB1:  84.10  1413
              PER: precision:  94.14%; recall:  93.38%; FB1:  93.76  1827



  0%|          | 0/300 [00:00<?, ?it/s]

loss on epoch 3 = 1228.0670166015625
conlleval:
processed 51578 tokens with 5942 phrases; found: 6166 phrases; correct: 5407.
accuracy:  98.26%; precision:  87.69%; recall:  91.00%; FB1:  89.31
              LOC: precision:  95.15%; recall:  94.01%; FB1:  94.58  1815
             MISC: precision:  74.03%; recall:  85.03%; FB1:  79.15  1059
              ORG: precision:  81.70%; recall:  86.20%; FB1:  83.89  1415
              PER: precision:  92.70%; recall:  94.46%; FB1:  93.57  1877



  0%|          | 0/300 [00:00<?, ?it/s]

loss on epoch 4 = 1292.0419921875
conlleval:
processed 51578 tokens with 5942 phrases; found: 5959 phrases; correct: 5422.
accuracy:  98.64%; precision:  90.99%; recall:  91.25%; FB1:  91.12
              LOC: precision:  95.08%; recall:  94.61%; FB1:  94.84  1828
             MISC: precision:  85.38%; recall:  83.62%; FB1:  84.49  903
              ORG: precision:  85.40%; recall:  85.91%; FB1:  85.65  1349
              PER: precision:  93.72%; recall:  95.60%; FB1:  94.65  1879



  0%|          | 0/300 [00:00<?, ?it/s]

loss on epoch 5 = 1326.24072265625
conlleval:
processed 51578 tokens with 5942 phrases; found: 6001 phrases; correct: 5413.
accuracy:  98.55%; precision:  90.20%; recall:  91.10%; FB1:  90.65
              LOC: precision:  94.85%; recall:  95.32%; FB1:  95.09  1846
             MISC: precision:  82.06%; recall:  84.82%; FB1:  83.41  953
              ORG: precision:  85.95%; recall:  84.41%; FB1:  85.18  1317
              PER: precision:  92.73%; recall:  94.90%; FB1:  93.80  1885



  0%|          | 0/300 [00:00<?, ?it/s]

loss on epoch 6 = 1343.9320068359375
conlleval:
processed 51578 tokens with 5942 phrases; found: 6002 phrases; correct: 5401.
accuracy:  98.51%; precision:  89.99%; recall:  90.90%; FB1:  90.44
              LOC: precision:  95.06%; recall:  94.34%; FB1:  94.70  1823
             MISC: precision:  82.01%; recall:  84.06%; FB1:  83.02  945
              ORG: precision:  82.88%; recall:  87.02%; FB1:  84.90  1408
              PER: precision:  94.52%; recall:  93.70%; FB1:  94.11  1826



  0%|          | 0/300 [00:00<?, ?it/s]

loss on epoch 7 = 1352.5631103515625
conlleval:
processed 51578 tokens with 5942 phrases; found: 5989 phrases; correct: 5413.
accuracy:  98.52%; precision:  90.38%; recall:  91.10%; FB1:  90.74
              LOC: precision:  95.12%; recall:  94.39%; FB1:  94.75  1823
             MISC: precision:  83.28%; recall:  85.36%; FB1:  84.31  945
              ORG: precision:  83.43%; recall:  86.35%; FB1:  84.87  1388
              PER: precision:  94.60%; recall:  94.14%; FB1:  94.37  1833



  0%|          | 0/300 [00:00<?, ?it/s]

loss on epoch 8 = 1357.3115234375
conlleval:
processed 51578 tokens with 5942 phrases; found: 5985 phrases; correct: 5407.
accuracy:  98.54%; precision:  90.34%; recall:  91.00%; FB1:  90.67
              LOC: precision:  95.17%; recall:  94.34%; FB1:  94.75  1821
             MISC: precision:  83.66%; recall:  83.84%; FB1:  83.75  924
              ORG: precision:  84.38%; recall:  86.20%; FB1:  85.28  1370
              PER: precision:  93.32%; recall:  94.73%; FB1:  94.02  1870



  0%|          | 0/300 [00:00<?, ?it/s]

loss on epoch 9 = 1360.3857421875
conlleval:
processed 51578 tokens with 5942 phrases; found: 5993 phrases; correct: 5428.
accuracy:  98.62%; precision:  90.57%; recall:  91.35%; FB1:  90.96
              LOC: precision:  94.72%; recall:  94.67%; FB1:  94.69  1836
             MISC: precision:  82.54%; recall:  85.14%; FB1:  83.82  951
              ORG: precision:  85.94%; recall:  86.13%; FB1:  86.03  1344
              PER: precision:  93.93%; recall:  94.95%; FB1:  94.44  1862



In [13]:
#Evaluation on test data
lstm.write_predictions(sentences_test, 'test_pred_lstm.txt')
!wget https://raw.githubusercontent.com/aritter/twitter_nlp/master/data/annotated/wnut16/conlleval.pl
!paste test test_pred_lstm.txt | perl conlleval.pl -d "\t"

--2023-03-14 02:27:09--  https://raw.githubusercontent.com/aritter/twitter_nlp/master/data/annotated/wnut16/conlleval.pl
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 12754 (12K) [text/plain]
Saving to: ‘conlleval.pl.1’


2023-03-14 02:27:09 (83.4 MB/s) - ‘conlleval.pl.1’ saved [12754/12754]

processed 46666 tokens with 5648 phrases; found: 5715 phrases; correct: 4915.
accuracy:  97.47%; precision:  86.00%; recall:  87.02%; FB1:  86.51
              LOC: precision:  88.84%; recall:  89.75%; FB1:  89.29  1685
             MISC: precision:  70.07%; recall:  76.35%; FB1:  73.07  765
              ORG: precision:  83.82%; recall:  84.53%; FB1:  84.17  1675
              PER: precision:  92.96%; recall:  91.40%; FB1:  92.17  1590


## Initialization with GloVe Embeddings (5 points)

If you haven't already, implement the `init_glove()` method in `BasicLSTMtagger` above.

Rather than initializing word embeddings randomly, it is common to use learned word embeddings (GloVe or Word2Vec), as discussed in lecture.  To make this simpler, we have already pre-filtered [GloVe](https://nlp.stanford.edu/projects/glove/) embeddings to only contain words in the vocabulary of the CoNLL NER dataset, and loaded them into a dictionary (`GloVe`) at the beginning of this notebook.



## Character Embeddings (10 points)

Now that you have your basic LSTM tagger working, the next step is to add a convolutional network that computes word embeddings from character representations of words.  See Figure 2 and Figure 3 in the [Ma and Hovy](https://www.aclweb.org/anthology/P16-1101.pdf) paper.  We have provided code in `sentences2input_tensors` to convert sentences into lists of word and character indices.  See also [nn.Conv1d](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html) and [MaxPool1d](https://pytorch.org/docs/stable/generated/torch.nn.MaxPool1d.html).

Hint: The nn.Conv1d accepts input size $(N, C_{in}, L_{in})$, but we have input size $(N, \text{SLEN}, \text{CLEN}, \text{EMB_DIM})$. We can reshape and [permute](https://pytorch.org/docs/stable/generated/torch.permute.html) our input to satisfy the nn.Conv1d, and recover the dimensions later.

Make sure to save your predictions on the test set, for submission to GradeScope.  You should be able to achieve **90 F1 / 85 F1 on the dev/test sets**.

In [14]:
import torch.nn.functional as F

class CharLSTMtagger(BasicLSTMtagger): 
    def __init__(self, DIM_EMB=10, DIM_CHAR_EMB=30, DIM_HID=10):
        super(CharLSTMtagger, self).__init__(DIM_EMB = DIM_EMB, DIM_HID = DIM_HID)
        NUM_TAGS = max(tag2i.values()) + 1

        self.DIM_CHAR_EMB = DIM_CHAR_EMB

        #TODO: Initialize parameters.
        self.dropout = nn.Dropout(p = 0.25)
        self.char_embedding = nn.Embedding(char_vocab_size, self.DIM_CHAR_EMB)
        self.char_embedding = self.char_embedding.to('cuda')
        self.dropout = self.dropout.to('cuda')
        self.conv = nn.Conv1d(self.DIM_CHAR_EMB, 10, 10)
        self.conv = self.conv.to('cuda')
        self.maxpool = nn.MaxPool1d(20)
        self.maxpool = self.maxpool.to('cuda')
        self.combined_LSTM = nn.LSTM(self.DIM_EMB + 10, DIM_HID, bidirectional = True, batch_first = True)
        self.combined_LSTM = self.combined_LSTM.to('cuda')

    def forward(self, X, X_char, train=False):
        #TODO: Implement the forward computation.
        word_embed_out = self.embedding(X.to('cuda'))
        char_embed_out = self.char_embedding(X_char.to('cuda'))
        char_embed_out = self.dropout(char_embed_out)
        char_embed_out = char_embed_out.permute(0, 1, 3, 2)
        char_embed_out = char_embed_out.reshape(word_embed_out.shape[0] * word_embed_out.shape[1], self.DIM_CHAR_EMB, -1)
        conv_out = self.conv(char_embed_out)
        maxpool_out = self.maxpool(conv_out)
        maxpool_out = maxpool_out.permute(0, 2, 1)
        maxpool_out = maxpool_out.reshape(word_embed_out.shape[0], word_embed_out.shape[1], -1)
        lstm_out, _ = self.combined_LSTM(torch.cat((word_embed_out, maxpool_out), dim = -1))
        lstm_out = lstm_out.reshape(word_embed_out.shape[0], word_embed_out.shape[1], -1)
        linear_out = self.Linear(lstm_out)
        predictions = self.LogSoftmax(linear_out)
        return predictions

    def sentences2input_tensors(self, sentences):
        (X, X_mask)   = prepare_input(sentences2indices(sentences, word2i))
        X_char        = prepare_input_char(sentences2indicesChar(sentences, char2i))
        return (X, X_mask, X_char)

    def inference(self, sentences):
        (X, X_mask, X_char) = self.sentences2input_tensors(sentences)
        pred = self.forward(X.cuda(), X_char.cuda()).argmax(dim=2)
        return [[i2tag[pred[i,j].item()] for j in range(len(sentences[i]))] for i in range(len(sentences))]

    def print_predictions(self, words, tags):
        Y_pred = self.inference(words)
        for i in range(len(words)):
            print("----------------------------")
            print(" ".join([f"{words[i][j]}/{Y_pred[i][j]}/{tags[i][j]}" for j in range(len(words[i]))]))
            print("Predicted:\t", Y_pred[i])
            print("Gold:\t\t", tags[i])

char_lstm_test = CharLSTMtagger(DIM_HID=7, DIM_EMB=300)
lstm_output    = char_lstm_test.forward(prepare_input(X[0:5])[0], prepare_input_char(X_char[0:5]))
Y_onehot       = prepare_output_onehot(Y[0:5])

print("lstm output shape:", lstm_output.shape)
print("Y onehot shape:", Y_onehot.shape)

lstm output shape: torch.Size([5, 32, 10])
Y onehot shape: torch.Size([5, 32, 10])


In [15]:
#Training LSTM w/ character embeddings. Feel free to change number of epochs, optimizer, learning rate and batch size.

nEpochs = 20

def train_char_lstm(sentences, tags, lstm):
  #TODO: initialize optimizer

    optimizer = optim.Adadelta(lstm.parameters(), lr = 1)
    batchSize = 10
    loss_function = nn.NLLLoss()

    for epoch in range(nEpochs):
        totalLoss = 0.0

        (sentences_shuffled, tags_shuffled) = shuffle_sentences(sentences, tags)
        for batch in tqdm.notebook.tqdm(range(0, len(sentences), batchSize), leave=False):
            lstm.zero_grad()
            #TODO: Gradient update
            batch_sentences = sentences_shuffled[batch: batch + batchSize]
            batch_tags = tags_shuffled[batch: batch + batchSize]
            Y_one_hot = prepare_output_onehot(sentences2indices(batch_tags, tag2i)).reshape(-1, 10).to('cuda')
            X = prepare_input(sentences2indices(batch_sentences, word2i, train = True))[0]
            X_char = prepare_input_char(sentences2indicesChar(batch_sentences, char2i))
            predictions = lstm(X.to('cuda'), X_char.to('cuda'), train = True).reshape(-1, 10)
            loss = loss_function(predictions, torch.argmax(Y_one_hot, dim = 1))
            totalLoss += loss_function(predictions, torch.argmax(Y_one_hot, dim = 1))
            loss.backward()
            optimizer.step()

        print(f"loss on epoch {epoch} = {totalLoss}")
        lstm.write_predictions(sentences_dev, 'dev_pred')   #Performance on dev set
        print('conlleval:')
        print(subprocess.Popen('paste dev dev_pred | perl conlleval.pl -d "\t"', shell=True, stdout=subprocess.PIPE,stderr=subprocess.STDOUT).communicate()[0].decode('UTF-8'))

        if epoch % 10 == 0:
            s = sample(range(len(sentences_dev)), 5)
            lstm.print_predictions([sentences_dev[i] for i in s], [tags_dev[i] for i in s])

char_lstm = CharLSTMtagger(DIM_HID=500, DIM_EMB=300).cuda()
char_lstm.init_glove(GloVe)
train_char_lstm(sentences_train, tags_train, char_lstm)

  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 0 = 102.0999526977539
conlleval:
processed 51578 tokens with 5942 phrases; found: 6006 phrases; correct: 4861.
accuracy:  97.12%; precision:  80.94%; recall:  81.81%; FB1:  81.37
              LOC: precision:  90.50%; recall:  81.98%; FB1:  86.03  1664
             MISC: precision:  80.55%; recall:  66.92%; FB1:  73.10  766
              ORG: precision:  63.21%; recall:  80.84%; FB1:  70.94  1715
              PER: precision:  88.88%; recall:  89.79%; FB1:  89.33  1861

----------------------------
-START-/START/START Chechnya/I-LOC/I-LOC must/O/O be/O/O accorded/O/O the/O/O maximum/O/O autonomy/O/O possible/O/O within/O/O the/O/O framework/O/O of/O/O Russian/I-MISC/I-MISC integrity/O/O ,/O/O "/O/O said/O/O Cotti/I-PER/I-PER ./O/O -END-/END/END
Predicted:	 ['START', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-MISC', 'O', 'O', 'O', 'O', 'I-PER', 'O', 'END']
Gold:		 ['START', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-MISC', 'O

  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 1 = 48.29219436645508
conlleval:
processed 51578 tokens with 5942 phrases; found: 6069 phrases; correct: 5085.
accuracy:  97.69%; precision:  83.79%; recall:  85.58%; FB1:  84.67
              LOC: precision:  92.42%; recall:  85.63%; FB1:  88.90  1702
             MISC: precision:  79.59%; recall:  75.27%; FB1:  77.37  872
              ORG: precision:  68.89%; recall:  83.07%; FB1:  75.32  1617
              PER: precision:  90.73%; recall:  92.51%; FB1:  91.61  1878



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 2 = 37.52080154418945
conlleval:
processed 51578 tokens with 5942 phrases; found: 6143 phrases; correct: 5232.
accuracy:  98.05%; precision:  85.17%; recall:  88.05%; FB1:  86.59
              LOC: precision:  91.87%; recall:  91.62%; FB1:  91.74  1832
             MISC: precision:  75.12%; recall:  82.21%; FB1:  78.51  1009
              ORG: precision:  77.98%; recall:  78.15%; FB1:  78.06  1344
              PER: precision:  89.02%; recall:  94.63%; FB1:  91.74  1958



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 3 = 30.010623931884766
conlleval:
processed 51578 tokens with 5942 phrases; found: 6083 phrases; correct: 5304.
accuracy:  98.29%; precision:  87.19%; recall:  89.26%; FB1:  88.22
              LOC: precision:  91.41%; recall:  93.85%; FB1:  92.61  1886
             MISC: precision:  84.09%; recall:  78.52%; FB1:  81.21  861
              ORG: precision:  78.72%; recall:  83.30%; FB1:  80.94  1419
              PER: precision:  90.71%; recall:  94.41%; FB1:  92.52  1917



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 4 = 24.957435607910156
conlleval:
processed 51578 tokens with 5942 phrases; found: 6087 phrases; correct: 5328.
accuracy:  98.35%; precision:  87.53%; recall:  89.67%; FB1:  88.59
              LOC: precision:  89.19%; recall:  96.08%; FB1:  92.51  1979
             MISC: precision:  83.67%; recall:  80.04%; FB1:  81.82  882
              ORG: precision:  80.84%; recall:  80.54%; FB1:  80.69  1336
              PER: precision:  92.33%; recall:  94.73%; FB1:  93.52  1890



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 5 = 20.986459732055664
conlleval:
processed 51578 tokens with 5942 phrases; found: 6071 phrases; correct: 5353.
accuracy:  98.44%; precision:  88.17%; recall:  90.09%; FB1:  89.12
              LOC: precision:  94.10%; recall:  92.92%; FB1:  93.51  1814
             MISC: precision:  77.80%; recall:  84.38%; FB1:  80.96  1000
              ORG: precision:  81.47%; recall:  82.62%; FB1:  82.04  1360
              PER: precision:  92.78%; recall:  95.55%; FB1:  94.14  1897



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 6 = 17.769893646240234
conlleval:
processed 51578 tokens with 5942 phrases; found: 6051 phrases; correct: 5392.
accuracy:  98.51%; precision:  89.11%; recall:  90.74%; FB1:  89.92
              LOC: precision:  93.37%; recall:  95.10%; FB1:  94.23  1871
             MISC: precision:  86.21%; recall:  81.34%; FB1:  83.71  870
              ORG: precision:  80.76%; recall:  86.73%; FB1:  83.64  1440
              PER: precision:  92.62%; recall:  94.03%; FB1:  93.32  1870



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 7 = 14.776846885681152
conlleval:
processed 51578 tokens with 5942 phrases; found: 6040 phrases; correct: 5369.
accuracy:  98.44%; precision:  88.89%; recall:  90.36%; FB1:  89.62
              LOC: precision:  95.37%; recall:  91.89%; FB1:  93.60  1770
             MISC: precision:  84.91%; recall:  81.78%; FB1:  83.31  888
              ORG: precision:  77.81%; recall:  87.62%; FB1:  82.43  1510
              PER: precision:  93.59%; recall:  95.11%; FB1:  94.35  1872



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 8 = 12.94572639465332
conlleval:
processed 51578 tokens with 5942 phrases; found: 6058 phrases; correct: 5400.
accuracy:  98.57%; precision:  89.14%; recall:  90.88%; FB1:  90.00
              LOC: precision:  94.58%; recall:  93.09%; FB1:  93.83  1808
             MISC: precision:  83.50%; recall:  83.95%; FB1:  83.72  927
              ORG: precision:  80.57%; recall:  86.88%; FB1:  83.60  1446
              PER: precision:  93.29%; recall:  95.06%; FB1:  94.17  1877



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 9 = 10.80263614654541
conlleval:
processed 51578 tokens with 5942 phrases; found: 6055 phrases; correct: 5360.
accuracy:  98.45%; precision:  88.52%; recall:  90.21%; FB1:  89.36
              LOC: precision:  95.54%; recall:  89.77%; FB1:  92.56  1726
             MISC: precision:  85.91%; recall:  82.00%; FB1:  83.91  880
              ORG: precision:  75.79%; recall:  89.19%; FB1:  81.95  1578
              PER: precision:  94.01%; recall:  95.49%; FB1:  94.75  1871



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 10 = 9.038496971130371
conlleval:
processed 51578 tokens with 5942 phrases; found: 6051 phrases; correct: 5397.
accuracy:  98.51%; precision:  89.19%; recall:  90.83%; FB1:  90.00
              LOC: precision:  95.37%; recall:  93.09%; FB1:  94.21  1793
             MISC: precision:  83.82%; recall:  83.73%; FB1:  83.78  921
              ORG: precision:  81.49%; recall:  85.01%; FB1:  83.21  1399
              PER: precision:  91.59%; recall:  96.36%; FB1:  93.92  1938

----------------------------
-START-/START/START Jones/I-ORG/I-ORG stock/O/O closed/O/O down/O/O 1/8/O/O at/O/O 40/O/O Friday/O/O ./O/O -END-/END/END
Predicted:	 ['START', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'END']
Gold:		 ['START', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'END']
----------------------------
-START-/START/START U.S./I-LOC/I-LOC physicist/O/O who/O/O discovered/O/O the/O/O two/O/O zones/O/O of/O/O radiation/O/O encircling/O/O the/O/O earth/O/O to/O/O which/O/O he/O/O g

  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 11 = 7.997614860534668
conlleval:
processed 51578 tokens with 5942 phrases; found: 6067 phrases; correct: 5371.
accuracy:  98.45%; precision:  88.53%; recall:  90.39%; FB1:  89.45
              LOC: precision:  93.66%; recall:  94.94%; FB1:  94.30  1862
             MISC: precision:  77.94%; recall:  86.23%; FB1:  81.87  1020
              ORG: precision:  84.61%; recall:  79.12%; FB1:  81.77  1254
              PER: precision:  91.71%; recall:  96.15%; FB1:  93.88  1931



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 12 = 6.332531452178955
conlleval:
processed 51578 tokens with 5942 phrases; found: 6020 phrases; correct: 5401.
accuracy:  98.59%; precision:  89.72%; recall:  90.90%; FB1:  90.30
              LOC: precision:  94.67%; recall:  93.85%; FB1:  94.26  1821
             MISC: precision:  83.26%; recall:  83.62%; FB1:  83.44  926
              ORG: precision:  82.54%; recall:  85.31%; FB1:  83.90  1386
              PER: precision:  93.38%; recall:  95.66%; FB1:  94.50  1887



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 13 = 5.509416103363037
conlleval:
processed 51578 tokens with 5942 phrases; found: 6032 phrases; correct: 5418.
accuracy:  98.57%; precision:  89.82%; recall:  91.18%; FB1:  90.50
              LOC: precision:  93.85%; recall:  94.77%; FB1:  94.31  1855
             MISC: precision:  83.07%; recall:  84.60%; FB1:  83.83  939
              ORG: precision:  84.12%; recall:  84.56%; FB1:  84.34  1348
              PER: precision:  93.28%; recall:  95.71%; FB1:  94.48  1890



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 14 = 4.7082295417785645
conlleval:
processed 51578 tokens with 5942 phrases; found: 5995 phrases; correct: 5403.
accuracy:  98.59%; precision:  90.13%; recall:  90.93%; FB1:  90.53
              LOC: precision:  95.17%; recall:  93.25%; FB1:  94.20  1800
             MISC: precision:  84.09%; recall:  84.27%; FB1:  84.18  924
              ORG: precision:  82.31%; recall:  86.06%; FB1:  84.14  1402
              PER: precision:  94.11%; recall:  95.49%; FB1:  94.80  1869



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 15 = 3.990530014038086
conlleval:
processed 51578 tokens with 5942 phrases; found: 6000 phrases; correct: 5402.
accuracy:  98.59%; precision:  90.03%; recall:  90.91%; FB1:  90.47
              LOC: precision:  94.99%; recall:  93.85%; FB1:  94.41  1815
             MISC: precision:  86.93%; recall:  82.97%; FB1:  84.91  880
              ORG: precision:  81.07%; recall:  86.20%; FB1:  83.56  1426
              PER: precision:  93.51%; recall:  95.39%; FB1:  94.44  1879



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 16 = 3.048219680786133
conlleval:
processed 51578 tokens with 5942 phrases; found: 6055 phrases; correct: 5428.
accuracy:  98.61%; precision:  89.64%; recall:  91.35%; FB1:  90.49
              LOC: precision:  94.45%; recall:  94.45%; FB1:  94.45  1837
             MISC: precision:  83.68%; recall:  83.95%; FB1:  83.81  925
              ORG: precision:  82.50%; recall:  86.13%; FB1:  84.28  1400
              PER: precision:  93.19%; recall:  95.77%; FB1:  94.46  1893



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 17 = 2.677410364151001
conlleval:
processed 51578 tokens with 5942 phrases; found: 6033 phrases; correct: 5413.
accuracy:  98.60%; precision:  89.72%; recall:  91.10%; FB1:  90.41
              LOC: precision:  94.18%; recall:  94.23%; FB1:  94.20  1838
             MISC: precision:  83.91%; recall:  83.73%; FB1:  83.82  920
              ORG: precision:  82.63%; recall:  85.16%; FB1:  83.88  1382
              PER: precision:  93.40%; recall:  95.98%; FB1:  94.67  1893



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 18 = 2.172713041305542
conlleval:
processed 51578 tokens with 5942 phrases; found: 5990 phrases; correct: 5388.
accuracy:  98.55%; precision:  89.95%; recall:  90.68%; FB1:  90.31
              LOC: precision:  94.74%; recall:  94.07%; FB1:  94.40  1824
             MISC: precision:  83.82%; recall:  83.73%; FB1:  83.78  921
              ORG: precision:  82.82%; recall:  84.49%; FB1:  83.65  1368
              PER: precision:  93.50%; recall:  95.28%; FB1:  94.38  1877



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 19 = 1.961692452430725
conlleval:
processed 51578 tokens with 5942 phrases; found: 6023 phrases; correct: 5409.
accuracy:  98.54%; precision:  89.81%; recall:  91.03%; FB1:  90.41
              LOC: precision:  95.38%; recall:  93.25%; FB1:  94.30  1796
             MISC: precision:  84.36%; recall:  84.27%; FB1:  84.32  921
              ORG: precision:  80.77%; recall:  87.40%; FB1:  83.95  1451
              PER: precision:  94.18%; recall:  94.84%; FB1:  94.51  1855



In [16]:
#Evaluation on test set
char_lstm.write_predictions(sentences_test, 'test_pred_cnn_lstm.txt')
!wget https://raw.githubusercontent.com/aritter/twitter_nlp/master/data/annotated/wnut16/conlleval.pl
!paste test test_pred_cnn_lstm.txt | perl conlleval.pl -d "\t"

--2023-03-14 02:37:59--  https://raw.githubusercontent.com/aritter/twitter_nlp/master/data/annotated/wnut16/conlleval.pl
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 12754 (12K) [text/plain]
Saving to: ‘conlleval.pl.2’


2023-03-14 02:37:59 (73.5 MB/s) - ‘conlleval.pl.2’ saved [12754/12754]

processed 46666 tokens with 5648 phrases; found: 5768 phrases; correct: 4913.
accuracy:  97.40%; precision:  85.18%; recall:  86.99%; FB1:  86.07
              LOC: precision:  90.41%; recall:  88.19%; FB1:  89.29  1627
             MISC: precision:  72.85%; recall:  74.93%; FB1:  73.88  722
              ORG: precision:  78.17%; recall:  85.79%; FB1:  81.80  1823
              PER: precision:  93.42%; recall:  92.21%; FB1:  92.81  1596


## Conditional Random Fields (15 points)

Now we are ready to add a CRF layer to the `CharacterLSTMTagger`.  To train the model, implement `conditional_log_likelihood`, using the score (unnormalized log probability) of the gold sequence, in addition to the partition function, $Z(X)$, which is computed using the forward algorithm.  Then, you can simply use Pytorch's automatic differentiation to compute gradients by running backpropagation through the computation graph of the dynamic program (this should be very simple, so long as you are able to correctly implement the forward algorithm using a computation graph that is supported by PyTorch).  This approach to computing gradients for CRFs is discussed in Section 7.5.3 of the [Eisenstein Book](https://github.com/jacobeisenstein/gt-nlp-class/blob/master/notes/eisenstein-nlp-notes.pdf)

You will also need to implement the Viterbi algorithm for inference during decoding.

After including CRF training and Viterbi decoding, you should be getting about **92 F1 / 88 F1 on the dev and test set**, respectively.

**IMPORTANT:** Note that training will be substantially slower this time - depending on the efficiency of your implementation, it could take about 5 minutes per epoch (e.g. 50 minutes for 10 iterations).  It is recommended to start out training on a single batch of data (and testing on this same batch), so that you can quickly debug, making sure your model can memorize the labels on a single batch, and then optimize your code.  Once you are fairly confident your code is working properly, then you can train using the full dataset.  We have provided a (commented out) line of code to switch between training on a single batch and the full dataset below.

**Hint #1:** While debugging your implementation of the Forward algorithm it is helpful to look at the loss during training.  The loss should never be less than zero (the log-likelihood should always be negative).

**Hint #2:** To sum log-probabilities in a numerically stable way at the end of the Forward algorithm, you will want to use [`torch.logsumexp`](https://pytorch.org/docs/stable/generated/torch.logsumexp.html).

In [56]:
#For F.max_pool1d()
import torch.nn.functional as F

class LSTM_CRFtagger(CharLSTMtagger):
    def __init__(self, DIM_EMB=10, DIM_CHAR_EMB=30, DIM_HID=10, N_TAGS=max(tag2i.values())+1):
        super(LSTM_CRFtagger, self).__init__(DIM_EMB=DIM_EMB, DIM_HID=DIM_HID, DIM_CHAR_EMB=DIM_CHAR_EMB)

        #TODO: Initialize parameters.

        self.transitionWeights = nn.Parameter(torch.zeros((N_TAGS, N_TAGS), requires_grad=True))
        nn.init.normal_(self.transitionWeights)

    def gold_score(self, lstm_scores, Y):
        #TODO: compute score of gold sequence Y (unnormalized conditional log-probability)
        #sequence_Y = list(range(len(Y)))
        sequence_Y = list(range(len(Y)))
        emits = lstm_scores[sequence_Y, Y]
        transits = self.transitionWeights[Y[:-1], Y[1:]]
        gold_score_Y = torch.sum(emits) + torch.sum(transits)
        return gold_score_Y

    #Forward algorithm for a single sentence
    #Efficiency will eventually be important here.  We recommend you start by 
    #training on a single batch and make sure your code can memorize the 
    #training data.  Then you can go back and re-write the inner loop using 
    #tensor operations to speed things up.
    def forward_algorithm(self, lstm_scores, sLen):
        #TODO: implement forward algorithm.
        dp = lstm_scores[0:1, :].clone()
        for i in range(1, len(lstm_scores)):
          emit = lstm_scores[i:i + 1, :].view(-1, 1)
          cur = dp + self.transitionWeights.T + emit
          dp = torch.logsumexp(cur, dim = 1)
        return torch.logsumexp(dp, dim = 0)

    def conditional_log_likelihood(self, sentences, tags, train=True):
        #Todo: compute conditional log likelihood of Y (use forward_algorithm and gold_score)
        input_tensors = self.sentences2input_tensors(sentences)
        X = input_tensors[0]
        X = X.to('cuda')
        X_char = input_tensors[2]
        X_char = X_char.to('cuda')
        Y_one_hot = prepare_output_onehot(sentences2indices(tags, tag2i))
        Y_one_hot = Y_one_hot.argmax(dim = -1)
        lstm_scores = self.forward(X, X_char)
        Y_conditional_log_likelihood = 0
        for sentence in range(len(sentences)):
            sentence_score = self.forward_algorithm(lstm_scores[sentence], len(sentences[sentence]))
            gold_label_score = self.gold_score(lstm_scores[sentence], Y_one_hot[sentence])
            Y_conditional_log_likelihood += (sentence_score - gold_label_score)
        return Y_conditional_log_likelihood

    def viterbi(self, lstm_scores, sLen):
        #TODO: Implement Viterbi algorithm, soring backpointers to recover the argmax sequence.  Returns the argmax sequence in addition to its unnormalized conditional log-likelihood.
        dp = torch.Tensor(1, 10)
        dp.fill_(-999999.99)
        dp = dp.to('cuda')
        dp[0][tag2i["START"]] = 0.
        back_ptrs = []
        for lstm_score in lstm_scores:
            cur = dp + self.transitionWeights
            back_ptr = torch.argmax(cur, dim = 1)
            back_ptrs.append(back_ptr)
            seq = range(len(back_ptr))
            cur_val = cur[seq, back_ptr]
            dp = cur_val + lstm_score
        best_score = torch.max(dp)
        dp_unflattened = dp.clone().unsqueeze(0)
        most_likely_path = [torch.argmax(dp_unflattened)]
        back_ptrs = reversed(back_ptrs[1:])
        for back_ptr in back_ptrs:
            most_likely_path.append(back_ptr[most_likely_path[-1]])
        most_likely_path = most_likely_path[::-1]

        #unnormalized means no mean, will do mean in train

        return most_likely_path, best_score

    #Computes Viterbi sequences on a batch of data.
    def viterbi_batch(self, sentences):
        viterbiSeqs = []
        (X, X_mask, X_char) = self.sentences2input_tensors(sentences)
        lstm_scores = self.forward(X.cuda(), X_char.cuda())
        for s in range(len(sentences)):
            (viterbiSeq, ll) = self.viterbi(lstm_scores[s], len(sentences[s]))
            viterbiSeqs.append(viterbiSeq)
        return viterbiSeqs

    def forward(self, X, X_char, train=False):
        #TODO: Implement the forward computation.
        X = X.to('cuda')
        X_char = X_char.to('cuda')
        return super().forward(X, X_char, train = True)

    def print_predictions(self, words, tags):
        Y_pred = self.inference(words)
        for i in range(len(words)):
            print("----------------------------")
            print(" ".join([f"{words[i][j]}/{Y_pred[i][j]}/{tags[i][j]}" for j in range(len(words[i]))]))
            print("Predicted:\t", [Y_pred[i][j] for j in range(len(words[i]))])
            print("Gold:\t\t", tags[i])

    #Need to use Viterbi this time.
    def inference(self, sentences, viterbi=True):
        pred = self.viterbi_batch(sentences)
        return [[i2tag[pred[i][j].item()] for j in range(len(sentences[i]))] for i in range(len(sentences))]

lstm_crf = LSTM_CRFtagger(DIM_EMB=300).cuda()

In [57]:
print(lstm_crf.conditional_log_likelihood(sentences_dev[0:3], tags_dev[0:3]))

tensor(108.1975, device='cuda:0', grad_fn=<AddBackward0>)


In [61]:
#CharLSTM-CRF Training. Feel free to change number of epochs, optimizer, learning rate and batch size.
import tqdm
import os
import subprocess
import random

nEpochs = 10

#Get CoNLL evaluation script
os.system('wget https://raw.githubusercontent.com/aritter/twitter_nlp/master/data/annotated/wnut16/conlleval.pl')

def train_crf_lstm(sentences, tags, lstm):
    optimizer = optim.Adadelta(lstm.parameters(), lr=1.0)
    #TODO: initialize optimizer

    batchSize = 10

    for epoch in range(nEpochs):
        totalLoss = 0.0
        lstm.train()

        #Shuffle the sentences
        (sentences_shuffled, tags_shuffled) = shuffle_sentences(sentences, tags)
        for batch in tqdm.notebook.tqdm(range(0, len(sentences), batchSize), leave=False):
            lstm.zero_grad()
            #TODO: take gradient step on a batch of data.
            batch_sentences = sentences_shuffled[batch: batch + batchSize]
            batch_tags = tags_shuffled[batch: batch + batchSize]
            loss = crf_lstm.conditional_log_likelihood(batch_sentences, batch_tags, train = True).mean() #normalizing loss
            totalLoss += loss.item()
            loss.backward()
            optimizer.step() 

        print(f"loss on epoch {epoch} = {totalLoss}")
        lstm.write_predictions(sentences_dev, 'dev_pred')   #Performance on dev set
        print('conlleval:')
        print(subprocess.Popen('paste dev dev_pred | perl conlleval.pl -d "\t"', shell=True, stdout=subprocess.PIPE,stderr=subprocess.STDOUT).communicate()[0].decode('UTF-8'))

        if epoch % 10 == 0:
            lstm.eval()
            s = random.sample(range(50), 5)
            lstm.print_predictions([sentences_train[i] for i in s], [tags_train[i] for i in s])   #Print predictions on train data (useful for debugging)

crf_lstm = LSTM_CRFtagger(DIM_HID=500, DIM_EMB=300, DIM_CHAR_EMB=30).cuda()
crf_lstm.init_glove(GloVe)
train_crf_lstm(sentences_train, tags_train, crf_lstm)             #Train on the full dataset
#train_crf_lstm(sentences_train[0:50], tags_train[0:50], crf_lstm)          #Train only the first batch (use this during development/debugging)

  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 0 = 22595.42800834775
conlleval:
processed 51578 tokens with 5942 phrases; found: 6118 phrases; correct: 5418.
accuracy:  98.27%; precision:  88.56%; recall:  91.18%; FB1:  89.85
              LOC: precision:  94.22%; recall:  94.07%; FB1:  94.14  1834
             MISC: precision:  79.00%; recall:  86.12%; FB1:  82.41  1005
              ORG: precision:  83.44%; recall:  84.94%; FB1:  84.18  1365
              PER: precision:  91.80%; recall:  95.39%; FB1:  93.56  1914

----------------------------
-START-/START/START -DOCSTART-/O/O -END-/END/END
Predicted:	 ['START', 'O', 'END']
Gold:		 ['START', 'O', 'END']
----------------------------
-START-/START/START The/O/O European/I-ORG/I-ORG Commission/I-ORG/I-ORG said/O/O on/O/O Thursday/O/O it/O/O disagreed/O/O with/O/O German/I-MISC/I-MISC advice/O/O to/O/O consumers/O/O to/O/O shun/O/O British/I-MISC/I-MISC lamb/O/O until/O/O scientists/O/O determine/O/O whether/O/O mad/O/O cow/O/O disease/O/O can/O/O be/O/O transmitted/O/

  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 1 = 4823.244229562581
conlleval:
processed 51578 tokens with 5942 phrases; found: 6043 phrases; correct: 5515.
accuracy:  98.59%; precision:  91.26%; recall:  92.81%; FB1:  92.03
              LOC: precision:  95.33%; recall:  95.48%; FB1:  95.40  1840
             MISC: precision:  81.45%; recall:  87.64%; FB1:  84.43  992
              ORG: precision:  89.04%; recall:  87.84%; FB1:  88.44  1323
              PER: precision:  94.01%; recall:  96.36%; FB1:  95.17  1888



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 2 = 1760.822577316314
conlleval:
processed 51578 tokens with 5942 phrases; found: 5940 phrases; correct: 5518.
accuracy:  98.74%; precision:  92.90%; recall:  92.86%; FB1:  92.88
              LOC: precision:  96.00%; recall:  95.32%; FB1:  95.66  1824
             MISC: precision:  87.77%; recall:  87.20%; FB1:  87.49  916
              ORG: precision:  88.99%; recall:  88.59%; FB1:  88.79  1335
              PER: precision:  95.17%; recall:  96.36%; FB1:  95.76  1865



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 3 = 594.1388993230648
conlleval:
processed 51578 tokens with 5942 phrases; found: 5937 phrases; correct: 5523.
accuracy:  98.80%; precision:  93.03%; recall:  92.95%; FB1:  92.99
              LOC: precision:  96.39%; recall:  95.81%; FB1:  96.10  1826
             MISC: precision:  88.96%; recall:  86.55%; FB1:  87.74  897
              ORG: precision:  88.76%; recall:  88.89%; FB1:  88.82  1343
              PER: precision:  94.76%; recall:  96.25%; FB1:  95.50  1871



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 4 = 233.78980530053377
conlleval:
processed 51578 tokens with 5942 phrases; found: 5935 phrases; correct: 5523.
accuracy:  98.82%; precision:  93.06%; recall:  92.95%; FB1:  93.00
              LOC: precision:  95.30%; recall:  96.08%; FB1:  95.69  1852
             MISC: precision:  87.53%; recall:  86.01%; FB1:  86.76  906
              ORG: precision:  90.42%; recall:  88.67%; FB1:  89.53  1315
              PER: precision:  95.38%; recall:  96.42%; FB1:  95.90  1862



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 5 = 106.46437788091134
conlleval:
processed 51578 tokens with 5942 phrases; found: 5912 phrases; correct: 5516.
accuracy:  98.81%; precision:  93.30%; recall:  92.83%; FB1:  93.07
              LOC: precision:  96.48%; recall:  95.43%; FB1:  95.95  1817
             MISC: precision:  87.28%; recall:  85.57%; FB1:  86.42  904
              ORG: precision:  89.69%; recall:  88.89%; FB1:  89.29  1329
              PER: precision:  95.70%; recall:  96.74%; FB1:  96.22  1862



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 6 = 69.45996087510139
conlleval:
processed 51578 tokens with 5942 phrases; found: 5962 phrases; correct: 5534.
accuracy:  98.81%; precision:  92.82%; recall:  93.13%; FB1:  92.98
              LOC: precision:  95.37%; recall:  96.35%; FB1:  95.86  1856
             MISC: precision:  88.35%; recall:  85.57%; FB1:  86.94  893
              ORG: precision:  89.35%; recall:  88.22%; FB1:  88.78  1324
              PER: precision:  94.87%; recall:  97.29%; FB1:  96.06  1889



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 7 = 78.7037621459458
conlleval:
processed 51578 tokens with 5942 phrases; found: 5934 phrases; correct: 5531.
accuracy:  98.83%; precision:  93.21%; recall:  93.08%; FB1:  93.15
              LOC: precision:  95.98%; recall:  96.30%; FB1:  96.14  1843
             MISC: precision:  87.43%; recall:  86.01%; FB1:  86.71  907
              ORG: precision:  90.03%; recall:  88.22%; FB1:  89.11  1314
              PER: precision:  95.51%; recall:  96.96%; FB1:  96.23  1870



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 8 = 43.9890065358486
conlleval:
processed 51578 tokens with 5942 phrases; found: 5936 phrases; correct: 5531.
accuracy:  98.82%; precision:  93.18%; recall:  93.08%; FB1:  93.13
              LOC: precision:  96.03%; recall:  96.24%; FB1:  96.14  1841
             MISC: precision:  87.86%; recall:  85.57%; FB1:  86.70  898
              ORG: precision:  89.72%; recall:  88.52%; FB1:  89.11  1323
              PER: precision:  95.36%; recall:  97.01%; FB1:  96.18  1874



  0%|          | 0/1499 [00:00<?, ?it/s]

loss on epoch 9 = 50.37612644047476
conlleval:
processed 51578 tokens with 5942 phrases; found: 5938 phrases; correct: 5537.
accuracy:  98.82%; precision:  93.25%; recall:  93.18%; FB1:  93.22
              LOC: precision:  96.09%; recall:  96.24%; FB1:  96.17  1840
             MISC: precision:  88.42%; recall:  86.12%; FB1:  87.25  898
              ORG: precision:  89.80%; recall:  88.59%; FB1:  89.19  1323
              PER: precision:  95.21%; recall:  97.01%; FB1:  96.10  1877



In [62]:
crf_lstm.eval()
crf_lstm.write_predictions(sentences_test, 'test_pred_cnn_lstm_crf.txt')
!wget https://raw.githubusercontent.com/aritter/twitter_nlp/master/data/annotated/wnut16/conlleval.pl
!paste test test_pred_cnn_lstm_crf.txt | perl conlleval.pl -d "\t"

--2023-03-14 05:18:57--  https://raw.githubusercontent.com/aritter/twitter_nlp/master/data/annotated/wnut16/conlleval.pl
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 12754 (12K) [text/plain]
Saving to: ‘conlleval.pl.22’


2023-03-14 05:18:57 (102 MB/s) - ‘conlleval.pl.22’ saved [12754/12754]

processed 46666 tokens with 5648 phrases; found: 5672 phrases; correct: 5015.
accuracy:  97.71%; precision:  88.42%; recall:  88.79%; FB1:  88.60
              LOC: precision:  91.11%; recall:  90.89%; FB1:  91.00  1664
             MISC: precision:  74.32%; recall:  78.35%; FB1:  76.28  740
              ORG: precision:  85.35%; recall:  86.63%; FB1:  85.99  1686
              PER: precision:  95.45%; recall:  93.38%; FB1:  94.40  1582


## Gradescope

Gradescope allows you to add multiple files to your submission. Please submit this notebook along with the test set prediction:
* test_pred_lstm.txt
* test_pred_cnn_lstm.txt
* test_pred_cnn_lstm_crf.txt
* NER_release.ipynb

To download this notebook, go to `File > Download.ipynb`. You can download the predictions from Colab by clicking the folder icon on the left and finding them under Files. 

Please make sure that you name the files as specified above. You will be able to see the test set accuracy for your predictions. However, the final score will be assigned later based on accuracy and implementation. 

When submitting the .ipynb notebook, please make sure that all the cells run when executed in order starting from a fresh session. If the code doesn't take too long to run, you can re-run everything with `Runtime -> Restart and run all`

You can submit multiple times before the deadline and choose the submission which you want to be graded by going to `Submission History` on gradescope.
