In [1]:
# =========================
# Setup
# =========================
!pip install -q transformers datasets accelerate torch torchvision scikit-learn

import torch
import numpy as np
from datasets import load_dataset
from transformers import (
    BertTokenizer,
    BertForSequenceClassification,
    TrainingArguments,
    Trainer,
)
from sklearn.metrics import accuracy_score, f1_score

# =========================
# Load dataset (Civil Comments)
# =========================
dataset = load_dataset("civil_comments")

# Reset any stale formatting
dataset.reset_format()

# =========================
# Label definitions
# =========================
LABELS = [
    "toxicity",
    "hate_speech",
    "political",
    "adult",
    "spam",
    "safe",
]

def map_labels(example):
    example["labels"] = [
        float(example["toxicity"] > 0.5),          # toxicity
        float(example["identity_attack"] > 0.5),   # hate_speech
        0.0,                                       # political (not in dataset)
        float(example["sexual_explicit"] > 0.5),   # adult
        0.0,                                       # spam (not in dataset)
        float(example["toxicity"] <= 0.5),         # safe
    ]
    return example

dataset = dataset.map(map_labels, batched=False)

# =========================
# Subsample for Colab
# =========================
dataset["train"] = dataset["train"].shuffle(seed=42).select(range(160_000))
dataset["test"] = dataset["test"].shuffle(seed=42).select(range(20_000))

# =========================
# Tokenization
# =========================
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def tokenize(batch):
    return tokenizer(
        batch["text"],
        truncation=True,
        padding="max_length",
        max_length=128,
    )

dataset = dataset.map(tokenize, batched=True)
dataset.set_format("torch")

# =========================
# Model
# =========================
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=len(LABELS),
    problem_type="multi_label_classification",
)

# =========================
# Training arguments (NEW API)
# =========================
training_args = TrainingArguments(
    output_dir="./brand_safety_model",
    eval_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    num_train_epochs=2,
    logging_steps=100,
    save_strategy="epoch",
    load_best_model_at_end=True,
    report_to="none",
)

# =========================
# Trainer (NEW API)
# =========================
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
)

# =========================
# Train
# =========================
trainer.train()

# =========================
# Evaluation
# =========================
preds = trainer.predict(dataset["test"])

y_true = preds.label_ids
y_pred = (preds.predictions > 0).astype(int)

accuracy = accuracy_score(y_true.flatten(), y_pred.flatten())
f1 = f1_score(y_true, y_pred, average="macro")

print("Accuracy:", accuracy)
print("F1 score:", f1)


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00002.parquet:   0%|          | 0.00/194M [00:00<?, ?B/s]

data/train-00001-of-00002.parquet:   0%|          | 0.00/187M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/20.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1804874 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/97320 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/97320 [00:00<?, ? examples/s]

Map:   0%|          | 0/1804874 [00:00<?, ? examples/s]

Map:   0%|          | 0/97320 [00:00<?, ? examples/s]

Map:   0%|          | 0/97320 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Map:   0%|          | 0/160000 [00:00<?, ? examples/s]

Map:   0%|          | 0/97320 [00:00<?, ? examples/s]

Map:   0%|          | 0/20000 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertForSequenceClassification LOAD REPORT from: bert-base-uncased
Key                                        | Status     | 
-------------------------------------------+------------+-
cls.predictions.transform.dense.bias       | UNEXPECTED | 
cls.predictions.transform.LayerNorm.weight | UNEXPECTED | 
cls.seq_relationship.bias                  | UNEXPECTED | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED | 
cls.seq_relationship.weight                | UNEXPECTED | 
cls.predictions.bias                       | UNEXPECTED | 
cls.predictions.transform.dense.weight     | UNEXPECTED | 
classifier.bias                            | MISSING    | 
classifier.weight                          | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.


Epoch,Training Loss,Validation Loss
1,0.036256,0.037884
2,0.034424,0.038966


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

There were missing keys in the checkpoint model loaded: ['bert.embeddings.LayerNorm.weight', 'bert.embeddings.LayerNorm.bias', 'bert.encoder.layer.0.attention.output.LayerNorm.weight', 'bert.encoder.layer.0.attention.output.LayerNorm.bias', 'bert.encoder.layer.0.output.LayerNorm.weight', 'bert.encoder.layer.0.output.LayerNorm.bias', 'bert.encoder.layer.1.attention.output.LayerNorm.weight', 'bert.encoder.layer.1.attention.output.LayerNorm.bias', 'bert.encoder.layer.1.output.LayerNorm.weight', 'bert.encoder.layer.1.output.LayerNorm.bias', 'bert.encoder.layer.2.attention.output.LayerNorm.weight', 'bert.encoder.layer.2.attention.output.LayerNorm.bias', 'bert.encoder.layer.2.output.LayerNorm.weight', 'bert.encoder.layer.2.output.LayerNorm.bias', 'bert.encoder.layer.3.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.attention.output.LayerNorm.bias', 'bert.encoder.layer.3.output.LayerNorm.weight', 'bert.encoder.layer.3.output.LayerNorm.bias', 'bert.encoder.layer.4.attention.output.La

Accuracy: 0.9870916666666667
F1 score: 0.2799533408095713


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [5]:
# =========================
# SHAP Explainability (CPU-safe)
# =========================
!pip install -q shap

import shap

# Move model to CPU for SHAP
model_cpu = model.to("cpu")
model_cpu.eval()

def shap_predict(texts):
    if isinstance(texts, str):
        texts = [texts]
    texts = [str(t) for t in texts]

    inputs = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=128,
    )

    # Ensure inputs are on CPU
    inputs = {k: v.to("cpu") for k, v in inputs.items()}

    with torch.no_grad():
        logits = model_cpu(**inputs).logits

    return logits.sigmoid().numpy()

