In [181]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
checkpoint_path = "checkpoints/toxic-trainer-distilbert/checkpoint-1400"

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint_path)

In [182]:
# Sample input text
text = "Congratulations! You are a total asshole now! I just hope that you die soon!"

# Tokenize the input
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)


In [183]:
tokenizer.decode(inputs["input_ids"][0])

'[CLS] congratulations! you are a total asshole now! i just hope that you die soon! [SEP]'

In [184]:
model.config.id2label

{0: 'toxic',
 1: 'severe_toxic',
 2: 'obscene',
 3: 'threat',
 4: 'insult',
 5: 'identity_hate'}

In [185]:
import torch
# Function to classify text
def classify_text(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = model(**inputs)
    probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
    print(probs)
    if all(prob < 0.5 for prob in probs[0]):
        return "NON TOXIC"
    else:
        toxic_labels = [model.config.id2label[i] for i, prob in enumerate(probs[0]) if prob >= 0.1]
        return f"{', '.join(toxic_labels)}"

In [186]:
classify_text(text)

tensor([[7.4756e-01, 4.5906e-03, 1.5195e-01, 7.3086e-04, 9.3976e-02, 1.1852e-03]])


'toxic, obscene'