# Benchmarking [gemma-2-2b-it](https://huggingface.co/google/gemma-2-2b-it)

## Libraries

In [2]:
# --- Libraries ---
import pandas as pd
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
from pathlib import Path
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from rich.console import Console
from rich.table import Table
import warnings
from tqdm import tqdm
import torch

# Suppress warnings for a cleaner output
warnings.filterwarnings("ignore")

from rich.panel import Panel
from rich.text import Text

# --- Performance Settings ---
# Set to 'high' for better performance on Ampere GPUs as suggested by the warning.
torch.set_float32_matmul_precision('high')


## Global variables

In [3]:
ROOT = Path("../..")
DATA_DIR = ROOT / "data"
BENCHMARK_PATH = DATA_DIR / "benchmark_jigsaw" / "benchmark_jigsaw.csv"
output_path = DATA_DIR / "benchmark_jigsaw" / "eng-google-gemma-2-2b-it.csv"
console = Console()
system_prompt = (Path(".") / "API_SYSTEM_PROMPT.txt").read_text().strip()
prompt = (Path(".") / "API_PROMPT.txt").read_text().strip()

In [4]:
os.environ["HTTP_PROXY"] = "socks5h://127.0.0.1:1080"
os.environ["HTTPS_PROXY"] = "socks5h://127.0.0.1:1080"

In [5]:
torch.set_float32_matmul_precision('high')

## Load dataset

In [6]:
# --- Load Dataset ---
console.print(Panel("[bold cyan]Step 1: Loading Dataset[/bold cyan]"))
try:
    df = pd.read_csv(BENCHMARK_PATH, encoding="utf-8")
    df = df.dropna(subset=["content", "label"])
    df["label"] = df["label"].astype(int)

    # Display dataset info
    label_counts = df["label"].value_counts().reset_index()
    label_counts.columns = ["label", "count"]
    table = Table(title="Dataset Overview", show_lines=True)
    table.add_column("Description", justify="center", style="cyan")
    table.add_column("Value", justify="center", style="yellow")
    table.add_row("Total Rows", str(len(df)))
    for _, row in label_counts.iterrows():
        label_name = "Non-Toxic (0)" if row["label"] == 0 else "Toxic (1)"
        table.add_row(f"Label: {label_name}", str(row["count"]))
    console.print(table, justify="left")
except FileNotFoundError:
    console.print(f"[bold red]Error: The file was not found at {BENCHMARK_PATH}. Please check the path.[/bold red]")
    exit()


## Load model

In [7]:
# --- Load Model and Create Pipeline ---
console.print(Panel("[bold cyan]Step 2: Loading Model and Creating Inference Pipeline[/bold cyan]"))
model_name = "google/gemma-2-2b-it"

try:
    # Using a pipeline is much more efficient for inference as it handles batching automatically.
    pipe = pipeline(
        "text-generation",
        model=model_name,
        model_kwargs={"torch_dtype": torch.bfloat16},
        device=0, # Use device=0 to force it to the first GPU
        trust_remote_code=True
    )
    console.print(f"[green]Successfully created pipeline for model: {model_name}[/green]")
except Exception as e:
    console.print(f"[bold red]Error creating the pipeline: {e}[/bold red]")
    console.print("[yellow]Please ensure you have accepted the model's license on Hugging Face and are logged in via `huggingface-cli login`.[/yellow]")
    exit()


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0


## Define prediction function

In [10]:
# --- Run Prediction (Optimized with Pipeline) ---
console.print(Panel("[bold cyan]Step 3: Running Predictions in Batches[/bold cyan]"))

def generate_prompt_text(text):
    input_text = system_prompt + prompt + f"« {text} »" + "\n Is this message toxic ?\n"
    chat = [{"role": "user", "content": input_text}]
    # The tokenizer's chat template adds the special tokens (<bos>, etc.)
    return pipe.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

## Run prediction

In [11]:
# Generate all prompts first
prompts = [generate_prompt_text(text) for text in df["original_text"]]

In [12]:
# Use the pipeline to predict in batches (much faster)
# The pipeline will show its own progress bar if tqdm is installed.
# We can set a batch size to control memory usage.
raw_outputs = []
# Using tqdm for a custom progress bar over the pipeline iterator
# use tqdm to show progress
from tqdm.rich import tqdm
for output in tqdm(pipe(prompts, max_new_tokens=10, batch_size=16), total=len(prompts), desc="Predicting Toxicity"):
    raw_outputs.append(output)


Output()

In [14]:
# Extract the generated part from the full response
def extract_classification(output):
    # The pipeline output is a list of dictionaries, e.g., [{'generated_text': '...'}]
    full_text = output[0]['generated_text']
    # The logic to extract the final word remains the same
    return full_text.split("Is this message toxic ?")[-1].strip()

df["toxicity_score_raw"] = [extract_classification(output) for output in raw_outputs]


