In [1]:
import pandas as pd
import numpy as np
from collections import Counter
import json
from tqdm import tqdm
import random
import pickle
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import SequentialSampler, TensorDataset, RandomSampler
from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast
from transformers import RobertaTokenizer, RobertaConfig, RobertaModel
from sklearn.metrics import classification_report, confusion_matrix, f1_score, accuracy_score
import torch
import torch.nn as nn
from datasets import load_dataset
import time
from torch.utils.data import DataLoader

2023-02-20 09:17:21.575575: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-20 09:17:21.730542: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-02-20 09:17:22.322627: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-11.6/lib64:/usr/local/cuda-11.6/lib64
2023-02-20 09:17:22.322701: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: 

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "3"

In [7]:
val_cats = json.load(open("../data/value-categories.json"))
tags = ["training", "validation"]
data_dict = {}
ratio_hard = 0.5

In [4]:
all_labels = sorted(list(val_cats.keys()))
all_labels_reduced = sorted(list(set([i.split(":")[0] for i in list(val_cats.keys())])))
id_2_class = {ix:i for ix, i in enumerate(all_labels)}
id_2_class_reduced = {ix:i for ix, i in enumerate(all_labels_reduced)}

len(all_labels), len(all_labels_reduced)

(20, 12)

In [5]:
data_dict = pickle.load(open("../data/data_dict_raw.pkl", "rb"))
len(data_dict["training"].keys())

5393

In [7]:
example_dict = {}
for tag in tags:
    if example_dict.get(tag, None) is None:
        example_dict[tag] = []
        
    for k, v in data_dict[tag].items():
        tmp = [0] * len(all_labels)
        tmp_red = [0] * len(all_labels_reduced)
        
        for ix, i in enumerate(all_labels):
            if i in v["labels"]:
                tmp[ix] = 1
        
        red_labels = set([i.split(":")[0] for i in v["labels"]])
        for ix, i in enumerate(all_labels_reduced):
            if i in red_labels:
                tmp_red[ix] = 1
        example_dict[tag].append([k, v["sent"], tmp, tmp_red])

In [10]:
pickle.dump(example_dict, open("../data/example_dict_standard_raw.pkl", "wb"))

### Training

In [6]:
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
example_dict = pickle.load(open("../data/example_dict_standard_raw.pkl", "rb"))

In [15]:
def get_encodings(lst):
    all_toks = tokenizer([example[1] for example in lst], padding=True)
    all_lbl = [example[-2] for example in lst]
    red_lbl = [example[-1] for example in lst]
    return torch.tensor(all_toks.input_ids), torch.tensor(all_toks.attention_mask), \
            torch.tensor(all_lbl), torch.tensor(red_lbl)

In [16]:
train_input_ids, train_attention_mask, train_labels, train_red_labels = get_encodings(example_dict["training"])
valid_input_ids, valid_attention_mask, valid_labels, valid_red_labels = get_encodings(example_dict["validation"])

print(train_input_ids.shape, train_attention_mask.shape, train_labels.shape, train_red_labels.shape)
print(valid_input_ids.shape, valid_attention_mask.shape, valid_labels.shape, valid_red_labels.shape)

torch.Size([5393, 166]) torch.Size([5393, 166]) torch.Size([5393, 20]) torch.Size([5393, 12])
torch.Size([1896, 159]) torch.Size([1896, 159]) torch.Size([1896, 20]) torch.Size([1896, 12])


In [5]:
class BaselineModel(nn.Module):
    def __init__(self, base_model, n_classes):
        super().__init__()
        self.base_model = base_model
        self.ff = nn.Linear(768, n_classes)

    def forward(self, input_ids, attention_mask):
        op = self.base_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        return self.ff(op["pooler_output"])

In [18]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

def train(model, optimizer, dataloader, device):
    model.train()
    ep_t_loss, batch_num = 0, 0
    loss_fct = nn.BCEWithLogitsLoss()
    
    for ix, batch in tqdm(enumerate(dataloader)):
        batch = tuple(t.to(device) for t in batch)
        input_ids, attention_mask, labels_all, labels_red = batch
        labels = labels_all if n_classes == 20 else labels_red
        optimizer.zero_grad()
        
        output_dct = model(input_ids=input_ids, attention_mask=attention_mask)
        
        loss = loss_fct(output_dct.view(-1), labels.float().view(-1))
        loss.backward()
        optimizer.step()
        
        batch_num += 1
        ep_t_loss += loss.item()
    return ep_t_loss / batch_num