# Proper SHAP text masker
masker = shap.maskers.Text(tokenizer)

explainer = shap.Explainer(
    shap_predict,
    masker,
    output_names=LABELS,
)

# Example explanation
sample_text = "This ad is disgusting and offensive"
shap_values = explainer([sample_text])

shap.plots.text(shap_values[0])


In [6]:
# =========================
# CLIP Image Safety
# =========================
from transformers import CLIPProcessor, CLIPModel
from PIL import Image

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

IMAGE_LABELS = [
    "safe advertisement",
    "violent content",
    "adult content",
    "spam advertisement",
    "hate symbols",
]

def analyze_image(image: Image.Image):
    inputs = clip_processor(
        text=IMAGE_LABELS,
        images=image,
        return_tensors="pt",
        padding=True,
    )
    with torch.no_grad():
        probs = clip_model(**inputs).logits_per_image.softmax(dim=1)[0]

    return dict(zip(IMAGE_LABELS, probs.tolist()))

# Example usage
# image = Image.open("example.jpg")
# analyze_image(image)


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/398 [00:00<?, ?it/s]

CLIPModel LOAD REPORT from: openai/clip-vit-base-patch32
Key                                  | Status     |  | 
-------------------------------------+------------+--+-
text_model.embeddings.position_ids   | UNEXPECTED |  | 
vision_model.embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

The image processor of type `CLIPImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. 


tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

In [8]:
# =========================
# FastAPI-style Inference Layer
# =========================
from typing import Dict
def predict_proba(texts, model_ref=None, device=None):
    """
    Returns sigmoid probabilities for each label.
    Safe for CPU or GPU depending on model_ref/device.
    """
    if isinstance(texts, str):
        texts = [texts]
    texts = [str(t) for t in texts]

    if model_ref is None:
        model_ref = model
    if device is None:
        device = next(model_ref.parameters()).device

    inputs = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=128,
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        logits = model_ref(**inputs).logits

    return logits.sigmoid().cpu().numpy()

def policy_decision(scores: Dict[str, float]) -> str:
    if scores["toxicity"] > 0.8 or scores["hate_speech"] > 0.7:
        return "REJECT"
    if scores["toxicity"] > 0.4:
        return "REVIEW"
    return "APPROVE"

def analyze_text_api(text: str):
    probs = predict_proba([text])[0]

    scores = dict(zip(LABELS, probs.tolist()))

    return {
        "verdict": policy_decision(scores),
        "risk_score": max(scores.values()),
        "categories": scores,
    }

# Example call
analyze_text_api("You are a disgusting person and should be banned")


{'verdict': 'REJECT',
 'risk_score': 0.9150559306144714,
 'categories': {'toxicity': 0.9150559306144714,
  'hate_speech': 0.004216074477881193,
  'political': 0.00021889600611757487,
  'adult': 0.0021778480149805546,
  'spam': 0.0002219238376710564,
  'safe': 0.08615574240684509}}

In [9]:
model.save_pretrained("brand_safety_model")
tokenizer.save_pretrained("brand_safety_model")


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

('brand_safety_model/tokenizer_config.json',
 'brand_safety_model/tokenizer.json')

In [10]:
!zip -r brand_safety_model.zip brand_safety_model


  adding: brand_safety_model/ (stored 0%)
  adding: brand_safety_model/checkpoint-20000/ (stored 0%)
  adding: brand_safety_model/checkpoint-20000/config.json (deflated 56%)
  adding: brand_safety_model/checkpoint-20000/trainer_state.json (deflated 80%)
  adding: brand_safety_model/checkpoint-20000/model.safetensors (deflated 7%)
  adding: brand_safety_model/checkpoint-20000/scheduler.pt (deflated 61%)
  adding: brand_safety_model/checkpoint-20000/training_args.bin (deflated 53%)
  adding: brand_safety_model/checkpoint-20000/rng_state.pth (deflated 26%)
  adding: brand_safety_model/checkpoint-20000/optimizer.pt (deflated 13%)
  adding: brand_safety_model/config.json (deflated 56%)
  adding: brand_safety_model/model.safetensors (deflated 7%)
  adding: brand_safety_model/checkpoint-10000/ (stored 0%)
  adding: brand_safety_model/checkpoint-10000/config.json (deflated 56%)
  adding: brand_safety_model/checkpoint-10000/trainer_state.json (deflated 79%)
  adding: brand_safety_model/checkpoi