<a href="https://colab.research.google.com/github/srpauliscu/nlp-shared-task/blob/main/semeval_values.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Script Setup

In [None]:
# Mount drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Install necessary packages
!pip install -r drive/MyDrive/nlp_sp/env/requirements.txt

In [None]:
# Import block
import torch
import torch.nn as nn
from transformers import BertModel
from transformers import AutoTokenizer
from typing import Dict, List
import random
from tqdm import tqdm
import numpy as np
from numpy import logical_and, sum as t_sum
import pandas as pd
from typing import Dict, List
from sklearn.model_selection import train_test_split


In [None]:
# Device setup for CUDA

'''

Important: Every tensor, layer, and model needs to be sent to the same device using to()
Ex: 
  ten = torch.ones(4,5).to(device)

'''

# Get the best device to run on
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)


cuda


# Hyperparameters

## Define hyparameters near top to make changing them easier

In [None]:
# Number of training loops
epochs = 100

# Learning rate - should be very small when using Adam
LR = .001

# Dropout probability
dropout_prob = 0.0

# Batch size
batch_size = 32

# Size to project to after BERT
hidden_size = 1024

# Data Preprocessing

## Data Format

**Data:** `arguments-training/validation/testing.tsv`
(5220 arguments)
- Argument ID
- Conclusion 
- Stance (e.g., in favor, against)
- Premise (justification for conclusion)

**Labels:** `labels-training/validation/testing.tsv` 
(20 binary value labels per argument)
- Argument ID
- Self-direction: thought
- Self-direction: action
- Stimulation
- Hedonism
- Achievement
- Power: dominance
- Power: resources
- Face
- Security: personal
- Security: societal
- Tradition
- Conformity: rules
- Conformity: interpersonal
- Humility
- Benevolence: caring
- Benevolence: dependability
- Universalism: concern
- Universalism: nature
- Universalism: tolerance
- Universalism: objectivity

**Access:** https://doi.org/10.5281/zenodo.6814563

## Load Data

