In [77]:
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
import torch
from torch import nn
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, mean_squared_error, mean_absolute_error, r2_score, accuracy_score, precision_score
from datasets import Dataset
from transformers import DataCollatorWithPadding, BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, BertModel, AdamW, get_linear_schedule_with_warmup

In [2]:
arg_qual_dir = "..\\IBM_Debater_arg_quality\\arg_quality_rank_30k.csv"

arg_qual_df = pd.read_csv(arg_qual_dir)

In [6]:
# train test dev
arg_qual_df.head()

Unnamed: 0,argument,topic,set,WA,MACE-P,stance_WA,stance_WA_conf
0,"""marriage"" isn't keeping up with the times. a...",We should abandon marriage,train,0.846165,0.297659,1,1.0
1,.a multi-party system would be too confusing a...,We should adopt a multi-party system,train,0.891271,0.726133,-1,1.0
2,\ero-tolerance policy in schools should not be...,We should adopt a zero-tolerance policy in sch...,dev,0.721192,0.396953,-1,1.0
3,`people reach their limit when it comes to the...,Assisted suicide should be a criminal offence,train,0.730395,0.225212,-1,1.0
4,"100% agree, should they do that, it would be a...",We should abolish safe spaces,train,0.236686,0.004104,1,0.805517


In [7]:
len(arg_qual_df)

30497

In [23]:
arg_qual_df.dropna(inplace=True)
arg_qual_df.drop_duplicates(inplace=True)
arg_qual_df.drop(columns=["MACE-P", "stance_WA", "stance_WA_conf"], inplace=True)

In [24]:
# calc the td-idf score of the text

def get_most_important_word(df, feature):
    vectorizer = TfidfVectorizer()
    X = vectorizer.fit_transform(df[feature])
    feature_names = vectorizer.get_feature_names_out()
    
    most_important_words = []
    
    for row in X:
        row_data = row.toarray().flatten()
        most_important_index = row_data.argmax()
        most_important_word = feature_names[most_important_index]
        
        if most_important_word.isnumeric():
            most_important_words.append("number")
        elif most_important_word.isalpha():
            most_important_words.append(most_important_word)
        else:
            most_important_words.append("na")
    return most_important_words

arg_qual_df['most_important_word'] = get_most_important_word(arg_qual_df, 'argument')

In [25]:
arg_qual_df.head()

Unnamed: 0,argument,topic,set,WA,most_important_word
0,"""marriage"" isn't keeping up with the times. a...",We should abandon marriage,train,0.846165,incorporates
1,.a multi-party system would be too confusing a...,We should adopt a multi-party system,train,0.891271,consensus
2,\ero-tolerance policy in schools should not be...,We should adopt a zero-tolerance policy in sch...,dev,0.721192,nuanced
3,`people reach their limit when it comes to the...,Assisted suicide should be a criminal offence,train,0.730395,suffering
4,"100% agree, should they do that, it would be a...",We should abolish safe spaces,train,0.236686,number


In [36]:

def count_words(sentence):
    sen_normal = sentence.lower().strip()
    for word_idx in range(len(sen_normal)):
        if not sen_normal[word_idx] in " abcdefghijklmnopqrstuvwxyz-',.":
            sen_normal = sen_normal.replace(sen_normal[word_idx], "*")
    sen_normal = sen_normal.replace("*", "")
    return len(sen_normal.split())

arg_qual_df['word_count'] = arg_qual_df['argument'].apply(count_words)


In [37]:
arg_qual_df.head()

Unnamed: 0,argument,topic,set,WA,most_important_word,word_count
0,"""marriage"" isn't keeping up with the times. a...",We should abandon marriage,train,0.846165,incorporates,27
1,.a multi-party system would be too confusing a...,We should adopt a multi-party system,train,0.891271,consensus,18
2,\ero-tolerance policy in schools should not be...,We should adopt a zero-tolerance policy in sch...,dev,0.721192,nuanced,31
3,`people reach their limit when it comes to the...,Assisted suicide should be a criminal offence,train,0.730395,suffering,40
4,"100% agree, should they do that, it would be a...",We should abolish safe spaces,train,0.236686,number,11


