# COMS W4705 Spring 24
## Homework 4 - Semantic Role Labelling with BERT

The goal of this assignment is to train and evaluate a PropBank-style semantic role labeling (SRL) system. Following (Collobert et al. 2011) and others, we will treat this problem as a sequence-labeling task. For each input token, the system will predict a B-I-O tag, as illustrated in the following example: 

|The|judge|scheduled|to|preside|over|his|trial|was|removed|from|the|case|today|.|             
|---|-----|---------|--|-------|----|---|-----|---|-------|----|---|----|-----|-|             
|B-ARG1|I-ARG1|B-V|B-ARG2|I-ARG2|I-ARG2|I-ARG2|I-ARG2|O|O|O|O|O|O|O|
|||schedule.01|||||||||||||

Note that the same sentence may have multiple annotations for different predicates

|The|judge|scheduled|to|preside|over|his|trial|was|removed|from|the|case|today|.|             
|---|-----|---------|--|-------|----|---|-----|---|-------|----|---|----|-----|-|             
|B-ARG1|I-ARG1|I-ARG1|I-ARG1|I-ARG1|I-ARG1|I-ARG1|I-ARG1|O|B-V|B-ARG2|I-ARG2|I-ARG2|B-ARGM-TMP|O|
||||||||||remove.01||||||

and not all predicates need to be verbs

|The|judge|scheduled|to|preside|over|his|trial|was|removed|from|the|case|today|.|             
|---|-----|---------|--|-------|----|---|-----|---|-------|----|---|----|-----|-|    
|O|O|O|O|O|O|B-ARG1|B-V|O|O|O|O|O|O|O|
||||||||try.02||||||||