In [None]:
# training arguments
#train_args_df = pd.read_csv('/content/drive/MyDrive/nlp_sp/data/arguments-training.tsv', sep='\t')         # Spencer
train_args_df = pd.read_csv('/content/drive/MyDrive/csci5832_project/data/arguments-training.tsv', sep='\t') # Caroline
# view structure
train_args_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5220 entries, 0 to 5219
Data columns (total 4 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   Argument ID  5220 non-null   object
 1   Conclusion   5220 non-null   object
 2   Stance       5220 non-null   object
 3   Premise      5220 non-null   object
dtypes: object(4)
memory usage: 163.2+ KB


In [None]:
# training labels
#train_labs_df = pd.read_csv('/content/drive/MyDrive/nlp_sp/data/labels-training.tsv', sep='\t')         # Spencer
train_labs_df = pd.read_csv('/content/drive/MyDrive/csci5832_project/data/labels-training.tsv', sep='\t') # Caroline
# view structure
train_labs_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5220 entries, 0 to 5219
Data columns (total 21 columns):
 #   Column                      Non-Null Count  Dtype 
---  ------                      --------------  ----- 
 0   Argument ID                 5220 non-null   object
 1   Self-direction: thought     5220 non-null   int64 
 2   Self-direction: action      5220 non-null   int64 
 3   Stimulation                 5220 non-null   int64 
 4   Hedonism                    5220 non-null   int64 
 5   Achievement                 5220 non-null   int64 
 6   Power: dominance            5220 non-null   int64 
 7   Power: resources            5220 non-null   int64 
 8   Face                        5220 non-null   int64 
 9   Security: personal          5220 non-null   int64 
 10  Security: societal          5220 non-null   int64 
 11  Tradition                   5220 non-null   int64 
 12  Conformity: rules           5220 non-null   int64 
 13  Conformity: interpersonal   5220 non-null   int6

## Data Prep

In [None]:
# convert multiple label columns to one label list column
train_labs_df['labels'] = train_labs_df.loc[:, 'Self-direction: thought':'Universalism: objectivity'].values.tolist()

In [None]:
# label distribution for full training data
print('Self-direction: thought =', sum(train_labs_df['Self-direction: thought']))
print('Self-direction: action =', sum(train_labs_df['Self-direction: action']))
print('Stimulation =', sum(train_labs_df['Stimulation']))
print('Hedonism =', sum(train_labs_df['Hedonism']))
print('Achievement = ', sum(train_labs_df['Achievement']))
print('Power: dominance =', sum(train_labs_df['Power: dominance']))
print('Power: resources =', sum(train_labs_df['Power: resources']))
print('Face =', sum(train_labs_df['Face']))
print('Security: personal =', sum(train_labs_df['Security: personal']))
print('Security: societal =', sum(train_labs_df['Security: societal']))
print('Tradition =', sum(train_labs_df['Tradition']))
print('Conformity: rules =', sum(train_labs_df['Conformity: rules']))
print('Conformity: interpersonal =', sum(train_labs_df['Conformity: interpersonal']))
print('Humility =', sum(train_labs_df['Humility']))
print('Benevolence: caring =', sum(train_labs_df['Benevolence: caring']))
print('Benevolence: dependability =', sum(train_labs_df['Benevolence: dependability']))
print('Universalism: concern =', sum(train_labs_df['Universalism: concern']))
print('Universalism: nature =', sum(train_labs_df['Universalism: nature']))
print('Universalism: tolerance =', sum(train_labs_df['Universalism: tolerance']))
print('Universalism: objectivity =', sum(train_labs_df['Universalism: objectivity']))

print('\nTotal number of samples = ', len(train_labs_df))

Self-direction: thought = 913
Self-direction: action = 1332
Stimulation = 312
Hedonism = 202
Achievement =  1400
Power: dominance = 461
Power: resources = 566
Face = 374
Security: personal = 1961
Security: societal = 1627
Tradition = 598
Conformity: rules = 1222
Conformity: interpersonal = 217
Humility = 438
Benevolence: caring = 1500
Benevolence: dependability = 766
Universalism: concern = 1992
Universalism: nature = 358
Universalism: tolerance = 709
Universalism: objectivity = 937

Total number of samples =  5220


In [None]:
# combine dfs to add label list to data dictionary
train_merged_df = pd.merge(train_args_df, train_labs_df, on='Argument ID')
train_merged_df = train_merged_df.drop(columns=['Self-direction: thought',
                                                'Self-direction: action',
                                                'Stimulation',
                                                'Hedonism',
                                                'Achievement',
                                                'Power: dominance',
                                                'Power: resources',
                                                'Face',
                                                'Security: personal',
                                                'Security: societal',
                                                'Tradition',
                                                'Conformity: rules',
                                                'Conformity: interpersonal',
                                                'Humility',
                                                'Benevolence: caring',
                                                'Benevolence: dependability',
                                                'Universalism: concern',
                                                'Universalism: nature',
                                                'Universalism: tolerance',
                                                'Universalism: objectivity'])

In [None]:
# view structure
train_merged_df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 5220 entries, 0 to 5219
Data columns (total 5 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   Argument ID  5220 non-null   object
 1   Conclusion   5220 non-null   object
 2   Stance       5220 non-null   object
 3   Premise      5220 non-null   object
 4   labels       5220 non-null   object
dtypes: object(5)
memory usage: 244.7+ KB


## Train/Val Split

In [None]:
# split train data into 80/20 train/val
train_data, val_data = train_test_split(train_merged_df, test_size=0.2, random_state=4)

In [None]:
# convert each row to a dictionary -> List[Dict]
train_data = train_data.to_dict(orient='records')
val_data = val_data.to_dict(orient='records')
full_data = train_merged_df.to_dict(orient='records')
# print examples
print('training example:\n', train_data[0])
print('validation example:\n', val_data[0])
print('full example:\n', full_data[0])

training example:
 {'Argument ID': 'A07017', 'Conclusion': 'Homeopathy brings more harm than good', 'Stance': 'against', 'Premise': 'homeopathy uses natural remedies that have little to no side affects on the body.', 'labels': [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0]}
validation example:
 {'Argument ID': 'A18174', 'Conclusion': 'The vow of celibacy should be abandoned', 'Stance': 'against', 'Premise': "the vow of celibacy should be promoted as it brings the sense of self control and purity to a person's soul.", 'labels': [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0]}
full example:
 {'Argument ID': 'A01001', 'Conclusion': 'Entrapment should be legalized', 'Stance': 'in favor of', 'Premise': "if entrapment can serve to more easily capture wanted criminals, then why shouldn't it be legal?", 'labels': [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}


In [None]:
# Calculate the label density
numerator = 0
for sample in full_data:
  numerator += sum(sample['labels'])

density = numerator / (20 * len(full_data))

# Set the threshold
threshold = 2*density
#threshold = 0.5

print(threshold)

0.34262452107279695


## Tokenization

In [None]:
# function to load samples from HuggingFace dataset to be batched and encoded

class BatchTokenizer:
    """Tokenizes and pads a batch of input sentences."""
    """HuggingFace docs: https://huggingface.co/transformers/v3.0.2/preprocessing.html"""

    def __init__(self):
        """Initializes the tokenizer

        Args:
            pad_symbol (Optional[str], optional): The symbol for a pad. Defaults to "<P>".
        """
        self.hf_tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-small")
    
    # HuggingFace tokenizer will join data with sentence separator token
    # and match batches of tokenized and encoded sentences
    def get_sep_token(self,):
        return self.hf_tokenizer.sep_token

    # call method can only take a pair of inputs, but we have three
    # conclusion batch, stance batch, and premise batch
    # so we create a hack
    #def __call__(self, con_batch: List[str], stan_batch: List[str], prem_batch: List[str]) -> List[List[str]]:

    def __call__(self, con_stan_batch: List[str], prem_batch: List[str]) -> List[List[str]]:  
        """Uses the huggingface tokenizer to tokenize and pad a batch.

        We return a dictionary of tensors per the huggingface model specification.

        Args:
            batch (List[str]): A List of sentence strings

        Returns:
            Dict: The dictionary of token specifications provided by HuggingFace
        """
        # The HF tokenizer will PAD for us, and additionally combine 
        # the two sentences deimited by the [SEP] token.
        enc = self.hf_tokenizer(
            con_stan_batch,
            prem_batch,
            #stan_batch,
            #prem_batch,
            padding=True,
            return_token_type_ids=False, # ignore with hack
            return_tensors='pt'
        )

        return enc

In [None]:
# define tokenizer
tokenizer = BatchTokenizer()

In [None]:
# example of use case for batch tokenizer without triplet hack (only two input types acceptable)
token_ex = tokenizer(*[['this is the conclusion with more words', 'this is also a conclusion'], ['this is the premise', 'this is the second premise']])
print(f"{token_ex}\n")
tokenizer.hf_tokenizer.batch_decode(token_ex['input_ids'])

{'input_ids': tensor([[  101,  2023,  2003,  1996,  7091,  2007,  2062,  2616,   102,  2023,
          2003,  1996, 18458,   102],
        [  101,  2023,  2003,  2036,  1037,  7091,   102,  2023,  2003,  1996,
          2117, 18458,   102,     0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]])}



['[CLS] this is the conclusion with more words [SEP] this is the premise [SEP]',
 '[CLS] this is also a conclusion [SEP] this is the second premise [SEP] [PAD]']

In [None]:
# example of use case for batch tokenizer with triplet hack
token_ex2 = tokenizer(*[['this is the conclusion with more words [SEP] and a stance against', 'this is also a conclusion [SEP] with another stance that is in favor of'], ['this is the premise', 'this is the second premise']])
print(f"{token_ex2}\n")
tokenizer.hf_tokenizer.batch_decode(token_ex2['input_ids'])

{'input_ids': tensor([[  101,  2023,  2003,  1996,  7091,  2007,  2062,  2616,   102,  1998,
          1037, 11032,  2114,   102,  2023,  2003,  1996, 18458,   102,     0,
             0,     0],
        [  101,  2023,  2003,  2036,  1037,  7091,   102,  2007,  2178, 11032,
          2008,  2003,  1999,  5684,  1997,   102,  2023,  2003,  1996,  2117,
         18458,   102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}



['[CLS] this is the conclusion with more words [SEP] and a stance against [SEP] this is the premise [SEP] [PAD] [PAD] [PAD]',
 '[CLS] this is also a conclusion [SEP] with another stance that is in favor of [SEP] this is the second premise [SEP]']

## Batch

In [None]:
# function to generate triple-wise inputs

def generate_triplewise_input(dataset: List[Dict]) -> (List[str], List[str], List[str], List[str], List[List[int]]):
    """
    group all argument components and corresponding labels of the datapoints
    a datapoint is now a dictionary of 
    argument id, conclusion, stance, premise, and label list
    """

    # extract each observation from dictionary; save to list
    d_vals = []
    for i in range(len(dataset)):
        d_vals.append(list(dataset[i].values()))

    # store data items in lists by three categories by id
    id_lst = []    
    conclusion_lst = []
    stance_lst = []
    premise_lst = []

    # store labels in list of lists of 20 labels
    label_lst = []

    # generate separate lists from each observation
    for i in range(len(d_vals)):
        id_lst.append(d_vals[i][0])
        conclusion_lst.append(d_vals[i][1])
        stance_lst.append(d_vals[i][2])
        premise_lst.append(d_vals[i][3])
        label_lst.append(d_vals[i][4])

    # add [SEP] token before every stance in list
    stance_lst = [' [SEP] ' + s for s in stance_lst]

    return id_lst, conclusion_lst, stance_lst, premise_lst, label_lst

In [None]:
# apply function to generate triple-wise inputs and labels for batching

# training data
train_ids, train_conclusions, train_stances, train_premises, train_labels = generate_triplewise_input(train_data)

# validation data
val_ids, val_conclusions, val_stances, val_premises, val_labels = generate_triplewise_input(val_data)

# full data
full_ids, full_conclusions, full_stances, full_premises, full_labels = generate_triplewise_input(full_data)

In [None]:
# temporarily combine conclusions and stances separate with [SEP]
# use hack to merge tokenized conclusion batch, stance batch, and premise batch

# training data
train_conclusions_stances = []
for i in range(len(train_conclusions)):
  train_conclusions_stances.append(train_conclusions[i] + train_stances[i])

# validation data
val_conclusions_stances = []
for i in range(len(val_conclusions)):
  val_conclusions_stances.append(val_conclusions[i] + val_stances[i])

# full data
full_conclusions_stances = []
for i in range(len(full_conclusions)):
  full_conclusions_stances.append(full_conclusions[i] + full_stances[i])

In [None]:
# define functions to chunk data for batches

# for train labels
def chunk(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i: i+n]

# for train features
def chunk_multi(lst1, lst2, n):
    for i in range(0, len(lst1), n):
        yield lst1[i: i+n], lst2[i: i+n]

In [None]:
# apply function to batch input data 
# tokenize and encode simultaneously since we are using HuggingFace

# batch
train_input_batches = [b for b in chunk_multi(train_conclusions_stances, train_premises, batch_size)]
val_size = 1
full_size = 1
val_input_batches = [b for b in chunk_multi(val_conclusions_stances, val_premises, val_size)]
full_input_batches = [b for b in chunk_multi(full_conclusions_stances, full_premises, full_size)]

# tokenize + encode
train_input_batches = [tokenizer(*batch).to(device) for batch in train_input_batches]
val_input_batches = [tokenizer(*batch).to(device) for batch in val_input_batches]
full_input_batches = [tokenizer(*batch).to(device) for batch in full_input_batches]

In [None]:
# check training data example
print(train_input_batches[0])
encoded_tst = tokenizer.hf_tokenizer.batch_decode(train_input_batches[0]['input_ids'])
encoded_tst[0]

{'input_ids': tensor([[  101,  2188, 29477,  ...,     0,     0,     0],
        [  101,  2057,  2323,  ...,     0,     0,     0],
        [  101,  2057,  2323,  ...,     0,     0,     0],
        ...,
        [  101,  2057,  2323,  ...,     0,     0,     0],
        [  101,  2057,  2323,  ...,     0,     0,     0],
        [  101,  2057,  2323,  ...,     0,     0,     0]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')}


'[CLS] homeopathy brings more harm than good [SEP] against [SEP] homeopathy uses natural remedies that have little to no side affects on the body. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [None]:
# define function to batch class labels
# a single observation's label is a list of 20 labels

def encode_labels(labels: List[List[int]]) -> torch.FloatTensor:
    """Turns the batch of labels into a tensor

    Args:
        labels (List[List[int]]): List of all lists of labels in batch

    Returns:
        torch.FloatTensor: Tensor of all lists of labels in batch
    """
    
    return torch.LongTensor(labels)


In [None]:
# apply function to batch labels in same order as inputs
# batch
train_label_batches = [b for b in chunk(train_labels, batch_size)]
val_label_batches = [b for b in chunk(val_labels, val_size)]
full_label_batches = [b for b in chunk(full_labels, full_size)]
# tokenize + encode
train_label_batches = [encode_labels(batch).to(device) for batch in train_label_batches]
val_label_batches = [encode_labels(batch).to(device) for batch in val_label_batches]
full_label_batches = [encode_labels(batch).to(device) for batch in full_label_batches]

# Model

Below is the code to define our model as well as the training loop.

## Functions to Make Predictions

In [None]:
def make_prediction(logits: torch.Tensor) -> torch.Tensor:
  # Use boolean logic to handle the predictions
  return (logits>threshold).float()

def predict(model: torch.nn.Module, sents: torch.Tensor) -> List:
    logits = model(sents)
    return make_prediction(logits.cpu())

## Model Definition

In [None]:

# Function to initialize weights for the chain classifiers
def init_weights(layer):
    if isinstance(layer, nn.Linear):
        torch.nn.init.xavier_normal_(layer.weight)

class NLIClassifier(torch.nn.Module):
    def __init__(self, output_size: int, hidden_size: int, dropout_prob: float):
      
      # Basic initialization
      super().__init__()
      self.output_size = output_size
      self.hidden_size = hidden_size

      # Additional args
      self.dropout_prob = dropout_prob

      # Initialize BERT, which we use instead of a single embedding layer.
      self.bert = BertModel.from_pretrained("prajjwal1/bert-small").to(device)
      
      # Comment out these lines to unfreeze BERT params
      for param in self.bert.parameters():
          param.requires_grad = False
          
      # Get BERT's hiddem dim
      self.bert_hidden_dimension = self.bert.config.hidden_size
      
      
      # Single linear layer to project to hidden size
      self.hidden_layer = torch.nn.Linear(self.bert_hidden_dimension, self.hidden_size * 4).to(device)
      self.hidden_layer_2 = torch.nn.Linear(self.hidden_size * 4, self.hidden_size * 2).to(device)
      #self.hidden_layer_3 = torch.nn.Linear(self.hidden_size * 2, self.hidden_size).to(device)
      
      # Use RELU regularization
      # TODO: Could try others
      self.relu = torch.nn.ReLU()

      '''

      We are doing multi-label classification using a chain classifier.
      For details, see: https://en.wikipedia.org/wiki/Multi-label_classification

      Setup a classifier chain for the 20 labels.
      To simplify code, just store them in a list and run through them sequentially.
      They will be interpreted in the same order as the training data:

      Self-direction: thought
      Self-direction: action
      Stimulation
      Hedonism
      Achievement
      Power: dominance
      Power: resources
      Face
      Security: personal
      Security: societal
      Tradition
      Conformity: rules
      Conformity: interpersonal
      Humility
      Benevolence: caring
      Benevolence: dependability
      Universalism: concern
      Universalism: nature
      Universalism: tolerance
      Universalism: objectivity

      '''

      self.chain = []
      for i in range(self.output_size):

        # To make it a chain, the prediction from the previous classifier is 
        # appended to the input and used as the input for the next classifier

        # Initialize each chain classifier
        # TODO: Try more layers per classifier
        # TODO: Could try bigger BERT model, but that would require more changes
        # TODO: Could unfreeze BERT weights
        
        # TODO: Hyperparameter tunings

        # TODO: Could also play with the threshold for prediction
        # and base it on label cardinality (i.e. average number of labels per sample)

        t = nn.Sequential(
            nn.Dropout(p=self.dropout_prob),
            nn.Linear(in_features=self.hidden_size + i, out_features = 1),
            nn.Sigmoid()
        )
        self.chain.append(t.to(device))
        # Initialize the weights
        for c in self.chain:
          c.apply(init_weights)

    def encode_text(
        self,
        symbols: Dict
    ) -> torch.Tensor:
        """Use BERT to create contextulized embeddings and get the output 
            from the pooling layer (i.e. embedding for CLR)

        Args:
            symbols (Dict): The Dict of token specifications provided by the HuggingFace tokenizer

        Returns:
            torch.Tensor: Encoding of CLR for the given input
        """

        # Run through BERT for contextualized embeddings
        encoded_sequence = self.bert(**symbols)
        # TODO: Get the [CLS] token using the `pooler_output` from 
        #      The BertModel output. See here: https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel
        #      and check the returns for the forward method.
        # We want to return a tensor of the form batch_size x 1 x bert_hidden_dimension
        
        # Pooler output is initially (batch_size, bert_hidden_dimension)
        pool_out = torch.unsqueeze(encoded_sequence['pooler_output'], dim=1)
        return pool_out

    def forward(
        self,
        symbols: Dict,
    ) -> torch.Tensor:
        """_summary_

        Args:
            symbols (Dict): The Dict of token specifications provided by the HuggingFace tokenizer

        Returns:
            torch.Tensor: _description_
        """
        encoded_sents = self.encode_text(symbols)
        output = self.hidden_layer(encoded_sents)
        output = self.relu(output)
        output = self.hidden_layer_2(output)
        output = self.relu(output)
        output = self.hidden_layer_3(output)
        output = self.relu(output)
        
        # output is of size (batch_size, hidden_layer)

        # Run through the classifier chain

        cur_input = output
        logits = []

        for classifier in self.chain:

          # Get output of next in chain
          o = classifier(cur_input)

          # Save the logits for training
          logits.append(o)

          # Make a prediction so we can append it to the next input
          # TODO: Could also append raw logits, potentially
          pred = make_prediction(o)

          # Append the previous prediction to the input for the next classifier
          cur_input = torch.cat([cur_input, pred], dim=2)

        # Preds contains 20 tensors, each batch_size x 1 x 1
        # We need to return one tensor that is 128 x 20
        stack = logits[0].squeeze(dim=1)
        for logit in logits[1:]:
          stack = torch.cat([stack, logit.squeeze(dim=1)], dim=-1)
        
        return stack

## Evaluation

### Metric Functions

In [None]:
def precision(predicted_labels, true_labels):
    """
    Precision is True Positives / All Positives Predictions
    """

    # Each pred/true pair is a list of 20 values, so need to go one level deeper

    all_pos = 0
    true_pos = 0
    for i in range(len(predicted_labels)):
      cur_pred = predicted_labels[i]
      cur_true = true_labels[i]

      # Count both true_pos and false_pos
      all_pos += sum(cur_pred)

      # Get true_pos only
      for j in range(len(cur_pred)):
        if (cur_pred[j] == 1 and cur_pred[j] == cur_true[j]):
          true_pos += 1

    if all_pos:
        return true_pos/all_pos   
    else:
        return 0.


def recall(predicted_labels, true_labels, which_label=1):
    """
    Recall is True Positives / All Positive Labels
    """

    false_neg = 0
    true_pos = 0
    for i in range(len(predicted_labels)):
      cur_pred = predicted_labels[i]
      cur_true = true_labels[i]
    
      for j in range(len(cur_pred)):
        # Get true_pos
        if (cur_pred[j] == 1 and cur_pred[j] == cur_true[j]):
          true_pos += 1

        # Get false_neg
        if (cur_pred[j] == 0 and cur_true[j] == 1):
          false_neg += 1
      
    denom = false_neg + true_pos
    if denom:
        return true_pos/denom
    else:
        return 0.

def f1_score(
    predicted_labels: List[int],
    true_labels: List[int]
):
    """
    F1 score is the harmonic mean of precision and recall
    """
    P = precision(predicted_labels, true_labels)
    R = recall(predicted_labels, true_labels)
    if P and R:
        return 2*P*R/(P+R)
    else:
        return 0.


## Training Loop

In [None]:
def training_loop(
    num_epochs,
    train_features,
    train_labels,
    dev_features,
    dev_labels,
    optimizer,
    model,
    possible_labels
):
    print("Training...")
    dev_f1_scores = []
    #loss_func = torch.nn.BCEWithLogitsLoss()
    loss_func = torch.nn.BCELoss()

    # Send the data to the device first
    #train_features = train_features.to(device)
    #train_labels = train_labels.to(device)
    #dev_features = dev_features.to(device)
    #dev_labels = dev_labels.to(device)

    batches = list(zip(train_features, train_labels))
    random.shuffle(batches)
    for i in range(num_epochs):
        losses = []
        for features, labels in tqdm(batches):
            # Empty the dynamic computation graph
            optimizer.zero_grad()
            preds = model(features)
            loss = loss_func(preds, labels.float())
            
            # Backpropogate the loss through our model
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
        
        print(f"epoch {i}, loss: {sum(losses)/len(losses)}")
        # Estimate the f1 score for the development set
        print("Evaluating dev...")
        all_preds = []
        all_labels = []
        for sents, labels in tqdm(zip(dev_features, dev_labels), total=len(dev_features)):
            pred = predict(model, sents)
            all_preds.extend(pred.cpu().detach().numpy())
            all_labels.extend(list(labels.cpu().numpy()))
        dev_f1 = f1_score(all_preds, all_labels)
        print(f"Dev F1 {dev_f1}")
        dev_f1_scores.append(dev_f1)

    # Print the best dev_f1 score for result reporting
    print(f"Best dev F1 score: {np.max(dev_f1_scores)}")
    print(f"Best iteration: {np.argmax(dev_f1_scores)}")
    
    # Return the trained model
    return model

# Training Phase

## Setup

In [None]:
# Number of labels (should be 20)
possible_labels = len(train_labels[0])
if possible_labels != 20:
  raise RuntimeError(f"Instead of 20 possible labels, we found {possible_labels}.")

# Intialize model
model = NLIClassifier(output_size=possible_labels, hidden_size=hidden_size, dropout_prob=dropout_prob)
model.train()

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

Some weights of the model checkpoint at prajjwal1/bert-small were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Train the Model

In [None]:
# Start the training
trained_model = training_loop(
    epochs,
    train_input_batches,
    train_label_batches,
    val_input_batches,
    val_label_batches,
    optimizer,
    model,
    list(range(possible_labels))
)


Training...


100%|██████████| 131/131 [00:03<00:00, 40.24it/s]


epoch 0, loss: 0.4154371694299101
Evaluating dev...


100%|██████████| 1044/1044 [00:13<00:00, 74.72it/s]


Dev F1 0.40482737549648634


100%|██████████| 131/131 [00:04<00:00, 27.89it/s]


epoch 1, loss: 0.38724998390401594
Evaluating dev...


100%|██████████| 1044/1044 [00:12<00:00, 85.07it/s] 


Dev F1 0.44411326378539495


100%|██████████| 131/131 [00:02<00:00, 44.23it/s]


epoch 2, loss: 0.3755833135761377
Evaluating dev...


100%|██████████| 1044/1044 [00:07<00:00, 130.58it/s]


Dev F1 0.4629518072289157


100%|██████████| 131/131 [00:02<00:00, 44.65it/s]


epoch 3, loss: 0.3671064178907234
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 128.59it/s]


Dev F1 0.47992836890016416


100%|██████████| 131/131 [00:02<00:00, 44.92it/s]


epoch 4, loss: 0.3607089251052332
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 118.79it/s]


Dev F1 0.4817649707339037


100%|██████████| 131/131 [00:03<00:00, 41.09it/s]


epoch 5, loss: 0.35491062302625814
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 127.73it/s]


Dev F1 0.48995159160921226


100%|██████████| 131/131 [00:02<00:00, 44.39it/s]


epoch 6, loss: 0.3517393542610052
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 127.07it/s]


Dev F1 0.4916302216860202


100%|██████████| 131/131 [00:03<00:00, 33.16it/s]


epoch 7, loss: 0.3477695795870919
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 109.08it/s]


Dev F1 0.5072004608294931


100%|██████████| 131/131 [00:02<00:00, 44.57it/s]


epoch 8, loss: 0.3433566942014767
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 109.45it/s]


Dev F1 0.5017880131597768


100%|██████████| 131/131 [00:02<00:00, 43.97it/s]


epoch 9, loss: 0.33880056110957196
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 126.14it/s]


Dev F1 0.512449575740715


100%|██████████| 131/131 [00:02<00:00, 44.82it/s]


epoch 10, loss: 0.3353435326623553
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 129.95it/s]


Dev F1 0.48707280832095096


100%|██████████| 131/131 [00:02<00:00, 44.19it/s]


epoch 11, loss: 0.33288778853780443
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 126.65it/s]