def evaluate(model, dataloader, device, n_classes=20, threshold=0.5):
    model.eval()
    ep_t_loss, batch_num = 0, 0
    loss_fct = nn.BCEWithLogitsLoss()
    preds, actual = [], []
    preds_cls, actual_cls = {i:[] for i in range(n_classes)}, {i:[] for i in range(n_classes)}
    
    for ix, batch in tqdm(enumerate(dataloader)):
        batch = tuple(t.to(device) for t in batch)
        input_ids, attention_mask, labels_all, labels_red = batch
        labels = labels_all if n_classes == 20 else labels_red
        with torch.no_grad():
            output_dct = model(input_ids=input_ids, attention_mask=attention_mask)
        
        loss = loss_fct(output_dct.view(-1), labels.float().view(-1))
        
        batch_num += 1
        ep_t_loss += loss.item()
        prd = (torch.sigmoid(output_dct) >= threshold).long()
        preds.extend(prd.view(-1).tolist())
        actual.extend(labels.view(-1).tolist())
        
        for k in preds_cls.keys():
            preds_cls[k].extend(prd[:,k].view(-1).tolist())
            actual_cls[k].extend(labels[:,k].view(-1).tolist())
            
    print("VALIDATION STATS:\n", classification_report(actual, preds, zero_division=0))
    print("-------------------------------")
    print("CLASS WISE VALIDATION STATS:\n")
    mappn = id_2_class if n_classes == 20 else id_2_class_reduced
    for k in preds_cls.keys():
        print("CLASS", mappn[k], ":\n", classification_report(actual_cls[k], preds_cls[k], zero_division=0),"\n")
    print("====================================================")
    return ep_t_loss / batch_num

In [19]:
batch_size = 64

In [20]:
train_data = TensorDataset(train_input_ids, train_attention_mask, train_labels, train_red_labels)
valid_data = TensorDataset(valid_input_ids, valid_attention_mask, valid_labels, valid_red_labels)

train_dl = DataLoader(train_data, batch_size=batch_size, sampler=RandomSampler(train_data), num_workers=2)
valid_dl = DataLoader(valid_data, batch_size=batch_size, sampler=SequentialSampler(valid_data), num_workers=2)

In [21]:
N_EPOCHS = 30
best_valid_loss = float('inf')
model_name = "roberta_baseline_model_reduced_v1.pt" # "roberta_baseline_model_v1.pt"
early_stopping = 4
n_classes = len(all_labels_reduced)# len(all_labels)
device = torch.device("cuda:{}".format(0)) if torch.cuda.is_available() else "cpu"
model = BaselineModel(RobertaModel.from_pretrained("roberta-base"), n_classes).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [22]:
n_classes

12

In [14]:
t_loss, v_loss, early_stopping_marker = [], [], []

for epoch in range(N_EPOCHS):
    print("Epoch: {}, Training ...\n".format(epoch))
    start_time = time.time()

    tr_l = train(model, optimizer, train_dl, device)
    t_loss.append(tr_l)
    
    print("Epoch: {}, Evaluating ...\n".format(epoch))
    vl_l = evaluate(model, valid_dl, device, n_classes=n_classes)
    v_loss.append(vl_l)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if vl_l <= best_valid_loss:
        best_valid_loss = vl_l
        print("FOUND BEST MODEL!")
        print("SAVING BEST MODEL!")
        torch.save(model.state_dict(), model_name)
        early_stopping_marker.append(False)
    else:
        early_stopping_marker.append(True)
    print(f'Epoch: {epoch} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Total Loss: {tr_l:.3f} | Val Total Loss: {vl_l:.3f}')
    if all(early_stopping_marker[-early_stopping:]) and len(early_stopping_marker) >= early_stopping:
        print("Early stopping training as the Validation loss did NOT improve for last " + \
              str(early_stopping) + " iterations.")
        break
    

Epoch: 0, Training ...



85it [00:24,  3.51it/s]

Epoch: 0, Evaluating ...




30it [00:02, 10.31it/s]


VALIDATION STATS:
               precision    recall  f1-score   support

           0       0.83      0.92      0.87     17108
           1       0.63      0.42      0.50      5644

    accuracy                           0.79     22752
   macro avg       0.73      0.67      0.69     22752