In [38]:
def avg_qual_score(df, feature):
    df_len = len(df)
    df_sum = df[feature].sum()
    return df_sum / df_len

avg_qual_score(arg_qual_df, "WA")

0.7913285945189035

In [65]:
def amount_high_low_qual(df, feature, threshold):
    high_qual = df[df[feature] > threshold]
    low_qual = df[df[feature] < threshold]
    return f"high qual: {len(high_qual)}, low qual: {len(low_qual)}"

In [46]:
lower_qual_df = arg_qual_df[arg_qual_df["WA"] < 0.7]
higher_qual_df = arg_qual_df[arg_qual_df["WA"] > 0.7]

In [48]:
balanced_df = pd.concat([lower_qual_df, higher_qual_df.sample(n=len(lower_qual_df), random_state=32)], ignore_index=True)

In [50]:
balanced_df.head()

Unnamed: 0,argument,topic,set,WA,most_important_word,word_count
0,"100% agree, should they do that, it would be a...",We should abolish safe spaces,train,0.236686,number,11
1,a bad score in an intelligence test is a blow ...,Intelligence tests bring more harm than good,dev,0.638716,blow,25
2,"A ban would be inffective, people who want a b...",Surrogacy should be banned,train,0.467092,foresaken,32
3,A Blockade is the perfect way to create stagna...,Blockade of the Gaza Strip should be ended,dev,0.555555,stagnation,12
4,A blockade is what you do when you want any ch...,Blockade of the Gaza Strip should be ended,dev,0.634427,you,24


In [None]:
def amount_high_low_qual(df, feature, threshold):
    high_qual = df[df[feature] > threshold]
    low_qual = df[df[feature] < threshold]
    return f"high qual: {len(high_qual)}, low qual: {len(low_qual)}"

In [52]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')



In [61]:
def tokenize_column(text):
    return tokenizer(text, padding='max_length', truncation=True, max_length=256, return_tensors='pt')

balanced_df['argument_input_ids'] = balanced_df['argument'].apply(lambda x: tokenize_column(x)['input_ids'][0])
balanced_df['argument_attention_mask'] = balanced_df['argument'].apply(lambda x: tokenize_column(x)['attention_mask'][0])

balanced_df['topic_input_ids'] = balanced_df['topic'].apply(lambda x: tokenize_column(x)['input_ids'][0])
balanced_df['topic_attention_mask'] = balanced_df['topic'].apply(lambda x: tokenize_column(x)['attention_mask'][0])

balanced_df['most_important_word_input_ids'] = balanced_df['most_important_word'].apply(lambda x: tokenize_column(x)['input_ids'][0])
balanced_df['most_important_word_attention_mask'] = balanced_df['most_important_word'].apply(lambda x: tokenize_column(x)['attention_mask'][0])


In [62]:
balanced_df.head()