Dev F1 0.5150226757369615


100%|██████████| 131/131 [00:02<00:00, 43.94it/s]


epoch 12, loss: 0.32915396476519926
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 109.41it/s]


Dev F1 0.5134336756224503


100%|██████████| 131/131 [00:02<00:00, 44.08it/s]


epoch 13, loss: 0.3261782397295683
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 126.60it/s]


Dev F1 0.5223126578725793


100%|██████████| 131/131 [00:02<00:00, 44.17it/s]


epoch 14, loss: 0.3218373917896329
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 123.12it/s]


Dev F1 0.5201506066099568


100%|██████████| 131/131 [00:02<00:00, 44.03it/s]


epoch 15, loss: 0.31757387018385735
Evaluating dev...


100%|██████████| 1044/1044 [00:10<00:00, 104.10it/s]


Dev F1 0.5290705898995615


100%|██████████| 131/131 [00:03<00:00, 40.06it/s]


epoch 16, loss: 0.3164207960358103
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 126.87it/s]


Dev F1 0.517804775869292


100%|██████████| 131/131 [00:03<00:00, 43.65it/s]


epoch 17, loss: 0.3141324629310433
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 125.84it/s]


Dev F1 0.5249197263716321


100%|██████████| 131/131 [00:02<00:00, 44.04it/s]