weighted avg       0.78      0.79      0.78     22752

-------------------------------
CLASS WISE VALIDATION STATS:

CLASS Achievement :
               precision    recall  f1-score   support

           0       0.72      0.99      0.84      1321
           1       0.88      0.12      0.21       575

    accuracy                           0.73      1896
   macro avg       0.80      0.56      0.52      1896
weighted avg       0.77      0.73      0.65      1896
 

CLASS Benevolence :
               precision    recall  f1-score   support

           0       0.58      1.00      0.73      1091
           1       0.00      0.00      0.00       805

    accuracy                           0.58      1896
   

85it [00:24,  3.47it/s]

Epoch: 1, Evaluating ...




30it [00:02, 10.28it/s]


VALIDATION STATS:
               precision    recall  f1-score   support

           0       0.85      0.92      0.88     17108
           1       0.67      0.53      0.59      5644

    accuracy                           0.82     22752
   macro avg       0.76      0.72      0.74     22752
weighted avg       0.81      0.82      0.81     22752

-------------------------------
CLASS WISE VALIDATION STATS:

CLASS Achievement :
               precision    recall  f1-score   support

           0       0.75      0.99      0.85      1321
           1       0.88      0.25      0.39       575

    accuracy                           0.76      1896
   macro avg       0.82      0.62      0.62      1896
weighted avg       0.79      0.76      0.71      1896
 

CLASS Benevolence :
               precision    recall  f1-score   support

           0       0.71      0.67      0.69      1091
           1       0.58      0.63      0.61       805

    accuracy                           0.65      1896
   

85it [00:24,  3.47it/s]

Epoch: 2, Evaluating ...




30it [00:02, 10.13it/s]


VALIDATION STATS:
               precision    recall  f1-score   support

           0       0.86      0.92      0.89     17108
           1       0.69      0.56      0.62      5644

    accuracy                           0.83     22752
   macro avg       0.78      0.74      0.75     22752
weighted avg       0.82      0.83      0.82     22752

-------------------------------
CLASS WISE VALIDATION STATS:

CLASS Achievement :
               precision    recall  f1-score   support

           0       0.83      0.90      0.86      1321
           1       0.72      0.58      0.64       575

    accuracy                           0.80      1896
   macro avg       0.77      0.74      0.75      1896
weighted avg       0.80      0.80      0.80      1896
 

CLASS Benevolence :
               precision    recall  f1-score   support

           0       0.71      0.79      0.75      1091
           1       0.66      0.56      0.61       805

    accuracy                           0.69      1896
   

85it [00:24,  3.45it/s]

Epoch: 3, Evaluating ...




30it [00:02, 10.13it/s]


VALIDATION STATS:
               precision    recall  f1-score   support

           0       0.87      0.92      0.89     17108
           1       0.70      0.57      0.63      5644

    accuracy                           0.83     22752
   macro avg       0.78      0.74      0.76     22752
weighted avg       0.82      0.83      0.83     22752

-------------------------------
CLASS WISE VALIDATION STATS:

CLASS Achievement :
               precision    recall  f1-score   support

           0       0.82      0.93      0.87      1321
           1       0.77      0.53      0.63       575

    accuracy                           0.81      1896
   macro avg       0.80      0.73      0.75      1896
weighted avg       0.81      0.81      0.80      1896
 

CLASS Benevolence :
               precision    recall  f1-score   support

           0       0.72      0.77      0.74      1091
           1       0.65      0.59      0.62       805

    accuracy                           0.69      1896
   

85it [00:24,  3.44it/s]

Epoch: 4, Evaluating ...




30it [00:02, 10.16it/s]


VALIDATION STATS:
               precision    recall  f1-score   support

           0       0.86      0.92      0.89     17108
           1       0.70      0.56      0.62      5644

    accuracy                           0.83     22752
   macro avg       0.78      0.74      0.76     22752
weighted avg       0.82      0.83      0.82     22752

-------------------------------
CLASS WISE VALIDATION STATS:

CLASS Achievement :
               precision    recall  f1-score   support

           0       0.83      0.92      0.87      1321
           1       0.75      0.56      0.64       575

    accuracy                           0.81      1896
   macro avg       0.79      0.74      0.76      1896