Unnamed: 0,argument,topic,set,WA,most_important_word,word_count,argument_input_ids,argument_attention_mask,topic_input_ids,topic_attention_mask,most_important_word_input_ids,most_important_word_attention_mask
0,"100% agree, should they do that, it would be a...",We should abolish safe spaces,train,0.236686,number,11,"[tensor(101), tensor(2531), tensor(1003), tens...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[tensor(101), tensor(2057), tensor(2323), tens...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[tensor(101), tensor(2193), tensor(102), tenso...","[tensor(1), tensor(1), tensor(1), tensor(0), t..."
1,a bad score in an intelligence test is a blow ...,Intelligence tests bring more harm than good,dev,0.638716,blow,25,"[tensor(101), tensor(1037), tensor(2919), tens...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[tensor(101), tensor(4454), tensor(5852), tens...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[tensor(101), tensor(6271), tensor(102), tenso...","[tensor(1), tensor(1), tensor(1), tensor(0), t..."
2,"A ban would be inffective, people who want a b...",Surrogacy should be banned,train,0.467092,foresaken,32,"[tensor(101), tensor(1037), tensor(7221), tens...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[tensor(101), tensor(7505), tensor(3217), tens...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[tensor(101), tensor(18921), tensor(3736), ten...","[tensor(1), tensor(1), tensor(1), tensor(1), t..."
3,A Blockade is the perfect way to create stagna...,Blockade of the Gaza Strip should be ended,dev,0.555555,stagnation,12,"[tensor(101), tensor(1037), tensor(15823), ten...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[tensor(101), tensor(15823), tensor(1997), ten...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[tensor(101), tensor(2358), tensor(8490), tens...","[tensor(1), tensor(1), tensor(1), tensor(1), t..."
4,A blockade is what you do when you want any ch...,Blockade of the Gaza Strip should be ended,dev,0.634427,you,24,"[tensor(101), tensor(1037), tensor(15823), ten...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[tensor(101), tensor(15823), tensor(1997), ten...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[tensor(101), tensor(2017), tensor(102), tenso...","[tensor(1), tensor(1), tensor(1), tensor(0), t..."


In [73]:
balanced_df.drop(columns=["argument", "topic", "most_important_word"], inplace=True)

KeyError: "['argument', 'topic', 'most_important_word'] not found in axis"

In [74]:
train_df = balanced_df[balanced_df["set"] == "train"]
test_df = balanced_df[balanced_df["set"] == "test"]
val_df = balanced_df[balanced_df["set"] == "dev"]
print("train: ", len(train_df))
print("test: ", len(test_df))
print("val: ", len(val_df))

train:  11263
test:  3484
val:  1735


In [72]:
amount_high_low_qual(val_df, "labels", 0.7)

'high qual: 908, low qual: 827'

In [59]:
# custom dataset that inherits from the dataset class
class ArgumentQualityDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        item = self.dataframe.iloc[idx]
        return {
            'argument_input_ids': torch.tensor(item['argument_input_ids'], dtype=torch.long),
            'argument_attention_mask': torch.tensor(item['argument_attention_mask'], dtype=torch.long),
            'topic_input_ids': torch.tensor(item['topic_input_ids'], dtype=torch.long),
            'topic_attention_mask': torch.tensor(item['topic_attention_mask'], dtype=torch.long),
            'most_important_word_input_ids': torch.tensor(item['most_important_word_input_ids'], dtype=torch.long),
            'most_important_word_attention_mask': torch.tensor(item['most_important_word_attention_mask'], dtype=torch.long),
            'labels': torch.tensor(item['WA'], dtype=torch.float)
        }