epoch 18, loss: 0.31025453696724115
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 125.90it/s]


Dev F1 0.5153168275283255


100%|██████████| 131/131 [00:02<00:00, 44.06it/s]


epoch 19, loss: 0.3097610908155223
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 106.91it/s]


Dev F1 0.523783185840708


100%|██████████| 131/131 [00:02<00:00, 44.05it/s]


epoch 20, loss: 0.3061353536962553
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 125.26it/s]


Dev F1 0.5339710226473484


100%|██████████| 131/131 [00:03<00:00, 43.15it/s]


epoch 21, loss: 0.3011023966410688
Evaluating dev...


100%|██████████| 1044/1044 [00:10<00:00, 95.98it/s]


Dev F1 0.5221301110018266


100%|██████████| 131/131 [00:03<00:00, 37.76it/s]


epoch 22, loss: 0.29938683277778044
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 126.59it/s]


Dev F1 0.5279685966633956


100%|██████████| 131/131 [00:02<00:00, 44.10it/s]


epoch 23, loss: 0.2971792061820285
Evaluating dev...


100%|██████████| 1044/1044 [00:11<00:00, 89.51it/s] 


Dev F1 0.5260676650027731


100%|██████████| 131/131 [00:02<00:00, 44.05it/s]


epoch 24, loss: 0.29518546141285934
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 124.51it/s]