In [15]:
# Process the raw predictions to get a clean label
def clean_prediction(raw_score):
    raw_score = str(raw_score).lower()
    if "non-toxic" in raw_score or "not toxic" in raw_score or "non toxic" in raw_score:
        return "non-toxique"
    if "toxic" in raw_score:
        return "toxique"
    return "unknown" # Handle cases where the model output is not clear

df["toxicity_score"] = df["toxicity_score_raw"].apply(clean_prediction)

In [16]:
for i, row in df.sample(5, random_state=42).iterrows():
    content = Text(row['content'], style="bold")
    toxicity = f"[yellow]Toxicity Score:[/yellow] [bold]{row['toxicity_score']}[/bold]"
    label = f"[cyan]Label:[/cyan] [bold]{row['label']}[/bold]"
    panel = Panel.fit(
        f"{content}\n\n{toxicity}\n{label}",
        title=f"Exemple {i+1}",
        border_style="magenta"
    )
    console.print(panel)

## Metrics & Report        

| Metric                     | Formula                                           | Interpretation                                                                                                       |
| -------------------------- | ------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------- |
| **Precision**              | `TP / (TP + FP)`                                  | Of the samples predicted **toxic**, how many were **actually toxic**? <br>→ High precision = **low false positives** |
| **Recall** *(Sensitivity)* | `TP / (TP + FN)`                                  | Of the **actual toxic** samples, how many did we **correctly identify**? <br>→ High recall = **low false negatives** |
| **F1-score**               | `2 * (Precision * Recall) / (Precision + Recall)` | Harmonic mean of precision and recall. <br>→ Best when **balance** is needed                                         |
| **Accuracy**               | `(TP + TN) / (TP + TN + FP + FN)`                 | Fraction of all correct predictions (toxic and non-toxic). <br>→ Can be misleading on imbalanced data                |
| **ROC AUC**                | Area under the ROC Curve                          | Measures the **ranking ability** of the classifier. <br>→ Higher = better separation of toxic vs. non-toxic          |


In [17]:
console.print(Panel("[bold cyan]Step 4: Calculating Metrics & Report[/bold cyan]"))

# Filter out unknown predictions before calculating metrics
eval_df = df[df['toxicity_score'] != 'unknown'].copy()
y_true = eval_df["label"]
y_pred = (eval_df["toxicity_score"].apply(lambda x: x.lower() == "toxique")).astype(int)

# --- Classification Report ---
report = classification_report(y_true, y_pred, digits=3, output_dict=True)
report_table = Table(title="Classification Report (google/gemma-2-2b-it)", show_lines=True)
report_table.add_column("Class", style="cyan", justify="center")
report_table.add_column("Precision", justify="center")
report_table.add_column("Recall", justify="center")
report_table.add_column("F1-score", justify="center")
report_table.add_column("Support", justify="center")

for label, metrics in report.items():
    if label in ["accuracy", "macro avg", "weighted avg"]:
        continue
    class_name = "Non-Toxic (0)" if label == "0" else "Toxic (1)"
    report_table.add_row(
        class_name,
        f"{metrics['precision']:.3f}",
        f"{metrics['recall']:.3f}",
        f"{metrics['f1-score']:.3f}",
        f"{int(metrics['support'])}"
    )

report_table.add_section()
report_table.add_row(
    "[bold yellow]Accuracy[/bold yellow]",
    "-",
    "-",
    f"[bold]{report['accuracy']:.3f}[/bold]",
    f"{int(report['macro avg']['support'])}"
)
console.print(report_table)

# --- Confusion Matrix ---
cm = confusion_matrix(y_true, y_pred)
cm_table = Table(title="Confusion Matrix", show_lines=True)
cm_table.add_column(" ", style="bold")
cm_table.add_column("Predicted Non-Toxic", justify="center")
cm_table.add_column("Predicted Toxic", justify="center")
cm_table.add_row("[cyan]Actual Non-Toxic[/cyan]", f"[green]{cm[0, 0]}[/green]", f"[red]{cm[0, 1]}[/red]")
cm_table.add_row("[cyan]Actual Toxic[/cyan]", f"[red]{cm[1, 0]}[/red]", f"[green]{cm[1, 1]}[/green]")
console.print(cm_table)

# --- ROC AUC Score ---
try:
    auc = roc_auc_score(y_true, y_pred)
    console.print(Panel(f"[bold green]ROC AUC Score: {auc:.3f}[/bold green]", title="ROC AUC"))
except Exception as e:
    console.print(Panel(f"[bold red]Could not calculate ROC AUC: {e}[/bold red]", title="ROC AUC"))

# --- Save the predictions ---
console.print(Panel("[bold cyan]Step 5: Saving Predictions[/bold cyan]"))
df['prediction'] = y_pred
try:
    df.to_csv(output_path, index=False, encoding="utf-8")
    console.print(f"[green]Predictions successfully saved to:[/] [bold]{output_path}[/bold]")
except Exception as e:
    console.print(f"[red]Failed to save predictions: {e}[/red]")