In [1]:
import os
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import f1_score, classification_report

import torch
#from torch.utils.data import Dataset
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, pipeline
from peft import LoraConfig
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

  from .autonotebook import tqdm as notebook_tqdm


## Load Data & Create Prompt

**Prompt Example**:

You are an expert in hate speech detection. Offensive tweets are defined as tweets containing profane words, sarcastic remarks, insults, slanders or slurs. These can have a potentially harmful effect on a given target. Classify the following input tweet as Offensive or Non-Offensive.

`###` Input: <tweet>

`###` Response: Offensive

In [2]:
system_prompt = "You are an expert in hate speech detection. Offensive tweets are defined as tweets containing profane words, sarcastic remarks, insults, slanders or slurs. These can have a potentially harmful effect on a given target. Classify the following input tweet as Offensive or Non-Offensive."
label_map = {1: "Offensive", 0: "Non-Offensive"}

In [3]:
def prepare_prompt(row, train=True):
    # Data Format -- https://huggingface.co/datasets/vicgalle/alpaca-gpt4?row=0
    prompt = system_prompt + "\n\n ### Input: " + row["tweet"] + "\n\n ### Response: "
    if train:
         prompt = prompt + label_map[row["offense"]] # Add label
    return prompt

## Inference

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

In [5]:
trained_checkpoint = "results_bloomz/checkpoint-5109/" # "results/checkpoint-5109/"  llama-test/

tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)
model = AutoModelForCausalLM.from_pretrained(trained_checkpoint)
model = model.to(device)
model.eval()

BloomForCausalLM(
  (transformer): BloomModel(
    (word_embeddings): Embedding(250880, 4096)
    (word_embeddings_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
    (h): ModuleList(
      (0-29): 30 x BloomBlock(
        (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (self_attention): BloomAttention(
          (query_key_value): Linear(
            in_features=4096, out_features=12288, bias=True
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.1, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=4096, out_features=64, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=64, out_features=12288, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
          )
          (dense): Linear(in_features=4096, out_features=4096, bias=True)

In [6]:
test_df = pd.read_csv("data/cm_hate_combined.csv")

In [7]:
test_df["text"] = test_df.apply(lambda row: prepare_prompt(row, train=False), axis=1)
test_df["text"].values[:2]

array(['You are an expert in hate speech detection. Offensive tweets are defined as tweets containing profane words, sarcastic remarks, insults, slanders or slurs. These can have a potentially harmful effect on a given target. Classify the following input tweet as Offensive or Non-Offensive.\n\n ### Input: @user @user @user @user @user Matlab sirf ladki ke character baat ithae tab bologe 0ar ladke ke upar wo bhi khud ke fd se karoge to chup rahoge.\n\n ### Response: ',
       'You are an expert in hate speech detection. Offensive tweets are defined as tweets containing profane words, sarcastic remarks, insults, slanders or slurs. These can have a potentially harmful effect on a given target. Classify the following input tweet as Offensive or Non-Offensive.\n\n ### Input: Pehle main bahut loyal tha tab mujhse koi ladki nahi pat rahi thi phir ek din....\n \n\n Phir kya abhi bhi koi nahi pat rahi\n (Kyuki abhi bhi loyal hi hu)\n\n ### Response: '],
      dtype=object)

In [8]:
def inference_responses(df):
    model.eval()
    responses = []

    for i in tqdm(range(len(df))):
        inputs = tokenizer(df["text"][i], padding=True, truncation=True, max_length=300, 
                           return_tensors="pt").to(device)
        with torch.no_grad():
            generate_ids = model.generate(inputs.input_ids, max_length=300)

        response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, 
                                          clean_up_tokenization_spaces=False)[0]
        responses.append(response)
    return responses

In [9]:
def get_labels(responses):
    labels = []
    response_trimmed = 0
    label_absent = 0

    for response in responses:
        splitted = response.split("### Response: ")
        if len(splitted) == 1:
            #print(response, "\n")
            response_trimmed += 1
            label = 0 #-1
            
        else:
            if "Non-Offensive" in splitted[1][:15]:
                label = 0
            elif "Offensive" in splitted[1][:15]:
                label = 1
            else:
                label_absent += 1
                label = 0 # Default majority class
                
        labels.append(label)

    print(f"{response_trimmed} responses trimmed due to max_length")
    print(f"{label_absent} labels absent \n")
    return labels

In [13]:
def print_metrics(labels, df):
    print("F1 score = ", f1_score(df['offense'].tolist(), labels))
    print(classification_report(df['offense'].tolist(), labels, 
                                target_names=["Non-Offensive (0)", "Offensive (1)"], digits=4))

In [14]:
test_responses = inference_responses(test_df)
test_labels = get_labels(test_responses)
print_metrics(test_labels, test_df)

100%|██████████| 641/641 [02:50<00:00,  3.76it/s]

0 responses trimmed due to max_length
0 labels absent 

F1 score =  0.593103448275862
                   precision    recall  f1-score   support

Non-Offensive (0)     0.6853    0.6436    0.6638       362
    Offensive (1)     0.5714    0.6165    0.5931       279

         accuracy                         0.6318       641
        macro avg     0.6284    0.6301    0.6285       641
     weighted avg     0.6357    0.6318    0.6330       641






In [15]:
with open('data/predictions/bloomz-ft-completion_custom_data.pickle', 'wb') as f:
    pickle.dump(test_labels, f, protocol=pickle.HIGHEST_PROTOCOL)