Dev F1 0.5333882934872217


100%|██████████| 131/131 [00:03<00:00, 43.56it/s]


epoch 25, loss: 0.29288141167800846
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 127.25it/s]


Dev F1 0.5367535512343125


100%|██████████| 131/131 [00:02<00:00, 44.23it/s]


epoch 26, loss: 0.2915513982982126
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 126.39it/s]


Dev F1 0.5274360746371803


100%|██████████| 131/131 [00:03<00:00, 43.46it/s]


epoch 27, loss: 0.28900965092746356
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 115.65it/s]


Dev F1 0.5333704115684094


100%|██████████| 131/131 [00:03<00:00, 41.03it/s]


epoch 28, loss: 0.286917103269628
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 125.38it/s]


Dev F1 0.5369183616606297


100%|██████████| 131/131 [00:02<00:00, 43.95it/s]


epoch 29, loss: 0.2839874719617931
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 123.88it/s]


Dev F1 0.5321876704855428


100%|██████████| 131/131 [00:02<00:00, 44.04it/s]


epoch 30, loss: 0.2829232176069085
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 127.12it/s]


Dev F1 0.5254260395364689


100%|██████████| 131/131 [00:02<00:00, 43.75it/s]


epoch 31, loss: 0.28158858953086474
Evaluating dev...


100%|██████████| 1044/1044 [00:11<00:00, 92.07it/s] 


Dev F1 0.5434079441066115


100%|██████████| 131/131 [00:03<00:00, 43.51it/s]


epoch 32, loss: 0.27707777473762746
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 127.22it/s]


Dev F1 0.5438058407787704


100%|██████████| 131/131 [00:02<00:00, 43.92it/s]


epoch 33, loss: 0.27409105810500284
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 127.17it/s]


Dev F1 0.5327270221854761