# custom bert model that inherits from the nn.Module class
class CustomBertModel(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super(CustomBertModel, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.classifier = nn.Linear(self.bert.config.hidden_size * 3 + 1, 1)  

    def forward(self, argument_input_ids, argument_attention_mask, topic_input_ids, topic_attention_mask, most_important_word_input_ids, most_important_word_attention_mask, word_count):
        argument_outputs = self.bert(input_ids=argument_input_ids, attention_mask=argument_attention_mask).pooler_output
        topic_outputs = self.bert(input_ids=topic_input_ids, attention_mask=topic_attention_mask).pooler_output
        word_outputs = self.bert(input_ids=most_important_word_input_ids, attention_mask=most_important_word_attention_mask).pooler_output
        
        concatenated_outputs = torch.cat((argument_outputs, topic_outputs, word_outputs, word_count.unsqueeze(1)), dim=1)
        logits = self.classifier(concatenated_outputs)
        return logits


In [60]:
model = CustomBertModel()



In [69]:
train_dataset = ArgumentQualityDataset(train_df)
test_dataset = ArgumentQualityDataset(test_df)
val_dataset = ArgumentQualityDataset(val_df)

In [76]:
def train_model(model, train_dataset, val_dataset, test_dataset, batch_size=8, num_epochs=3, learning_rate=2e-5):
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    # test_loader = DataLoader(test_dataset, batch_size=batch_size)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    total_steps = len(train_loader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, 
                                                num_warmup_steps=0, 
                                                num_training_steps=total_steps)
    loss_fn = torch.nn.MSELoss()

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch in train_loader:
            optimizer.zero_grad()
            input_ids = batch['argument_input_ids'].to(device)
            attention_mask = batch['argument_attention_mask'].to(device)
            topic_ids = batch['topic_input_ids'].to(device)
            topic_mask = batch['topic_attention_mask'].to(device)
            word_ids = batch['most_important_word_input_ids'].to(device)
            word_mask = batch['most_important_word_attention_mask'].to(device)
            word_count = batch['word_count'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(argument_input_ids=input_ids,
                            argument_attention_mask=attention_mask,
                            topic_input_ids=topic_ids,
                            topic_attention_mask=topic_mask,
                            most_important_word_input_ids=word_ids,
                            most_important_word_attention_mask=word_mask,
                            word_count=word_count)
            
            loss = loss_fn(outputs.squeeze(), labels)
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            scheduler.step()
        
        avg_train_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_train_loss}')
        
        model.eval()
        val_preds, val_labels = [], []
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['argument_input_ids'].to(device)
                attention_mask = batch['argument_attention_mask'].to(device)
                topic_ids = batch['topic_input_ids'].to(device)
                topic_mask = batch['topic_attention_mask'].to(device)
                word_ids = batch['most_important_word_input_ids'].to(device)
                word_mask = batch['most_important_word_attention_mask'].to(device)
                word_count = batch['word_count'].to(device)
                labels = batch['labels'].to(device) 
                
                outputs = model(argument_input_ids=input_ids,
                                argument_attention_mask=attention_mask,
                                topic_input_ids=topic_ids,
                                topic_attention_mask=topic_mask,
                                most_important_word_input_ids=word_ids,
                                most_important_word_attention_mask=word_mask,
                                word_count=word_count)
                
                val_preds.extend(outputs.squeeze().cpu().numpy())
                val_labels.extend(labels.cpu().numpy())
        
        val_mse = mean_squared_error(val_labels, val_preds)
        print(f'Validation MSE: {val_mse}')

In [None]:
# train the model here

In [None]:
def evaluate_model(model, test_dataset, batch_size=8, tolerance=0.1):
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()
    
    test_preds, test_labels = [], []
    
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['argument_input_ids'].to(device)
            attention_mask = batch['argument_attention_mask'].to(device)
            topic_ids = batch['topic_input_ids'].to(device)
            topic_mask = batch['topic_attention_mask'].to(device)
            word_ids = batch['most_important_word_input_ids'].to(device)
            word_mask = batch['most_important_word_attention_mask'].to(device)
            word_count = batch['word_count'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(argument_input_ids=input_ids,
                            argument_attention_mask=attention_mask,
                            topic_input_ids=topic_ids,
                            topic_attention_mask=topic_mask,
                            most_important_word_input_ids=word_ids,
                            most_important_word_attention_mask=word_mask,
                            word_count=word_count)
            
            test_preds.extend(outputs.squeeze().cpu().numpy())
            test_labels.extend(labels.cpu().numpy())
        
    test_preds = np.array(test_preds)
    test_labels = np.array(test_labels)
    
    acc_within_tolerance = np.abs(test_preds - test_labels) <= tolerance
    acc_within_tolerance = np.mean(acc_within_tolerance)
    
    test_mse = mean_squared_error(test_labels, test_preds)
    test_mae = mean_absolute_error(test_labels, test_preds)
    test_r2 = r2_score(test_labels, test_preds)
    
    print(f'Test MSE: {test_mse}')
    print(f'Test MAE: {test_mae}')
    print(f'Test R²: {test_r2}')
    print(f'Accuracy within {tolerance} tolerance: {acc_within_tolerance}')