weighted avg       0.80      0.81      0.80      1896
 

CLASS Benevolence :
               precision    recall  f1-score   support

           0       0.70      0.81      0.75      1091
           1       0.68      0.54      0.60       805

    accuracy                           0.69      1896
   

85it [00:24,  3.44it/s]

Epoch: 5, Evaluating ...




30it [00:02, 10.16it/s]


VALIDATION STATS:
               precision    recall  f1-score   support

           0       0.87      0.91      0.89     17108
           1       0.68      0.58      0.63      5644

    accuracy                           0.83     22752
   macro avg       0.77      0.75      0.76     22752
weighted avg       0.82      0.83      0.82     22752

-------------------------------
CLASS WISE VALIDATION STATS:

CLASS Achievement :
               precision    recall  f1-score   support

           0       0.84      0.89      0.86      1321
           1       0.70      0.60      0.65       575

    accuracy                           0.80      1896
   macro avg       0.77      0.75      0.76      1896
weighted avg       0.80      0.80      0.80      1896
 

CLASS Benevolence :
               precision    recall  f1-score   support

           0       0.71      0.75      0.73      1091
           1       0.63      0.59      0.61       805

    accuracy                           0.68      1896
   

85it [00:24,  3.45it/s]

Epoch: 6, Evaluating ...




30it [00:02, 10.09it/s]


VALIDATION STATS:
               precision    recall  f1-score   support

           0       0.87      0.90      0.89     17108
           1       0.67      0.60      0.63      5644

    accuracy                           0.83     22752
   macro avg       0.77      0.75      0.76     22752
weighted avg       0.82      0.83      0.82     22752

-------------------------------
CLASS WISE VALIDATION STATS:

CLASS Achievement :
               precision    recall  f1-score   support

           0       0.83      0.92      0.87      1321
           1       0.74      0.56      0.64       575

    accuracy                           0.81      1896
   macro avg       0.78      0.74      0.75      1896
weighted avg       0.80      0.81      0.80      1896
 

CLASS Benevolence :
               precision    recall  f1-score   support

           0       0.71      0.78      0.74      1091
           1       0.66      0.57      0.61       805

    accuracy                           0.69      1896
   

85it [00:24,  3.44it/s]

Epoch: 7, Evaluating ...




30it [00:02, 10.12it/s]


VALIDATION STATS:
               precision    recall  f1-score   support

           0       0.87      0.89      0.88     17108
           1       0.64      0.60      0.62      5644

    accuracy                           0.82     22752
   macro avg       0.76      0.75      0.75     22752
weighted avg       0.81      0.82      0.82     22752

-------------------------------
CLASS WISE VALIDATION STATS:

CLASS Achievement :
               precision    recall  f1-score   support

           0       0.84      0.88      0.86      1321
           1       0.68      0.61      0.65       575

    accuracy                           0.80      1896
   macro avg       0.76      0.74      0.75      1896
weighted avg       0.79      0.80      0.79      1896
 

CLASS Benevolence :
               precision    recall  f1-score   support

           0       0.77      0.58      0.66      1091
           1       0.57      0.76      0.65       805

    accuracy                           0.66      1896
   

### Predict on Test Set

In [6]:
device = torch.device("cuda:{}".format(0)) if torch.cuda.is_available() else "cpu"
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
roberta = RobertaModel.from_pretrained("roberta-base")

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
def get_model(reduced=False):
    if reduced:
        model_name = "roberta_baseline_model_reduced_v1.pt" #"roberta_baseline_model_v1.pt"
        n_classes = len(all_labels_reduced)
    else:
        model_name = "roberta_baseline_model_v1.pt"
        n_classes = len(all_labels)

    model = BaselineModel(roberta, n_classes).to(device)
    state_dict = torch.load(model_name)
    model.load_state_dict(state_dict)
    model.eval()
    print("Model loaded!")
    return model

In [8]:
reduced_model, full_model = get_model(reduced=True), get_model(reduced=False)

Model loaded!
Model loaded!


In [10]:
split = "test"

In [11]:
arg_df = pd.read_csv("../data/arguments-"+split+".tsv", sep="\t")
arg_df.shape

(1576, 4)

In [12]:
import math

batch_size = 64
threshold = 0.5