100%|██████████| 131/131 [00:02<00:00, 44.05it/s]


epoch 34, loss: 0.2719031510917285
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 125.64it/s]


Dev F1 0.5365524683032102


100%|██████████| 131/131 [00:02<00:00, 44.40it/s]


epoch 35, loss: 0.271244932449501
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 125.16it/s]


Dev F1 0.532153229152261


100%|██████████| 131/131 [00:03<00:00, 40.42it/s]


epoch 36, loss: 0.2667283623955632
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 126.25it/s]


Dev F1 0.5308839190628328


100%|██████████| 131/131 [00:02<00:00, 44.71it/s]


epoch 37, loss: 0.26606133204835064
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 126.46it/s]


Dev F1 0.5243404949687245


100%|██████████| 131/131 [00:02<00:00, 43.84it/s]


epoch 38, loss: 0.2646480650847195
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 126.72it/s]


Dev F1 0.5203408466190215


100%|██████████| 131/131 [00:02<00:00, 44.43it/s]


epoch 39, loss: 0.26187098435773193
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 110.18it/s]


Dev F1 0.5199565571544935


100%|██████████| 131/131 [00:02<00:00, 44.24it/s]


epoch 40, loss: 0.25914471226794117
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 109.83it/s]


Dev F1 0.5247829739568749


100%|██████████| 131/131 [00:02<00:00, 43.68it/s]


epoch 41, loss: 0.2570255677436144
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 126.58it/s]


Dev F1 0.5203744493392071


100%|██████████| 131/131 [00:02<00:00, 44.10it/s]


epoch 42, loss: 0.25503725063709815
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 126.43it/s]


Dev F1 0.5068068472840006


100%|██████████| 131/131 [00:02<00:00, 43.78it/s]


epoch 43, loss: 0.25317987776894607
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 109.38it/s]


Dev F1 0.5189131592967501


100%|██████████| 131/131 [00:02<00:00, 43.82it/s]


epoch 44, loss: 0.25338777706368276
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 127.58it/s]


Dev F1 0.5267750892502975


100%|██████████| 131/131 [00:02<00:00, 44.02it/s]


epoch 45, loss: 0.24897098291011258
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 126.93it/s]


Dev F1 0.5118358758548133


100%|██████████| 131/131 [00:02<00:00, 44.33it/s]


epoch 46, loss: 0.24752998636424087
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 127.55it/s]


Dev F1 0.5239608154620069


100%|██████████| 131/131 [00:03<00:00, 41.80it/s]


epoch 47, loss: 0.2448996808237702
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 116.36it/s]


Dev F1 0.5193021019370793


100%|██████████| 131/131 [00:02<00:00, 44.65it/s]


epoch 48, loss: 0.24290464579604054
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 108.06it/s]


Dev F1 0.5238158418468887


100%|██████████| 131/131 [00:02<00:00, 44.15it/s]


epoch 49, loss: 0.2420847733284681
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 126.76it/s]


Dev F1 0.5093286122838077


100%|██████████| 131/131 [00:02<00:00, 43.89it/s]


epoch 50, loss: 0.2421200967017021
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 107.95it/s]


Dev F1 0.5171460176991151


100%|██████████| 131/131 [00:02<00:00, 44.57it/s]


epoch 51, loss: 0.23822504566370986
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 128.37it/s]


Dev F1 0.5114879649890591


100%|██████████| 131/131 [00:02<00:00, 44.12it/s]


epoch 52, loss: 0.23850060840144413
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 127.55it/s]


Dev F1 0.5059782608695652


100%|██████████| 131/131 [00:02<00:00, 43.95it/s]


epoch 53, loss: 0.2349952890445258
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 127.03it/s]


Dev F1 0.524261317430198


100%|██████████| 131/131 [00:02<00:00, 44.04it/s]


epoch 54, loss: 0.23482007311500666
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 106.39it/s]


Dev F1 0.5157694399129962


100%|██████████| 131/131 [00:02<00:00, 43.72it/s]


epoch 55, loss: 0.23321060005945105
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 124.41it/s]


Dev F1 0.5141678129298487


100%|██████████| 131/131 [00:02<00:00, 43.96it/s]


epoch 56, loss: 0.2306307554244995
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 106.99it/s]


Dev F1 0.5021160409556314


100%|██████████| 131/131 [00:02<00:00, 44.21it/s]


epoch 57, loss: 0.22936987967891548
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 122.09it/s]


Dev F1 0.5154413774255261


100%|██████████| 131/131 [00:03<00:00, 38.91it/s]


epoch 58, loss: 0.22789861682717127
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 125.99it/s]


Dev F1 0.5104818949087938


100%|██████████| 131/131 [00:02<00:00, 43.86it/s]


epoch 59, loss: 0.22885856726242385
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 123.67it/s]


Dev F1 0.4988558352402746


100%|██████████| 131/131 [00:02<00:00, 43.89it/s]


epoch 60, loss: 0.22816718125161323
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 123.95it/s]


Dev F1 0.514872334824954


100%|██████████| 131/131 [00:02<00:00, 43.68it/s]


epoch 61, loss: 0.22859140104464903
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 106.13it/s]


Dev F1 0.506378802747792


100%|██████████| 131/131 [00:02<00:00, 44.36it/s]


epoch 62, loss: 0.22635813808168165
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 125.95it/s]


Dev F1 0.5116215848851434


100%|██████████| 131/131 [00:03<00:00, 43.41it/s]


epoch 63, loss: 0.22486095949438692
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 124.77it/s]


Dev F1 0.5134876078323977


100%|██████████| 131/131 [00:02<00:00, 43.77it/s]


epoch 64, loss: 0.22242121239199894
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 105.54it/s]


Dev F1 0.5087209302325582


100%|██████████| 131/131 [00:03<00:00, 42.66it/s]


epoch 65, loss: 0.22331988822867851
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 111.42it/s]


Dev F1 0.5122143420015761


100%|██████████| 131/131 [00:02<00:00, 43.81it/s]


epoch 66, loss: 0.220057746722498
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 125.48it/s]


Dev F1 0.5105820105820106


100%|██████████| 131/131 [00:02<00:00, 44.05it/s]


epoch 67, loss: 0.219449745789739
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 124.36it/s]


Dev F1 0.5107442348008386


100%|██████████| 131/131 [00:02<00:00, 44.10it/s]


epoch 68, loss: 0.2173777150971289
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 114.29it/s]


Dev F1 0.5194123819517313


100%|██████████| 131/131 [00:03<00:00, 40.90it/s]


epoch 69, loss: 0.21815186636593506
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 123.35it/s]


Dev F1 0.5099399599733155


100%|██████████| 131/131 [00:02<00:00, 43.72it/s]


epoch 70, loss: 0.2132463783935736
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 127.20it/s]


Dev F1 0.5057354065765995


