# 1. Load Libraries

In [2]:
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
from huggingface_hub import hf_hub_download
import json
import onnxruntime as rt

In [3]:
reddit_df = pd.read_csv('./data/combined_cleaned_500k.csv',  lineterminator='\n', encoding='utf8')

In [4]:
reddit_df.head()

Unnamed: 0,text,timestamp,username,link,link_id,parent_id,id,subreddit_id,moderation\r
0,i think most singaporeans dont give a damn who...,2020-04-11 15:49:23,invigo79,/r/singapore/comments/fz7vtl/im_quite_interest...,t3_fz7vtl,t3_fz7vtl,fn3gbrg,t5_2qh8c,"{'removal_reason': None, 'collapsed': False, '..."
1,fair point the secrecy aspect of it slipped my...,2020-04-03 09:59:08,potatetoe_tractor,/r/singapore/comments/fu3axm/government_to_tab...,t3_fu3axm,t1_fmasya5,fmau5k3,t5_2qh8c,"{'removal_reason': None, 'collapsed': False, '..."
2,range,2020-02-15 15:07:03,CrossfittJesus,/r/singapore/comments/f4ac70/what_is_ps_defens...,t3_f4ac70,t3_f4ac70,fhp05xc,t5_2qh8c,"{'removal_reason': None, 'collapsed': False, '..."
3,gt this is binary thinking because you think t...,2020-06-04 07:07:39,nomad80,/r/singapore/comments/gw55cx/notoracism/fsu4fyd/,t3_gw55cx,t1_fsu3dsf,fsu4fyd,t5_2qh8c,"{'removal_reason': None, 'collapsed': False, '..."
4,boo boo poor u lmao,2020-10-31 13:52:12,pirorok,/r/singapore/comments/jl6abo/rsingapore_random...,t3_jl6abo,t1_gap4e9y,gap4vkl,t5_2qh8c,"{'removal_reason': None, 'collapsed': False, '..."


# 2. Test on the first 5 rows of data

In [5]:
small_reddit_df = reddit_df[0:5]

In [6]:
small_reddit_df

Unnamed: 0,text,timestamp,username,link,link_id,parent_id,id,subreddit_id,moderation\r
0,i think most singaporeans dont give a damn who...,2020-04-11 15:49:23,invigo79,/r/singapore/comments/fz7vtl/im_quite_interest...,t3_fz7vtl,t3_fz7vtl,fn3gbrg,t5_2qh8c,"{'removal_reason': None, 'collapsed': False, '..."
1,fair point the secrecy aspect of it slipped my...,2020-04-03 09:59:08,potatetoe_tractor,/r/singapore/comments/fu3axm/government_to_tab...,t3_fu3axm,t1_fmasya5,fmau5k3,t5_2qh8c,"{'removal_reason': None, 'collapsed': False, '..."
2,range,2020-02-15 15:07:03,CrossfittJesus,/r/singapore/comments/f4ac70/what_is_ps_defens...,t3_f4ac70,t3_f4ac70,fhp05xc,t5_2qh8c,"{'removal_reason': None, 'collapsed': False, '..."
3,gt this is binary thinking because you think t...,2020-06-04 07:07:39,nomad80,/r/singapore/comments/gw55cx/notoracism/fsu4fyd/,t3_gw55cx,t1_fsu3dsf,fsu4fyd,t5_2qh8c,"{'removal_reason': None, 'collapsed': False, '..."
4,boo boo poor u lmao,2020-10-31 13:52:12,pirorok,/r/singapore/comments/jl6abo/rsingapore_random...,t3_jl6abo,t1_gap4e9y,gap4vkl,t5_2qh8c,"{'removal_reason': None, 'collapsed': False, '..."


# 3. Model Text Classification

In [10]:
# Download model config
repo_path = "govtech/lionguard-v1"
config_path = hf_hub_download(repo_id=repo_path, filename="config.json")
with open(config_path, 'r') as f:
    config = json.load(f)

In [9]:
#print(config)