all_res = {}
sents = []
for ix, row in tqdm(arg_df.iterrows()):
    stance = " against. " if row["Stance"] == "against" else " in favor of. "
    sents.append(row["Premise"] + stance + row["Conclusion"])
    
all_toks = tokenizer(sents, padding=True)
input_ids, attention_mask = torch.tensor(all_toks.input_ids), torch.tensor(all_toks.attention_mask)
test_data = TensorDataset(input_ids, attention_mask)
test_dl = DataLoader(test_data, batch_size=batch_size, sampler=SequentialSampler(test_data), num_workers=2)

preds_full, preds_red = [], []
for ix, batch in tqdm(enumerate(test_dl)):
    batch = tuple(t.to(device) for t in batch)
    input_ids, attention_mask = batch
    with torch.no_grad():
        op_full = full_model(input_ids=input_ids, attention_mask=attention_mask)
        op_red = reduced_model(input_ids=input_ids, attention_mask=attention_mask)
        preds_full.extend((torch.sigmoid(op_full) >= threshold).long().tolist())
        preds_red.extend((torch.sigmoid(op_red) >= threshold).long().tolist())

for e_id, pred in enumerate(preds_full):
    tmp = {"full": [id_2_class[ix] for ix, i in enumerate(pred) if i == 1],
           "reduced": [id_2_class_reduced[ix] for ix, i in enumerate(preds_red[e_id]) if i == 1]}
    all_res[arg_df.iloc[e_id]["Argument ID"]] = tmp

pickle.dump(all_res, open("../data/"+split+"_prediction_logit_dict_baseline_v1.pkl", "wb"))

1576it [00:00, 25186.79it/s]
25it [00:05,  4.91it/s]


In [14]:
# all_res["A07099"]

 ### Submission Formatting

In [21]:
split = "test"
thresh = "thresh_8"

best combos:
    1. thresh_8
    2. thresh_8 + base full
    3. thresh_8 + base reduced filter
    4. thresh_8 + base full + reduced filter
    

In [22]:
res_lbl_dct_fine = pickle.load(open("../data/"+split+"_prediction_label_dict_v1.pkl", "rb"))
res_lbl_dct_base_all = pickle.load(open("../data/"+split+"_prediction_logit_dict_baseline_v1.pkl", "rb"))
col_names = list(pd.read_csv("../data/labels-validation.tsv", sep="\t").columns)

In [23]:
def get_format(valu=None):
    if valu == 1:
        return {k:v["thresh_8"] for k,v in res_lbl_dct_fine.items()}
    
    elif valu == 2:
        return {k: list(set(v["thresh_8"] + res_lbl_dct_base_all[k]["full"])) for k, v in res_lbl_dct_fine.items()} 
    
    elif valu == 3:
        return {k: [i for i in v["thresh_8"] if i.split(":")[0] in res_lbl_dct_base_all[k]["reduced"]] 
                for k, v in res_lbl_dct_fine.items()}
        
    elif valu == 4:
        res_lbl_dct_base = {}
        for k,v in res_lbl_dct_base_all.items():
            res_lbl_dct_base[k] = [i for i in v["full"] if i.split(":")[0] in v["reduced"]]
        return {k: list(set(v["thresh_8"] + res_lbl_dct_base[k])) for k, v in res_lbl_dct_fine.items()}  
    else:
        return {}

In [24]:
for i in range(1,5):
    print("For Setting",i)
    res_lbl_dct = get_format(i)
    print(len(res_lbl_dct))
    
    op_lst = []
    for arg_id, v in res_lbl_dct.items():
        t_lbl = [1 if i in v else 0 for ix, i in enumerate(col_names[1:])]
        op_lst.append([arg_id] + t_lbl)
    op_df = pd.DataFrame(op_lst, columns=col_names)
    print(op_df.shape)
    op_name = "./"+split+"_setting_"+str(i)+".tsv"
    op_df.to_csv(op_name, sep="\t", index=False)
    print("Saved to Location: ",op_name,"\n")

For Setting 1
1896
(1896, 21)
Saved to Location:  ./validation_setting_1.tsv 

For Setting 2
1896
(1896, 21)
Saved to Location:  ./validation_setting_2.tsv 

For Setting 3
1896
(1896, 21)
Saved to Location:  ./validation_setting_3.tsv 

For Setting 4
1896
(1896, 21)
Saved to Location:  ./validation_setting_4.tsv 