100%|██████████| 131/131 [00:02<00:00, 43.76it/s]


epoch 71, loss: 0.21254878576475245
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 124.75it/s]


Dev F1 0.5079281876556153


100%|██████████| 131/131 [00:02<00:00, 44.55it/s]


epoch 72, loss: 0.21490926949577477
Evaluating dev...


100%|██████████| 1044/1044 [00:10<00:00, 96.29it/s]


Dev F1 0.5103610061253747


100%|██████████| 131/131 [00:03<00:00, 42.57it/s]


epoch 73, loss: 0.2134998714878359
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 124.17it/s]


Dev F1 0.4998016660055533


100%|██████████| 131/131 [00:02<00:00, 44.00it/s]


epoch 74, loss: 0.20928401926546605
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 124.81it/s]


Dev F1 0.5045281533009581


100%|██████████| 131/131 [00:02<00:00, 43.91it/s]


epoch 75, loss: 0.20761414907360806
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 128.02it/s]


Dev F1 0.5098141847683851


100%|██████████| 131/131 [00:02<00:00, 43.70it/s]


epoch 76, loss: 0.20570634095040896
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 112.21it/s]


Dev F1 0.50336012649888


100%|██████████| 131/131 [00:02<00:00, 43.92it/s]


epoch 77, loss: 0.2061884141605319
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 127.98it/s]


Dev F1 0.5008649367930804


100%|██████████| 131/131 [00:02<00:00, 44.18it/s]


epoch 78, loss: 0.2060104486823992
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 125.10it/s]


Dev F1 0.5027653410587306


100%|██████████| 131/131 [00:02<00:00, 44.02it/s]


epoch 79, loss: 0.20369348805824308
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 124.22it/s]


Dev F1 0.5


100%|██████████| 131/131 [00:02<00:00, 43.86it/s]


epoch 80, loss: 0.20300785173441618
Evaluating dev...


100%|██████████| 1044/1044 [00:11<00:00, 90.90it/s]


Dev F1 0.5001316829075585


100%|██████████| 131/131 [00:02<00:00, 43.76it/s]


epoch 81, loss: 0.20044886929388264
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 122.91it/s]


Dev F1 0.503132915611252


100%|██████████| 131/131 [00:02<00:00, 44.06it/s]


epoch 82, loss: 0.19994705540078286
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 125.17it/s]


Dev F1 0.4995411039727284


100%|██████████| 131/131 [00:02<00:00, 43.95it/s]


epoch 83, loss: 0.1974352420741365
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 121.95it/s]


Dev F1 0.49271212909942735


100%|██████████| 131/131 [00:03<00:00, 43.49it/s]


epoch 84, loss: 0.19905833956849484
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 119.53it/s]


Dev F1 0.5043033889187736


100%|██████████| 131/131 [00:03<00:00, 40.19it/s]


epoch 85, loss: 0.19856211332646945
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 117.93it/s]


Dev F1 0.5026567481402763


100%|██████████| 131/131 [00:02<00:00, 44.03it/s]


epoch 86, loss: 0.1970047934819724
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 123.50it/s]


Dev F1 0.49240924092409244


100%|██████████| 131/131 [00:02<00:00, 44.08it/s]


epoch 87, loss: 0.19783678707730679
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 125.12it/s]


Dev F1 0.49812030075187974


100%|██████████| 131/131 [00:02<00:00, 43.80it/s]


epoch 88, loss: 0.19627168063671535
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 105.58it/s]


Dev F1 0.49456594659868514


100%|██████████| 131/131 [00:03<00:00, 42.21it/s]


epoch 89, loss: 0.19283691964304175
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 110.85it/s]


Dev F1 0.49642464246424645


100%|██████████| 131/131 [00:02<00:00, 43.80it/s]


epoch 90, loss: 0.192386059986271
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 124.23it/s]


Dev F1 0.49780380673499275


100%|██████████| 131/131 [00:02<00:00, 44.64it/s]


epoch 91, loss: 0.19247474123275918
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 123.63it/s]


Dev F1 0.48854961832061067


100%|██████████| 131/131 [00:03<00:00, 43.62it/s]


epoch 92, loss: 0.19399393514822458
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 106.62it/s]


Dev F1 0.4980524539080758


100%|██████████| 131/131 [00:02<00:00, 43.79it/s]


epoch 93, loss: 0.19178887399780842
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 123.12it/s]


Dev F1 0.4969193678006965


100%|██████████| 131/131 [00:02<00:00, 43.97it/s]


epoch 94, loss: 0.19033352808870432
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 123.87it/s]


Dev F1 0.5036428666048482


100%|██████████| 131/131 [00:02<00:00, 43.95it/s]


epoch 95, loss: 0.18720071142866412
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 124.46it/s]


Dev F1 0.4892885480031738


100%|██████████| 131/131 [00:03<00:00, 40.53it/s]


epoch 96, loss: 0.189245545568357
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 119.71it/s]


Dev F1 0.49765635462702557


100%|██████████| 131/131 [00:03<00:00, 43.63it/s]


epoch 97, loss: 0.18501822218185163
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 108.69it/s]


Dev F1 0.49617111169791395


100%|██████████| 131/131 [00:02<00:00, 43.94it/s]


epoch 98, loss: 0.18337730460494528
Evaluating dev...


100%|██████████| 1044/1044 [00:08<00:00, 128.67it/s]


Dev F1 0.4939695857367593


100%|██████████| 131/131 [00:02<00:00, 43.70it/s]


epoch 99, loss: 0.18248631276247154
Evaluating dev...


100%|██████████| 1044/1044 [00:09<00:00, 108.67it/s]


Dev F1 0.4938205265986029
Best dev F1 score: 0.5438058407787704
Best iteration: 32


## Run Model on Entire Dataset for Evaluation

In [None]:
# Set up our output DataFrame
cols = [c for c in train_labs_df.columns if c != 'labels']
out_df = pd.DataFrame(columns=cols)

# Set the Arg ID as the index for easy access
out_df['Argument ID'] = full_ids
out_df.set_index('Argument ID', inplace = True)

# Put the model in evaluation mode
model.eval()

for sents, id in tqdm(zip(full_input_batches, full_ids), total=len(full_input_batches)):
  # Get our prediction
  pred = predict(model, sents).cpu().detach().numpy()[0]

  # Add it to the output DataFrame
  out_df.loc[id] = pred

# Print out for error checking
out_df.info()

100%|██████████| 5220/5220 [01:13<00:00, 70.59it/s]

<class 'pandas.core.frame.DataFrame'>
Index: 5220 entries, A01001 to D27100
Data columns (total 20 columns):
 #   Column                      Non-Null Count  Dtype 