The SRL system will be implemented in [PyTorch](https://pytorch.org/). We will use BERT (in the implementation provided by the [Huggingface transformers](https://huggingface.co/docs/transformers/index) library) to compute contextualized token representations and a custom classification head to predict semantic roles. We will fine-tune the pretrained BERT model on the SRL task. 


### Overview of the Approach

The model we will train is pretty straightforward. Essentially, we will just encode the sentence with BERT, then take the contextualized embedding for each token and feed it into a classifier to predict the corresponding tag. 

Because we are only working on argument identification and labeling (not predicate identification), it is essentially that we tell the model where the predicate is. This can be accomplished in various ways. The approach we will choose here repurposes Bert's *segment embeddings*. 

Recall that BERT is trained on two input sentences, seperated by [SEP], and on a next-sentence-prediction objective (in addition to the masked LM objective). To help BERT comprehend which sentence a given token belongs to, the original BERT uses a segment embedding, using A for the first sentene, and B for the second sentence 2. 
Because we are labeling only a single sentence at a time, we can use the segment embeddings to indicate the predicate position instead: The predicate is labeled as segment B (1) and all other tokens will be labeled as segment A (0). 

<img src="https://github.com/daniel-bauer/4705-f23-hw5/blob/main/bert_srl_model.png?raw=true" width=400px>

## Setup: GCP, Jupyter, PyTorch, GPU

To make sure that PyTorch is available and can use the GPU,run the following cell which should return True. If it doesn't, make sure the GPU drivers and CUDA are installed correctly. 

GPU support is required for this assignment -- you will not be able to fine-tune BERT on a CPU. 

In [1]:
import torch
torch.cuda.is_available() 

False

## Dataset: Ontonotes 5.0 English SRL annotations

We will work with the English part of the [Ontonotes 5.0](https://catalog.ldc.upenn.edu/LDC2013T19) data. This is an extension of PropBank, using the same type of annotation. Ontonotes contains annotations other than predicate/argument structures, but we will use the PropBank style SRL annotations only. *Important*: This data set is provided to you for use in COMS 4705 only! Columbia is a subscriber to LDC and is allowed to use the data for educational purposes. However, you may not use the dataset in projects unrelated to Columbia teaching or research.

If you haven't done so already, you can download the data here: 


In [2]:
#! wget https://storage.googleapis.com/4705-bert-srl-data/ontonotes_srl.zip    

In [3]:
#! unzip ontonotes_srl.zip    

The data has been pre-processed in the following format. There are three files: 

`propbank_dev.tsv`	`propbank_test.tsv`	`propbank_train.tsv`

Each of these files is in a tab-separated value format. A single predicate/argument structure annotation consists of four rows. For example 

```
ontonotes/bc/cnn/00/cnn_0000.152.1
The     judge   scheduled       to      preside over    his     trial   was     removed from    the     case    today   /.
                schedule.01
B-ARG1  I-ARG1  B-V     B-ARG2  I-ARG2  I-ARG2  I-ARG2  I-ARG2  O       O       O       O       O       O       O
```

* The first row is a unique identifier (1st annotation of the 152nd sentence in the file ontonotes/bc/cnn/00/cnn_0000).
* The second row contains the tokens of the sentence (tab-separated). 
* The third row contains the probank frame name for the predicate (empty field for all other tokens). 
* The fourth row contains the B-I-O tag for each token. 

The file `rolelist.txt` contains a list of propbank BIO labels in the dataset (i.e. possible output tokens). This list has been filtered to contain only roles that appeared more than 1000 times in the training data. 
We will load this list and create mappings from numeric ids to BIO tags and back.

In [4]:
role_to_id = {}
with open("role_list.txt",'r') as f:
    role_list = [x.strip() for x in f.readlines()]
    role_to_id = dict((role, index) for (index, role) in enumerate(role_list))
    role_to_id['[PAD]'] = -100
    
    id_to_role = dict((index, role) for (role, index) in role_to_id.items())
    

Note that we are also mapping the '[PAD]' token to the value -100. This allows the loss function to ignore these tokens during training.

## Part 1 - Data Preparation

Before you can build the SRL model, you first need to preprocess the data. 


### 1.1 - Tokenization 

One challenge is that the pre-trained BERT model uses subword ("WordPiece") tokenization, but the Ontonotes data does not. Fortunately Huggingface transformers provides a tokenizer. 

In [5]:
from transformers import BertTokenizerFast 
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case=True)
tokenizer.tokenize("This is an unbelievably boring test sentence.")



['this',
 'is',
 'an',
 'un',
 '##bel',
 '##ie',
 '##va',
 '##bly',
 'boring',
 'test',
 'sentence',
 '.']

**TODO**: 
We need to be able to maintain the correct labels (B-I-O tags) for each of the subwords. 
Complete the following function that takes a list of tokens and a list of B-I-O labels of the same length as parameters, and returns a new token / label pair, as illustrated in the following example. 


```
>>> tokenize_with_labels("the fancyful penguin devoured yummy fish .".split(), "B-ARG0 I-ARG0 I-ARG0 B-V B-ARG1 I-ARG1 O".split(), tokenizer)
(['the',
  'fancy',
  '##ful',
  'penguin',
  'dev',
  '##oured',
  'yu',
  '##mmy',
  'fish',
  '.'],
 ['B-ARG0',
  'I-ARG0',
  'I-ARG0',
  'I-ARG0',
  'B-V',
  'I-V',
  'B-ARG1',
  'I-ARG1',
  'I-ARG1',
  'O'])

```

To approach this problem, iterate through each word/label pair in the sentence. Call the tokenizer on the word. This may result in one or more tokens. Create the correct number of labels to match the number of tokens. Take care to not generate multiple B- tokens. 


This approach is a bit slower than tokenizing the entire sentence, but is necessary to produce proper input tokenization for the pre-trained BERT model, and the matching target labels. 

In [6]:
def tokenize_with_labels(sentence, text_labels, tokenizer):
    """
    Word piece tokenization makes it difficult to match word labels
    back up with individual word pieces. 
    """

    tokenized_sentence = []
    labels = []
    
    for word, label in zip(sentence, text_labels):

        w_tokens = tokenizer.tokenize(word)
        tokenized_sentence.extend(w_tokens)
        labels.append(label)
        if label[0] == 'B':
            label = 'I' + label[1:]
        labels.extend([label for _ in w_tokens[1:]])

    return tokenized_sentence, labels

In [7]:
tokenizer.convert_tokens_to_ids(['[CLS]',
                                 'the',
  'fancy',
  '##ful',
  'penguin',
  'dev',
  '##oured',
  'yu',
  '##mmy',
  'fish',
  '.'])

[101, 1996, 11281, 3993, 13987, 16475, 16777, 9805, 18879, 3869, 1012]

In [8]:
tokenize_with_labels("the fancyful penguin devoured yummy fish .".split(), "B-ARG0 I-ARG0 I-ARG0 B-V B-ARG1 I-ARG1 O".split(), tokenizer)

(['the',
  'fancy',
  '##ful',
  'penguin',
  'dev',
  '##oured',
  'yu',
  '##mmy',
  'fish',
  '.'],
 ['B-ARG0',
  'I-ARG0',
  'I-ARG0',
  'I-ARG0',
  'B-V',
  'I-V',
  'B-ARG1',
  'I-ARG1',
  'I-ARG1',
  'O'])

### 1.2 Loading the Dataset

Next, we are creating a PyTorch [Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) class. This class acts as a contained for the training, development, and testing data in memory. You should already be familiar with Datasets and Dataloaders from homework 3. 

1.2.1 **TODO**: Write the \_\_init\_\_(self, filename) method that reads in the data from a data file (specified by the filename).

For each annotation you start with  the tokens in the sentence, and the BIO tags. Then you need to create the following 

1. call the `tokenize_with_labels` function to tokenize the sentence.
2. Add the (token, label) pair to the self.items list. 

1.2.2 **TODO**: Write the \_\_len\_\_(self) method that returns the total number of items. 

1.2.3 **TODO**: Write the \_\_getitem\_\_(self, k) method that returns a single item in a format BERT will understand. 
* We need to process the sentence by adding "\[CLS\]" as the first token and "\[SEP\]" as the last token. The need to pad the token sequence to 128 tokens using the "\[PAD\]" symbol. This needs to happen both for the inputs (sentence token sequence) and outputs (BIO tag sequence).
* We need to create an *attention mask*, which is a sequence of 128 tokens indicating the actual input symbols (as a 1) and \[PAD\] symbols (as a 0).
* We need to create a *predicate indicator* mask, which is a sequence of 128 tokens with at most one 1, in the position of the "B-V" tag. All other entries should be 0. The model will use this information to understand where the predicate is located. 

* Finally, we need to convert the token and tag sequence into numeric indices. For the tokens, this can be done using the `tokenizer.convert_tokens_to_ids` method. For the tags, use the `role_to_id` dictionary. 
Each sequence must be a pytorch tensor of shape (1,128). You can convert a list of integer values like this `torch.tensor(token_ids, dtype=torch.long)`.

To keep everything organized, we will return a dictionary in the following format 

```
{'ids': token_tensor,
 'targets': tag_tensor,
 'mask': attention_mask_tensor,
 'pred': predicate_indicator_tensor}
```


(Hint: To debug these, read in the first annotation only / the first few annotations) 


In [36]:
from torch.utils.data import Dataset, DataLoader 

class SrlData(Dataset):
    
    def __init__(self, filename):

        super(SrlData, self).__init__()
        
        self.max_len = 128 # the max number of tokens inputted to the transformer. 
        
        self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case=True)        
        
        self.items = []
        # complete this method 
        
        with open(filename, 'r') as f:
            ind = 0
            for line in f.readlines():
                if ind == 1:
                    sentence = line
                elif ind == 3:
                    labels = line
                    self.items.append(tokenize_with_labels(sentence.split(), labels.split(), self.tokenizer))
                ind = (ind+1)%4
                
        self.num_items = len(self.items)
                    
                                                        
    def __len__(self):
        return self.num_items
    
    def pad_list(self, wlist):
        wlist = ['[CLS]'] + wlist
        padding = ['[PAD]'] * (128-len(wlist))
        wlist = wlist + padding
        wlist[-1] = '[SEP]'
        return wlist
    
    def __getitem__(self, k):
        words, labels = self.items[k]
        words = self.pad_list(words)[:128]
        labels = self.pad_list(labels)[:128]
        
        ids = tokenizer.convert_tokens_to_ids(words)
        attn_mask = [1 if i != '[PAD]' else 0 for i in words]
        targets = [role_to_id[role] if role in role_to_id else role_to_id['O'] for role in labels]
        pred = [1 if i == 'B-V' else 0 for i in labels]
        
        
        #complete this method 
        return {'ids': torch.tensor([ids], dtype=torch.long),
                'mask': torch.tensor([attn_mask], dtype=torch.long), #attn_mask, 
                'targets': torch.tensor([targets], dtype=torch.long),#label_tensor, 
                'pred': torch.tensor([pred], dtype=torch.long),#pred_tensor
               }
        