{'description': 'Binary classifier on harmful text in Singapore context', 'embedding': {'tokenizer': 'BAAI/bge-large-en-v1.5', 'model': 'BAAI/bge-large-en-v1.5', 'max_length': 512, 'batch_size': 32}, 'classifier': {'binary': {'calibrated': True, 'threshold': {'high_recall': 0.2, 'balanced': 0.5, 'high_precision': 0.8}, 'model_type': 'ridge_classifier', 'model_fp': 'models/lionguard-binary.onnx'}, 'hateful': {'calibrated': False, 'threshold': {'high_recall': -0.341, 'balanced': -0.186, 'high_precision': -0.008}, 'model_type': 'ridge_classifier', 'model_fp': 'models/lionguard-hateful.onnx'}, 'harassment': {'calibrated': False, 'threshold': {'high_recall': -0.571, 'balanced': -0.471, 'high_precision': -0.471}, 'model_type': 'ridge_classifier', 'model_fp': 'models/lionguard-harassment.onnx'}, 'public_harm': {'calibrated': False, 'threshold': {'high_recall': -0.713, 'balanced': -0.632, 'high_precision': -0.576}, 'model_type': 'ridge_classifier', 'model_fp': 'models/lionguard-public_harm.onn

In [11]:
def get_embeddings(device, data):
    # Load the model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config['embedding']['tokenizer'])
    model = AutoModel.from_pretrained(config['embedding']['model'])
    model.eval()
    model.to(device)

    # Generate the embeddings
    batch_size = config['embedding']['batch_size']
    num_batches = int(np.ceil(len(data)/batch_size))
    output = []
    for i in range(num_batches):
        sentences = data[i*batch_size:(i+1)*batch_size]
        encoded_input = tokenizer(sentences, max_length=config['embedding']['max_length'], padding=True, truncation=True, return_tensors='pt')
        encoded_input.to(device)
        with torch.no_grad():
            model_output = model(**encoded_input)
            sentence_embeddings = model_output[0][:, 0]
        sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
        output.extend(sentence_embeddings.cpu().numpy())
    
    return np.array(output)

In [12]:
def predict(batch_text):
    device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
    embeddings = get_embeddings(device, batch_text)
    embeddings_df = pd.DataFrame(embeddings)

    # Prepare input data
    X_input = np.array(embeddings_df, dtype=np.float32)

    # Load the classifiers
    results = {}
    for category, details in config['classifier'].items():
        # Download the classifier from HuggingFace hub
        local_model_fp = hf_hub_download(repo_id=repo_path, filename=config['classifier'][category]['model_fp'])

        # Run the inference
        session = rt.InferenceSession(local_model_fp)
        input_name = session.get_inputs()[0].name
        outputs = session.run(None, {input_name: X_input})

        # If calibrated, return only the prediction for the unsafe class
        if config['classifier'][category]['calibrated']: 
            scores = [output[1] for output in outputs[1]]
        else:
            scores = outputs[1].flatten()
        
        # Generate the predictions depending on the recommended threshold score
        results[category] = {
            'scores': scores,
            'predictions': {
                'high_recall': [1 if score >= config['classifier'][category]['threshold']['high_recall'] else 0 for score in scores],
                'balanced': [1 if score >= config['classifier'][category]['threshold']['balanced'] else 0 for score in scores],
                'high_precision': [1 if score >= config['classifier'][category]['threshold']['high_precision'] else 0 for score in scores]
            }
        }

    return results

In [14]:
# Extract the text data from the DataFrame (assuming the text is in a column named 'text')
batch_text = small_reddit_df['text'].tolist()

# Generate the scores and predictions
results = predict(batch_text)

# Prepare results for DataFrame
output_data = []
for i in range(len(batch_text)):
    output_row = {
        'Text': batch_text[i],
    }
    for category in results.keys():
        output_row[f'{category} Score'] = results[category]['scores'][i]
        output_row[f'{category} HR'] = results[category]['predictions']['high_recall'][i]
        output_row[f'{category} B'] = results[category]['predictions']['balanced'][i]
        output_row[f'{category} HP'] = results[category]['predictions']['high_precision'][i]
    output_data.append(output_row)

# Create a DataFrame from the results
small_results_df = pd.DataFrame(output_data)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [15]:
print(small_results_df)

                                                Text  binary Score  binary HR  \
0  i think most singaporeans dont give a damn who...      0.008994          0   
1  fair point the secrecy aspect of it slipped my...      0.000000          0   
2                                              range      0.004988          0   
3  gt this is binary thinking because you think t...      1.000000          1   
4                                boo boo poor u lmao      1.000000          1   

   binary B  binary HP  hateful Score  hateful HR  hateful B  hateful HP  \
0         0          0      -0.582897           0          0           0   
1         0          0      -1.116735           0          0           0   
2         0          0      -1.027191           0          0           0   
3         1          1      -0.419287           0          0           0   
4         1          1      -0.952112           0          0           0   

   harassment Score  ...  sexual B  sexual HP  toxic Sco