In [None]:
## library set up for grace ondemand jupyternotebook

import numpy as np
import pandas as pd
import ctypes
import os
ctypes.CDLL('/lib64/libsnappy.so.1')
import pyarrow
from datasets import Dataset
from collections import Counter
import random
import evaluate
random.seed(666)
from transformers import BertTokenizer, BertModel, AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer

import torch
from torch import nn
# baseline_only only trains the classifier head added after the last hidden layer output of the transformer block
baseline_only = True

# model_name has to be "BERT" or "Llama2-7B"
model_name = "Llama2-7B"

print("Libraries load finished")

In [None]:
## This code block modifies the NLP model structure for classification task and input the models and tokenizers

class bertClassifier(nn.Module):
    def __init__(self, model, num_labels=2,only_classifier = False):
        super(bertClassifier, self).__init__()
        self.only_classifier = only_classifier
        if only_classifier:
            for param in model.parameters():
                param.requires_grad = False
        self.bert = model
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
        self.loss_fn = nn.CrossEntropyLoss()
        
    def forward(self, input_ids,labels, attention_mask=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        bos_output = outputs.pooler_output.requires_grad_(not self.only_classifier)
        logits = self.classifier(bos_output.bfloat16())
        new_output = {"loss" : self.loss_fn(logits, labels),
                     "logits" : logits}
        return new_output


class LlamaClassifier(nn.Module):
    def __init__(self, model, num_labels=2 , only_classifier=False):
        super(LlamaClassifier, self).__init__()
        self.only_classifier = only_classifier  
        if only_classifier:
            for param in model.parameters():
                param.requires_grad = False
        else:
            lora_config = LoraConfig(task_type="SEQ_CLS", r=16, lora_alpha=4, lora_dropout=0.1)
            model = get_peft_model(model, lora_config)
            model.print_trainable_parameters()
        self.llama = model
        self.classifier = nn.Linear(self.llama.config.hidden_size, num_labels)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, input_ids, labels, attention_mask=None):
        outputs = self.llama(input_ids, attention_mask=attention_mask)
        index_c = torch.sum(attention_mask, dim=1) - 1
        bos_output1 = outputs.hidden_states[-1]
        bos_output1 = bos_output1.requires_grad_(not self.only_classifier)
        bos_output = torch.stack([bos_output1[i, idx] for i, idx in enumerate(index_c)])
        logits = self.classifier(bos_output.bfloat16())
        new_output = {"loss": self.loss_fn(logits, labels), "logits": logits}
        return new_output



In [None]:
def load_tokenizer_and_model(model_name, baseline_only):
    if model_name == "BERT":
        tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
        model = BertModel.from_pretrained("bert-base-cased",device_map = 'cuda')
        model = bertClassifier(num_labels=2,only_classifier = baseline_only,model=model)
        model = model.bfloat16()
    elif model_name == "Llama2-7B":
        from peft import get_peft_model, LoraConfig
        tokenizer = AutoTokenizer.from_pretrained("NousResearch/llama-2-7b-chat-hf",padding_side='right')
        model = AutoModelForCausalLM.from_pretrained("NousResearch/llama-2-7b-chat-hf",device_map = 'cuda')
        model.config.use_cache = False
        model.config.output_hidden_states = True
        model = LlamaClassifier(num_labels=2,only_classifier = baseline_only,model=model)
        model = model.bfloat16()
    else:
        raise ValueError('In this project we only finetuned "BERT" or "Llama2-7B" models')
    print("load "+model_name)
    return model, tokenizer


model, tokenizer = load_tokenizer_and_model(model_name, baseline_only)

In [None]:
## This code block directly reads training and testing dataset from the github data source

twitter_training = pd.read_json('https://raw.githubusercontent.com/EducationalTestingService/sarcasm/master/twitter/sarcasm_detection_shared_task_twitter_training.jsonl',lines=True)
twitter_testing = pd.read_json('https://raw.githubusercontent.com/EducationalTestingService/sarcasm/master/twitter/sarcasm_detection_shared_task_twitter_testing.jsonl',lines=True)
reddit_training = pd.read_json('https://raw.githubusercontent.com/EducationalTestingService/sarcasm/master/reddit/sarcasm_detection_shared_task_reddit_training.jsonl',lines=True)
reddit_testing = pd.read_json('https://raw.githubusercontent.com/EducationalTestingService/sarcasm/master/reddit/sarcasm_detection_shared_task_reddit_testing.jsonl',lines=True)
combined_training, combined_testing  = pd.concat([reddit_training, twitter_training], ignore_index=True), pd.concat([reddit_testing, twitter_testing], ignore_index=True)

def prepare_df(df):
    labels = {'NOT_SARCASM': 0, 'SARCASM': 1}
    df['label'] = df['label'].map(labels)
    df['full_seq'] = "context: " + df['context'].astype(str) + " response: " + df['response'].astype(str)
    df.drop(columns=['context', 'response'], inplace=True)
    df = df.to_dict(orient='list')
    df = Dataset.from_dict(df)
    encoded_input = tokenizer.batch_encode_plus( df["full_seq"], return_tensors='pt', padding=True, truncation=True,max_length=512)
    input_ids = encoded_input['input_ids']
    attention_masks = encoded_input['attention_mask']
    df = {"input_ids" : encoded_input["input_ids"], "attention_mask" : encoded_input["attention_mask"], "labels" : torch.tensor(df["label"])}  
    df = Dataset.from_dict(df)
    return df

combined_training, combined_testing = prepare_df(combined_training), prepare_df(combined_testing)



In [None]:
## This code block initializes training set up and hyperparameters

if model_name = "BERT":
    lr = 1e-3 if baseline_only else 2e-5
else:
    lr = 1e-4 if baseline_only else 5e-5

metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)


training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=10,              # number of training epochs
    optim = "paged_adamw_32bit",
    per_device_train_batch_size=5,  # batch size for training
    warmup_steps=100,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',           # directory for storing logs
    evaluation_strategy='steps',
    learning_rate=lr,
    max_grad_norm = 0.5, 
    gradient_accumulation_steps = 5, 
    logging_steps=50)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=combined_training,  
    eval_dataset=combined_testing, 
    compute_metrics=compute_metrics,)
print("start training")
trainer.train()
torch.save({'model': model1.state_dict()}, model_name+'_'+baseline_only+'_.pth')

In [None]:
import matplotlib.pyplot as plt

hist = trainer.state.log_history
k = int((len(hist)-1)/2)
train_loss = [hist[2*t]["loss"] for t in range(1,k)]
eval_loss = [hist[2*t+1]["eval_loss"] for t in range(1,k)]
eval_accuracy = [hist[2*t+1]["eval_accuracy"] for t in range(1,k)]

plt.figure(figsize=(16, 5))
# Subplot 1: Loss
plt.subplot(1, 2, 1)  # 1 row, 2 columns, first subplot
plt.plot(train_loss, label='Training Loss', color='blue')
plt.plot(eval_loss, label='Eval/Test Loss', color='red')
plt.title(model_name+'_'+baseline_only+'_Training and Eval/Test Loss')
plt.xlabel('Stages')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.ylim(0.4, 0.8) 

# Subplot 2: Accuracy
plt.subplot(1, 2, 2)  # 1 row, 2 columns, second subplot
plt.plot(eval_accuracy, label='Eval/Test Accuracy', color='green')
plt.title(model_name+'_'+baseline_only+'_Eval/Test Accuracy')
plt.xlabel('Stages')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.ylim(0.5, 0.8) 

plt.savefig(model_name+'_training_error_'+baseline_only+'.png', dpi=500) 
# Show the plots
plt.show()