In [37]:
# Reading the training data takes a while for the entire data because we preprocess all data offline
data = SrlData("propbank_train.tsv")

In [38]:
data.items

[(['we',
   'respectful',
   '##ly',
   'invite',
   'you',
   'to',
   'watch',
   'a',
   'special',
   'edition',
   'of',
   'across',
   'china',
   '.'],
  ['B-ARG0',
   'B-ARGM-MNR',
   'I-ARGM-MNR',
   'B-V',
   'B-ARG1',
   'B-ARG2',
   'I-ARG2',
   'I-ARG2',
   'I-ARG2',
   'I-ARG2',
   'I-ARG2',
   'I-ARG2',
   'I-ARG2',
   'O']),
 (['we',
   'respectful',
   '##ly',
   'invite',
   'you',
   'to',
   'watch',
   'a',
   'special',
   'edition',
   'of',
   'across',
   'china',
   '.'],
  ['O',
   'O',
   'O',
   'O',
   'B-ARG0',
   'O',
   'B-V',
   'B-ARG1',
   'I-ARG1',
   'I-ARG1',
   'I-ARG1',
   'I-ARG1',
   'I-ARG1',
   'O']),
 (['standing',
   'tall',
   'on',
   'tai',
   '##hang',
   'mountain',
   'is',
   'the',
   'monument',
   'to',
   'the',
   'hundred',
   'regiments',
   'offensive',
   '.'],
  ['B-V',
   'B-ARGM-ADV',
   'B-ARGM-LOC',
   'I-ARGM-LOC',
   'I-ARGM-LOC',
   'I-ARGM-LOC',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O

## 2. Model Definition

In [26]:
from torch.nn import Module, Linear, CrossEntropyLoss
from transformers import BertModel

We will define the pyTorch model as a subclass of the [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) class. The code for the model is provided for you. It may help to take a look at the documentation to remind you of how Module works. Take a look at how the huggingface BERT model simply becomes another sub-module. 

In [27]:
class SrlModel(Module):
    
    def __init__(self):
        
        super(SrlModel, self).__init__()
        
        self.encoder = BertModel.from_pretrained("bert-base-uncased")
    
        # The following two lines would freeze the BERT parameters and allow us to train the classifier by itself.
        # We are fine-tuning the model, so you can leave this commented out!
        # for param in self.encoder.parameters():
        #    param.requires_grad = False
        
        # The linear classifier head, see model figure in the introduction. 
        self.classifier = Linear(768, len(role_to_id))
                
        
    def forward(self, input_ids, attn_mask, pred_indicator):
    
        # This defines the flow of data through the model 
    
        # Note the use of the "token type ids" which represents the segment encoding explained in the introduction. 
        # In our segment encoding, 1 indicates the predicate, and 0 indicates everything else. 
        bert_output =  self.encoder(input_ids=input_ids, attention_mask=attn_mask, token_type_ids=pred_indicator)

        enc_tokens = bert_output[0] # the result of encoding the input with BERT
        logits = self.classifier(enc_tokens) #feed into the classification layer to produce scores for each tag.
        
        # Note that we are only interested in the argmax for each token, so we do not have to normalize 
        # to a probability distribution using softmax. The CrossEntropyLoss loss function takes this into account.
        # It essentially computes the softmax first and then computes the negative log-likelihood for the target classes. 
        return logits        

In [14]:
model = SrlModel().to('cuda') # create new model and store weights in GPU memory

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


AssertionError: Torch not compiled with CUDA enabled

In [28]:
model = SrlModel()

Now we are ready to try running the model with just a single input example to check if it is working correctly. Clearly it has not been trained, so the output is not what we expect. But we can see what the loss looks like for an initial sanity check. 

**TODO**: 
* Take a single data item from the dev set, as provided by your Dataset class defined above. Obtain the input token ids, attention mask, predicate indicator mask, and target labels. 
* Run the model on the ids, attention mask, and predicate mask like this: 

In [44]:
d4 = data[4]

In [34]:
# pick an item from the dataset. Then run 

outputs = model(d4['ids'], d4['mask'], d4['pred'])

**TODO**: 
Compute the loss on this one item only.
The initial loss should be close to -ln(1/num_labels)

Without training we would assume that all labels for each token (including the target label) are equally likely, so the negative log probability for the targets should be approximately $$ -\ln\left(\frac{1}{\text{num\_labels}}\right) $$ This is what the loss function should return on a single example. This is a good sanity check to run for any multi-class prediction problem. 

In [39]:
import math
-math.log(1 / len(role_to_id), math.e)

3.970291913552122

In [152]:
loss_function = CrossEntropyLoss(ignore_index = -100, reduction='mean')

# complete this. Note that you still have to provide a (batch_size, input_pos) 
# tensor for each parameter, where batch_size =1

dd = data[1]

outputs =  model(dd['ids'], dd['mask'], dd['pred'])

loss = loss_function(outputs.view([1,53,128]), dd['targets'])
loss.item()   #this should be approximately the score from the previous cell


4.015496253967285

**TODO**: At this point you should also obtain the actual predictions by taking the argmax over each position.
The result should look something like this (values will differ).

```
tensor([[ 1,  4,  4,  4,  4,  4,  5, 29, 29, 29,  4, 28,  6, 32, 32, 32, 32, 32,
         32, 32, 30, 30, 32, 30, 32,  4, 32, 32, 30,  4, 49,  4, 49, 32, 30,  4,
         32,  4, 32, 32,  4,  2,  4,  4, 32,  4, 32, 32, 32, 32, 30, 32, 32, 30,
         32,  4,  4, 49,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  6,  6, 32, 32,
         30, 32, 32, 32, 32, 32, 30, 30, 30, 32, 30, 49, 49, 32, 32, 30,  4,  4,
          4,  4, 29,  4,  4,  4,  4,  4,  4, 32,  4,  4,  4, 32,  4, 30,  4, 32,
         30,  4, 32,  4,  4,  4,  4,  4, 32,  4,  4,  4,  4,  4,  4,  4,  4,  4,
          4,  4]], device='cuda:0')
```

Then use the id_to_role dictionary to decode to actual tokens. 

```
['[CLS]', 'O', 'O', 'O', 'O', 'O', 'B-ARG0', 'I-ARG0', 'I-ARG0', 'I-ARG0', 'O', 'B-V', 'B-ARG1', 'I-ARG2', 'I-ARG2', 'I-ARG2', 'I-ARG2', 'I-ARG2', 'I-ARG2', 'I-ARG2', 'I-ARG1', 'I-ARG1', 'I-ARG2', 'I-ARG1', 'I-ARG2', 'O', 'I-ARG2', 'I-ARG2', 'I-ARG1', 'O', 'I-ARGM-TMP', 'O', 'I-ARGM-TMP', 'I-ARG2', 'I-ARG1', 'O', 'I-ARG2', 'O', 'I-ARG2', 'I-ARG2', 'O', '[SEP]', 'O', 'O', 'I-ARG2', 'O', 'I-ARG2', 'I-ARG2', 'I-ARG2', 'I-ARG2', 'I-ARG1', 'I-ARG2', 'I-ARG2', 'I-ARG1', 'I-ARG2', 'O', 'O', 'I-ARGM-TMP', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ARG1', 'B-ARG1', 'I-ARG2', 'I-ARG2', 'I-ARG1', 'I-ARG2', 'I-ARG2', 'I-ARG2', 'I-ARG2', 'I-ARG2', 'I-ARG1', 'I-ARG1', 'I-ARG1', 'I-ARG2', 'I-ARG1', 'I-ARGM-TMP', 'I-ARGM-TMP', 'I-ARG2', 'I-ARG2', 'I-ARG1', 'O', 'O', 'O', 'O', 'I-ARG0', 'O', 'O', 'O', 'O', 'O', 'O', 'I-ARG2', 'O', 'O', 'O', 'I-ARG2', 'O', 'I-ARG1', 'O', 'I-ARG2', 'I-ARG1', 'O', 'I-ARG2', 'O', 'O', 'O', 'O', 'O', 'I-ARG2', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
```

For now, just make sure you understand how to do this for a single example. Later, you will write a more formal function to do this once we have trained the model. 

In [156]:
role_ids = list(id_to_role.keys())
predictions = [role_ids[outputs[0][i].argmax()] for i in range(128)]
str_pred = [id_to_role[i] for i in predictions]
str_pred

['[SEP]',
 'I-ARGM-CXN',
 'I-ARG2',
 'I-ARGM-LVB',
 '[SEP]',
 'I-ARGM-PRR',
 'B-ARGM-LOC',
 '[PAD]',
 '[SEP]',
 'I-ARGM-EXT',
 '[PAD]',
 'I-ARG4',
 '[SEP]',
 '[SEP]',
 '[PAD]',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-LVB',
 '[SEP]',
 '[SEP]',
 '[SEP]',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 '[SEP]',
 '[SEP]',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 '[SEP]',
 'I-ARGM-LVB',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 '[PAD]',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 '[SEP]',
 '[SEP]',
 'I-ARGM-PRD',
 '[SEP]',
 '[SEP]',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-EXT',
 '[PAD]',
 'I-ARGM-GOL',
 'I-ARGM-CAU',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD',
 'I-ARGM-PRD'

## 3. Training loop

pytorch provides a DataLoader class that can be wrapped around a Dataset to easily use the dataset for training. The DataLoader allows us to easily adjust the batch size and shuffle the data. 

In [None]:
from torch.utils.data import DataLoader
loader = DataLoader(data, batch_size = 32, shuffle = True)

The following cell contains the main training loop. The code should work as written and report the loss after each batch,
cumulative average loss after each 100 batches, and print out the final average loss after the epoch. 

**TODO**: Modify the training loop belowso that it also computes the accuracy for each batch and reports the 
average accuracy after the epoch. 
The accuracy is the number of correctly predicted token labels out of the number of total predictions. 
Make sure you exclude [PAD] tokens, i.e. tokens for which the target label is -100. It's okay to include [CLS] and [SEP] in the accuracy calculation. 

In [None]:
loss_function = CrossEntropyLoss(ignore_index = -100, reduction='mean')

LEARNING_RATE = 1e-05
optimizer = torch.optim.AdamW(params=model.parameters(), lr=LEARNING_RATE)

device = 'cuda'

def train():
    """
    Train the model for one epoch.
    """
    tr_loss = 0 
    tr_acc = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    tr_preds, tr_labels = [], []
    # put model in training mode
    model.train()
    
    for idx, batch in enumerate(loader):
        
        # Get the encoded data for this batch and push it to the GPU
        ids = batch['ids'].to(device, dtype = torch.long)
        mask = batch['mask'].to(device, dtype = torch.long)
        targets = batch['targets'].to(device, dtype = torch.long)
        pred_mask = batch['pred'].to(device, dtype = torch.long)

        # Run the forward pass of the model
        logits = model(input_ids=ids, attn_mask=mask, pred_indicator=pred_mask)        
        loss = loss_function(logits.transpose(2,1), targets) 
        tr_loss += loss.item()
        print("Batch loss: ", loss.item()) # can comment out if too verbose.
        
        nb_tr_steps += 1
        nb_tr_examples += targets.size(0)
        
        if idx % 100==0:
            #torch.cuda.empty_cache() # can help if you run into memory issues
            curr_avg_loss = tr_loss/nb_tr_steps
            print(f"Current average loss: {curr_avg_loss}")
                   
        # Compute accuracy for this batch
        matching = torch.sum(torch.argmax(logits,dim=2) == targets)       
        predictions = torch.sum(torch.where(targets==-100,0,1))
        tr_acc += matching / predictions
                
        # Run the backward pass to update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    epoch_loss = tr_loss / nb_tr_steps
    print(f"Training loss epoch: {epoch_loss}")
    
    epoch_acc = tr_acc / nb_tr_steps
    print(f"Training acc epoch: {epoch_acc}")

Now let's train the model for one epoch. This will take a while (up to a few hours). 

In [None]:
train()

In my experiments, I found that two epochs are needed for good performance. 

In [None]:
train() 

I ended up with a training loss of about 0.19 and a training accuracy of 0.94. Specific values may differ. 

At this point, it's a good idea to save the model (or rather the parameter dictionary) so you can continue evaluating the model without having to retrain. 

In [None]:
torch.save(model.state_dict(), "srl_model_fulltrain_2epoch_finetune_1e-05.pt")

## 4. Decoding

In [None]:
# Optional step: If you stopped working after part 3, first load the trained model 

model = SrlModel().to('cuda') 
model.load_state_dict(torch.load("srl_model_fulltrain_2epoch_finetune_1e-05.pt"))
model = model.to('cuda')

**TODO (this is the fun part)**: Now that we have a trained model, let's try labeling an unseen example sentence. Complete the functions decode_output and label_sentence below. decode_output takes the logits returned by the model, extracts the argmax to obtain the label predictions for each token, and then translate the result into a list of string labels. 

label_sentence takes a list of input tokens and a predicate index, prepares the model input, call the model and then call decode_output to produce a final result. 

Note that you have already implemented all components necessary (preparing the input data from the token list and predicate index, decoding the model output). But now you are putting it together in one convenient function. 

In [None]:
tokens = "A U. N. team spent an hour inside the hospital , where it found evident signs of shelling and gunfire .".split()

In [None]:
def decode_output(logits): # it will be useful to have this in a separate function later on
    """
    Given the model output, return a list of string labels for each token. 
    """
    pass
    

In [None]:
def label_sentence(tokens, pred_idx):
    
    # complete this function to prepare token_ids, attention mask, predicate mask, then call the model. 
    # Decode the output to produce a list of labels. 
    pass
        

In [None]:
# Now you should be able to run

label_test = label_sentence(tokens, 13) # Predicate is "found"
zip(tokens, label_test)

The expected output is somethign like this: 
```   
 ('A', 'O'),
 ('U.', 'O'),
 ('N.', 'O'),
 ('team', 'O'),
 ('spent', 'O'),
 ('an', 'O'),
 ('hour', 'O'),
 ('inside', 'O'),
 ('the', 'B-ARGM-LOC'),
 ('hospital', 'I-ARGM-LOC'),
 (',', 'O'),
 ('where', 'B-ARGM-LOC'),
 ('it', 'B-ARG0'),
 ('found', 'B-V'),
 ('evident', 'B-ARG1'),
 ('signs', 'I-ARG1'),
 ('of', 'I-ARG1'),
 ('shelling', 'I-ARG1'),
 ('and', 'I-ARG1'),
 ('gunfire', 'I-ARG1'),
 ('.', 'O'),
```


### 5. Evaluation 1: Token-Based Accuracy
We want to evaluate the model on the dev or test set. 

In [None]:
dev_data = SrlData("propbank_dev.tsv") # Takes a while because we preprocess all data offline

In [None]:
from torch.utils.data import DataLoader
loader = DataLoader(dev_data, batch_size = 1, shuffle = False)

In [None]:
# Optional: Load the model again if you stopped working prior to this step. 
# model = SrlModel()
# model.load_state_dict(torch.load("srl_model_fulltrain_2epoch_finetune_1e-05.pt"))
# model = mode.to('cuda')

**TODO**: Complete the evaluate_token_accuracy function below. The function should iterate through the items in the data loader (see training loop in part 3). Run the model on each sentence/predicate pair and extract the predictions.

For each sentence, count the correct predictions and the total predictions. Finally, compute the accuracy as #correct_predictions / #total_predictions

Careful: You need to filter out the padded positions ([PAD] target tokens), as well as [CLS] and [SEP]. It's okay to include [B-V] in the count though. 

In [None]:
def evaluate_token_accuracy(model, loader):
    
    model.eval() # put model in evaluation mode
    
    # for the accuracy 
    total_correct = 0 # number of correct token label predictions. 
    total_predictions = 0 # number of total predictions = number of tokens in the data. 
    
    # iterate over the data here. 

    acc = total_correct / total_predictions
    print(f"Accuracy: {acc}")
    

### 6. Span-Based evaluation 

While the accuracy score in part 5 is encouraging, an accuracy-based evaluation is problematic for two reasons. First, most of the target labels are actually O. Second, it only tells us that per-token prediction works, but does not directly evaluate the SRL performance. 

Instead, SRL systems are typically evaluated on micro-averaged precision, recall, and F1-score for predicting labeled spans. 

More specifically, for each sentence/predicate input, we run the model, decode the output, and extract a set of labeled spans (from the output and the target labels). These spans are (i,j,label) tuples.  

We then compute the true_positives, false_positives, and false_negatives based on these spans. 

In the end, we can compute 

* Precision:  true_positive / (true_positives + false_positives)  , that is the number of correct spans out of all predicted spans. 

* Recall: true_positives / (true_positives + false_negatives) , that is the number of correct spans out of all target spans. 

* F1-score:   (2 * precision * recall) / (precision + recall)


For example, consider 

| |[CLS]|The|judge|scheduled|to|preside|over|his|trial|was|removed|from|the|case|today|.|             
|--||---|-----|---------|--|-------|----|---|-----|---|-------|----|---|----|-----|-|             
||0|1|2|3|4|5|6|7|8|9|1O|11|12|13|14|15|
|target|[CLS]|B-ARG1|I-ARG1|B-V|B-ARG2|I-ARG2|I-ARG2|I-ARG2|I-ARG2|O|O|O|O|O|O|O|
|prediction|[CLS]|B-ARG1|I-ARG1|B-V|I-ARG2|I-ARG2|O|O|O|O|O|O|O|O|B-ARGM-TMP|O|

The target spans are (1,2,"ARG1"), and (4,8,"ARG2").

The predicted spans would be (1,2,"ARG1"), (14,14,"ARGM-TMP"). Note that in the prediction, there is no proper ARG2 span because we are missing the B-ARG2 token, so this span should not be created. 

So for this sentence we woudl get: true_positives: 1 false_positives: 1 false_negatives: 1

*TODO*: Complete the function evaluate_spans that performs the span-based evaluation on the given model and data loader. You can use the provided extract_spans function, which returns the spans as a dictionary. For example
{(1,2): "ARG1", (4,8):"ARG2"} 

In [None]:
def extract_spans(labels):
    spans = {} # map (start,end) ids to label
    current_span_start = 0
    current_span_type = ""
    inside = False
    for i, label in enumerate(labels):
        if label.startswith("B"):            
            if inside: 
                if current_span_type != "V":
                    spans[(current_span_start,i)] = current_span_type            
            current_span_start = i
            current_span_type = label[2:]
            inside = True
        elif inside and label.startswith("O"):
            if current_span_type != "V":
                spans[(current_span_start,i)] = current_span_type
            inside = False
        elif inside and label.startswith("I") and label[2:] != current_span_type:            
            if current_span_type != "V":
                spans[(current_span_start,i)] = current_span_type
            inside = False
    return spans
                        

In [None]:
def evaluate_spans(model, loader):
    
    
    total_tp = 0
    total_fp = 0
    total_fn = 0
        
    for idx, batch in enumerate(loader):
        
        pass # compelte this
    
    
    total_p = total_tp / (total_tp + total_fp)
    total_r = total_tp / (total_tp + total_fn)
    total_f = (2 * total_p *total_r) / (total_p + total_r)
            
    print(f"Overall P: {total_p}  Overall R: {total_r}  Overall F1: {total_f}")
    
evaluate(model, loader)    

In my evaluation, I got an F score of 0.82  (which slightly below the state-of-the art in 2018)

### OPTIONAL: 

Repeat the span-based evaluation, but print out precision/recall/f1-score for each role separately.