---  ------                      --------------  ----- 
 0   Self-direction: thought     5220 non-null   object
 1   Self-direction: action      5220 non-null   object
 2   Stimulation                 5220 non-null   object
 3   Hedonism                    5220 non-null   object
 4   Achievement                 5220 non-null   object
 5   Power: dominance            5220 non-null   object
 6   Power: resources            5220 non-null   object
 7   Face                        5220 non-null   object
 8   Security: personal          5220 non-null   object
 9   Security: societal          5220 non-null   object
 10  Tradition                   5220 non-null   object
 11  Conformity: rules           5220 non-null   object
 12  Conformity: interpersonal   5220 non-null   object
 13  Humility                    5220 non-null   ob




In [None]:
# Write the output to a TSV
out_df = out_df.astype(int)
#out_df.to_csv('/content/drive/MyDrive/nlp_sp/data/model-preds.tsv', sep="\t")               # Spencer
out_df.to_csv('/content/drive/MyDrive/csci5832_project/results/model-preds.tsv', sep="\t")  # Caroline
print("Finished run!")


Finished run!


# Test phase

## Test set data preprocessing

### Load data

In [None]:
# test arguments
#train_args_df = pd.read_csv('/content/drive/MyDrive/nlp_sp/data/arguments-test.tsv', sep='\t')         # Spencer
test_args_df = pd.read_csv('/content/drive/MyDrive/csci5832_project/data/arguments-test.tsv', sep='\t') # Caroline
# view structure
test_args_df.info()

### Data prep

In [None]:
# convert each row to a dictionary -> List[Dict]
test_data = test_args_df.to_dict(orient='records')
# print examples
print('test example:\n', test_data[0])

### Tokenization

In [None]:
# function to load samples from HuggingFace dataset to be batched and encoded
# identically defined for train/dev section but here again for ease of use

class BatchTokenizer:
    """Tokenizes and pads a batch of input sentences."""
    """HuggingFace docs: https://huggingface.co/transformers/v3.0.2/preprocessing.html"""

    def __init__(self):
        """Initializes the tokenizer

        Args:
            pad_symbol (Optional[str], optional): The symbol for a pad. Defaults to "<P>".
        """
        self.hf_tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-small")
    
    # HuggingFace tokenizer will join data with sentence separator token
    # and match batches of tokenized and encoded sentences
    def get_sep_token(self,):
        return self.hf_tokenizer.sep_token

    # call method can only take a pair of inputs, but we have three
    # conclusion batch, stance batch, and premise batch
    # so we create a hack
    #def __call__(self, con_batch: List[str], stan_batch: List[str], prem_batch: List[str]) -> List[List[str]]:

    def __call__(self, con_stan_batch: List[str], prem_batch: List[str]) -> List[List[str]]:  
        """Uses the huggingface tokenizer to tokenize and pad a batch.

        We return a dictionary of tensors per the huggingface model specification.

        Args:
            batch (List[str]): A List of sentence strings

        Returns:
            Dict: The dictionary of token specifications provided by HuggingFace
        """
        # The HF tokenizer will PAD for us, and additionally combine 
        # the two sentences deimited by the [SEP] token.
        enc = self.hf_tokenizer(
            con_stan_batch,
            prem_batch,
            #stan_batch,
            #prem_batch,
            padding=True,
            return_token_type_ids=False, # ignore with hack
            return_tensors='pt'
        )

        return enc

In [None]:
# define tokenizer
tokenizer = BatchTokenizer()

### "Batch"

In [None]:
# redefine another function to generate triple-wise inputs (test data w/o labels)

def generate_triplewise_input_test(dataset: List[Dict]) -> (List[str], List[str], List[str], List[str]):
    """
    group all argument components
    a datapoint is now a dictionary of 
    argument id, conclusion, stance, premise
    """

    # extract each observation from dictionary; save to list
    d_vals = []
    for i in range(len(dataset)):
        d_vals.append(list(dataset[i].values()))

    # store data items in lists by three categories by id
    id_lst = []    
    conclusion_lst = []
    stance_lst = []
    premise_lst = []

    # generate separate lists from each observation
    for i in range(len(d_vals)):
        id_lst.append(d_vals[i][0])
        conclusion_lst.append(d_vals[i][1])
        stance_lst.append(d_vals[i][2])
        premise_lst.append(d_vals[i][3])

    # add [SEP] token before every stance in list
    stance_lst = [' [SEP] ' + s for s in stance_lst]

    return id_lst, conclusion_lst, stance_lst, premise_lst

In [None]:
# apply function to generate triple-wise inputs and labels for batching

# test data
test_ids, test_conclusions, test_stances, test_premises = generate_triplewise_input_test(test_data)

In [None]:
# temporarily combine conclusions and stances separate with [SEP]
# use hack to merge tokenized conclusion batch, stance batch, and premise batch

# test data
test_conclusions_stances = []
for i in range(len(test_conclusions)):
  test_conclusions_stances.append(test_conclusions[i] + test_stances[i])

In [None]:
# define functions to chunk data for batches 
# identically defined for train/dev section but here again for ease of use

# for train labels
def chunk(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i: i+n]

# for train features
def chunk_multi(lst1, lst2, n):
    for i in range(0, len(lst1), n):
        yield lst1[i: i+n], lst2[i: i+n]

In [None]:
# apply function to batch input data 
# tokenize and encode simultaneously since we are using HuggingFace

# single "batch"
test_size = 1
test_input_batches = [b for b in chunk_multi(test_conclusions_stances, test_premises, test_size)]

# tokenize + encode
test_input_batches = [tokenizer(*batch).to(device) for batch in test_input_batches]

In [None]:
# check test data example
print(test_input_batches[0])
encoded_test_tst = tokenizer.hf_tokenizer.batch_decode(test_input_batches[0]['input_ids'])
encoded_test_tst[0]

## Test set predictions 
Obtain the predictions on the test set using the trained model 

In [None]:
# Set up our output DataFrame
cols = [c for c in train_labs_df.columns if c != 'labels']
out_df = pd.DataFrame(columns=cols)

# Set the Arg ID as the index for easy access
out_df['Argument ID'] = test_ids
out_df.set_index('Argument ID', inplace = True)

# Put the model in evaluation mode
model.eval()

for sents, id in tqdm(zip(test_input_batches, test_ids), total=len(test_input_batches)):
  # Get our prediction
  pred = predict(model, sents).cpu().detach().numpy()[0]

  # Add it to the output DataFrame
  out_df.loc[id] = pred

# Print out for error checking
out_df.info()

In [None]:
# Write the output to a TSV
out_df = out_df.astype(int)
#out_df.to_csv('/content/drive/MyDrive/nlp_sp/data/model-preds-test.tsv', sep="\t")               # Spencer
out_df.to_csv('/content/drive/MyDrive/csci5832_project/results/model-preds-test.tsv', sep="\t")   # Caroline
print("Finished run!")