# Keyword Spam Moderation – Colab Notebook

This notebook implements the workflow described in `docs/plan_collab.md`. Run the
cells top-to-bottom on Google Colab (GPU runtime; A100 preferred, T4 works with
smaller batches).

> **Before you start**
> 1. Upload `data/train_set.tsv` and `data/test_set.tsv` somewhere accessible
>    (Drive or the Colab file system).
> 2. Provide a Hugging Face token if the chosen model is gated.
> 3. Optionally enable the cache export flags later if you want to persist
>    downloaded images to Drive or GCS.

## Executive Introduction

This notebook delivers an end-to-end multimodal moderation workflow for detecting keyword/brand spamming in marketplace listings. We predict a strict JSON object with three fields: `is_spam` (boolean), `confidence` (0-1), and a concise `reason`.

Why this approach:
- Problem: keyword/brand stuffing erodes trust and hurts user experience.
- Target: a calibrated binary signal (is_spam) with confidence that maps to policy actions (keep/review/demote).
- Model: Qwen/Qwen3-VL-2B-Instruct via Unsloth QLoRA fits Colab GPUs while retaining strong multimodal reasoning.
- Simplicity: Transformers inference keeps runs reproducible without servers.

We compare against a leakage-free TF-IDF baseline, fine-tune the VLM, evaluate and sweep thresholds, and package artifacts.

## 1. Environment Setup

In [None]:
#@title Install dependencies (pinned to avoid Colab drift)
%%capture
!pip install -U "transformers==4.44.2" "accelerate==0.34.2" "peft==0.12.0" "datasets==2.20.0" \
               unsloth bitsandbytes pillow pandas scikit-learn pyarrow tqdm trl \
               google-cloud-storage ipywidgets seaborn requests

In [None]:
#@title Clone repository and set paths
from pathlib import Path
import os
import subprocess
import sys

REPO_URL = 'https://github.com/rostandk/ml-assessment.git'
IS_COLAB = 'google.colab' in sys.modules
DEFAULT_REPO_DIR = Path('/content/ml-assessment')
LOCAL_REPO_DIR = Path.cwd()

def ensure_repo(url: str, target: Path) -> Path:
    """Clone *url* to *target* if missing, otherwise fast-forward."""
    if target.exists():
        subprocess.run(['git', '-C', str(target), 'pull', '--ff-only'], check=True)
    else:
        target.parent.mkdir(parents=True, exist_ok=True)
        subprocess.run(['git', 'clone', '--depth', '1', url, str(target)], check=True)
    return target.resolve()

REPO_DIR = ensure_repo(REPO_URL, DEFAULT_REPO_DIR) if IS_COLAB else LOCAL_REPO_DIR

if IS_COLAB:
    os.chdir(REPO_DIR)

if str(REPO_DIR) not in sys.path:
    sys.path.insert(0, str(REPO_DIR))

DATA_DIR = REPO_DIR / 'data'


In [None]:
#@title (Optional) Mount Google Drive
from google.colab import drive

drive.mount('/content/drive')

In [None]:
#@title Imports & configuration
import json
import os
import random
from pathlib import Path
from typing import Any, Iterable

try:
    from google.colab import output
    output.clear(wait=True)
except ImportError:
    pass

import numpy as np
import pandas as pd
from IPython.display import HTML, display
import seaborn as sns
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, precision_recall_fscore_support
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from datasets import load_dataset
from unsloth import FastVisionModel, UnslothVisionDataCollator
from transformers import EarlyStoppingCallback, TrainingArguments
from trl import SFTTrainer

import torch

import sys
import pathlib

REPO_DIR = globals().get('REPO_DIR', pathlib.Path().resolve())
if str(REPO_DIR) not in sys.path:
    sys.path.insert(0, str(REPO_DIR))

import utils

RNG_SEED = 42
random.seed(RNG_SEED)
np.random.seed(RNG_SEED)
torch.manual_seed(RNG_SEED)
torch.cuda.manual_seed_all(RNG_SEED)

if 'google.colab' in sys.modules:
    DEFAULT_CACHE_DIR = Path('/content/cache/images')
    DEFAULT_ARTIFACTS_DIR = Path('/content/artifacts')
    DEFAULT_DRIVE_CACHE_PATH = Path('/content/drive/MyDrive/keyword_spam_vlm/cache/images')
else:
    DEFAULT_CACHE_DIR = Path('cache/images')
    DEFAULT_ARTIFACTS_DIR = Path('artifacts')
    DEFAULT_DRIVE_CACHE_PATH = Path.home() / 'keyword_spam_vlm' / 'cache' / 'images'

DATA_DIR = globals().get('DATA_DIR', REPO_DIR / 'data')
TRAIN_TSV_PATH = (DATA_DIR / 'train_set.tsv').resolve()
TEST_TSV_PATH = (DATA_DIR / 'test_set.tsv').resolve()

CONFIG = {
    'model_id': 'Qwen/Qwen3-VL-2B-Instruct',
    'dtype': 'bfloat16' if torch.cuda.is_bf16_supported() else 'float16',
    'epochs': 3,
    'learning_rate': 1e-4,
    'max_seq_len': 1024,
    'warmup_ratio': 0.05,
    'batch_size_t4': 4,
    'grad_accum_t4': 4,
    'batch_size_a100': 8,
    'grad_accum_a100': 2,
    'review_threshold': 0.5,
    'demote_threshold': 0.7,
    'cache_dir': str(DEFAULT_CACHE_DIR),
    'train_tsv': str(TRAIN_TSV_PATH),
    'test_tsv': str(TEST_TSV_PATH),
    'artifacts_dir': str(DEFAULT_ARTIFACTS_DIR),
    'export_cache_to_drive': False,
    'drive_cache_path': str(DEFAULT_DRIVE_CACHE_PATH),
    'export_cache_to_gcs': False,
    'gcs_bucket': 'ml-assesment',
    'gcs_prefix': 'images',
}

Path(CONFIG['cache_dir']).mkdir(parents=True, exist_ok=True)
Path(CONFIG['artifacts_dir']).mkdir(parents=True, exist_ok=True)

GPU_NAME = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'
print(f'Detected accelerator: {GPU_NAME}')


In [None]:
#@title Run plan overview
from IPython.display import HTML

per_device_batch = CONFIG["batch_size_a100"] if "A100" in GPU_NAME else CONFIG["batch_size_t4"]
grad_accum = CONFIG["grad_accum_a100"] if "A100" in GPU_NAME else CONFIG["grad_accum_t4"]
mode_hint = "Full run recommended" if "A100" in GPU_NAME else "Start with a smaller subset (`max_rows`)"

rows = [
    ("Accelerator", GPU_NAME),
    ("Per-device batch size", per_device_batch),
    ("Gradient accumulation", grad_accum),
    ("Learning rate", CONFIG["learning_rate"]),
    ("Epochs", CONFIG["epochs"]),
    ("Max sequence length", CONFIG["max_seq_len"]),
    ("Suggested mode", mode_hint),
]

html = "<table><tbody>" + "".join(
    f"<tr><th style='text-align:left;padding-right:12px;'>{k}</th><td>{v}</td></tr>" for k, v in rows
) + "</tbody></table>"
display(HTML(html))

## 2. Data preparation and Old Way review

In [None]:
#@title Load TSVs, validate schema, compute label confidence
REQUIRED_COLUMNS = ["product_id", "description", "image_url", "label", "yes_count", "no_count"]


def load_dataset_tsv(path: str | Path) -> pd.DataFrame:
    df = pd.read_csv(path, sep="	")
    missing = [col for col in REQUIRED_COLUMNS if col not in df.columns]
    if missing:
        raise ValueError(f"Missing columns: {missing}")
    df = df[REQUIRED_COLUMNS].copy()
    df["product_id"] = df["product_id"].astype(str)
    if df["product_id"].duplicated().any():
        raise ValueError("Duplicate product_id values detected")
    df["label"] = df["label"].astype(int)
    for col in ("yes_count", "no_count"):
        df[col] = df[col].fillna(0).astype(int)
        if (df[col] < 0).any():
            raise ValueError(f"Negative values found in {col}")
    total_votes = df["yes_count"] + df["no_count"]
    with np.errstate(divide="ignore", invalid="ignore"):
        confidence = (df["yes_count"] - df["no_count"]) / np.where(total_votes == 0, np.nan, total_votes)
    df["label_confidence"] = confidence.fillna(0.0)
    return df


train_df = load_dataset_tsv(CONFIG["train_tsv"])
print(f"Loaded {len(train_df)} training rows")

train_split, val_split = train_test_split(
    train_df,
    test_size=0.1,
    stratify=train_df["label"],
    random_state=RNG_SEED,
)
print(f"Train rows: {len(train_split)}, Validation rows: {len(val_split)}")

try:
    test_df = load_dataset_tsv(CONFIG["test_tsv"])
    print(f"Loaded {len(test_df)} test rows")
except FileNotFoundError:
    test_df = pd.DataFrame(columns=REQUIRED_COLUMNS)
    print("Test TSV not found; skipping test evaluation")

train_df.head(3)

## 2. Junior Notebook Review (Old Way)

Common weaknesses observed in the legacy approach:
- Single-modality bias: images are ignored, missing brand/category cues.
- Fragile heuristics: hashtag spam and CTAs (e.g., 'DM for deals') fool simple rules.
- Limited reproducibility: loose splits and few guardrails.
- Low operational tie-in: no calibrated confidence for policy thresholds.

Illustrative examples below (sampled from the training data):

In [None]:
# Illustrative failure taxonomy from the legacy 'Old Way'
import re

def _hashtag_ratio(text: str) -> float:
    tokens = str(text).split()
    return 0.0 if not tokens else sum(1 for t in tokens if t.startswith('#')) / len(tokens)

CTA_RE = re.compile(r"\b(dm|whatsapp|contact|inbox|email|text me)\b", re.I)
BRANDS = {'nike', 'adidas', 'gucci'}

def _cta_present(text: str) -> bool:
    return bool(CTA_RE.search(str(text)))

def _brand_mentioned(text: str) -> bool:
    low = str(text).lower()
    return any(b in low for b in BRANDS)

def _shorten(text: str, n: int = 120) -> str:
    s = str(text).strip().replace('', ' ')
    return s if len(s) <= n else s[: n - 1] + '…'

legacy_rows = []
for row in train_df.itertuples(index=False):
    reason = None
    if _hashtag_ratio(row.description) > 0.2:
        reason = 'hashtag-heavy description'
    elif _cta_present(row.description):
        reason = 'call-to-action present'
    elif _brand_mentioned(row.description):
        reason = 'brand keyword mentioned'
    if reason:
        legacy_rows.append({
            'product_id': row.product_id,
            'label': int(row.label),
            'matched_reason': reason,
            'description_snippet': _shorten(row.description),
        })
    if len(legacy_rows) >= 9:  # keep it compact
        break
pd.DataFrame(legacy_rows)


### Download images on demand

In [None]:
#@title Parallel image download
all_urls = pd.concat([
    train_split["image_url"],
    val_split["image_url"],
    test_df.get("image_url", pd.Series([], dtype=str)),
]).dropna().unique()

print(f"Total unique URLs: {len(all_urls)}")
results = utils.download_images(
    all_urls,
    CONFIG["cache_dir"],
    max_workers=12,
    timeout=20,
    max_retries=4,
)
status_df = pd.DataFrame(results)
status_path = Path(CONFIG["artifacts_dir"]) / "image_download_status.csv"
status_df.to_csv(status_path, index=False)
print(status_df["downloaded"].value_counts())
print(f"Saved download status to {status_path}")

### Optional: export cache to Drive or GCS

In [None]:
#@title Export cache (optional)
from shutil import copytree, ignore_patterns

if CONFIG["export_cache_to_drive"]:
    drive_path = Path(CONFIG["drive_cache_path"])
    if drive_path.exists():
        print(f"Drive path {drive_path} already exists – remove it first if you want a fresh copy")
    else:
        drive_path.parent.mkdir(parents=True, exist_ok=True)
        copytree(CONFIG["cache_dir"], drive_path, ignore=ignore_patterns("*.zip"))
        print(f"Exported cache to {drive_path}")

if CONFIG["export_cache_to_gcs"]:
    utils.sync_images_to_gcs(
        CONFIG["cache_dir"],
        bucket=CONFIG["gcs_bucket"],
        prefix=CONFIG["gcs_prefix"],
        public=True,
    )

## 3. What We Predict

We model `is_spam` (binary) with an associated `confidence` in [0,1]. We compute `label_confidence = (yes - no) / (yes + no)` as a weak indicator of label certainty. Operationally, we use two thresholds over the model's confidence to map predictions into actions: keep, review, and demote. We later sweep thresholds on validation to choose a sensible operating point.

In [None]:
# Class balance and label confidence distribution
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
train_df['label'].value_counts().sort_index().plot(kind='bar', ax=axes[0], title='Class counts (train)')
axes[0].set_xticklabels(['non-spam', 'spam'], rotation=0)
sns.histplot(train_df['label_confidence'], bins=20, ax=axes[1])
axes[1].set_title('Label confidence (train)')
plt.tight_layout(); plt.show()


## 4. Baseline – TF-IDF + logistic regression

In [None]:
#@title Train leakage-free baseline

def run_baseline(train_df: pd.DataFrame, val_df: pd.DataFrame) -> dict[str, Any]:
    vectorizer = TfidfVectorizer(max_features=10000, stop_words="english")
    X_train = vectorizer.fit_transform(train_df["description"])
    X_val = vectorizer.transform(val_df["description"])

    clf = LogisticRegression(max_iter=500, class_weight="balanced")
    clf.fit(X_train, train_df["label"])

    preds = clf.predict(X_val)
    precision, recall, f1, _ = precision_recall_fscore_support(val_df["label"], preds, average="macro", zero_division=0)
    accuracy = accuracy_score(val_df["label"], preds)

    metrics = {
        "accuracy": float(accuracy),
        "macro_precision": float(precision),
        "macro_recall": float(recall),
        "macro_f1": float(f1),
        "classification_report": classification_report(val_df["label"], preds, digits=3),
    }
    (Path(CONFIG["artifacts_dir"]) / "baseline_metrics.json").write_text(json.dumps(metrics, indent=2))
    return metrics


baseline_metrics = run_baseline(train_split, val_split)
print(json.dumps(baseline_metrics, indent=2))

## 4. Build SFT datasets for Unsloth

In [None]:
#@title Construct Unsloth-ready JSONL files
PROMPT_TEMPLATE = "You are a moderator. Respond with JSON containing is_spam (bool), confidence (0-1), reason (short)."


def build_messages(df: pd.DataFrame) -> list[dict[str, Any]]:
    rows: list[dict[str, Any]] = []
    for record in df.to_dict("records"):
        url = record.get("image_url") or ""
        image_path = utils.download_image(url, CONFIG["cache_dir"]) if url else None
        user_content = []
        if image_path is not None and image_path.exists():
            user_content.append({"type": "image", "image": str(image_path)})
        message_text = f"{PROMPT_TEMPLATE}\n\nDescription: {record['description']}"
        if not (image_path is not None and image_path.exists()):
            message_text += "\n\nNote: image unavailable. Base your judgment on the text only."
        user_content.append({"type": "text", "text": message_text})
        assistant_text = json.dumps(
            {
                "is_spam": bool(record["label"]),
                "confidence": 1.0,
                "reason": "Training label",
            },
            ensure_ascii=False,
        )
        rows.append(
            {
                "id": record["product_id"],
                "messages": [
                    {"role": "system", "content": [{"type": "text", "text": "Respond with strict JSON."}]},
                    {"role": "user", "content": user_content},
                    {"role": "assistant", "content": [{"type": "text", "text": assistant_text}]},
                ],
            }
        )
    return rows


sft_dir = Path(CONFIG["artifacts_dir"]) / "sft"
sft_dir.mkdir(parents=True, exist_ok=True)

train_rows = build_messages(train_split)
val_rows = build_messages(val_split)

train_jsonl = sft_dir / "train.jsonl"
val_jsonl = sft_dir / "val.jsonl"
train_jsonl.write_text("".join(json.dumps(r, ensure_ascii=False) for r in train_rows))
val_jsonl.write_text("".join(json.dumps(r, ensure_ascii=False) for r in val_rows))

VAL_TRUE_LABELS = val_split["label"].astype(int).reset_index(drop=True).tolist()
print(f"SFT rows -> train: {len(train_rows)}, val: {len(val_rows)}")

## 5. QLoRA fine-tuning

In [None]:
#@title Fine-tune with Unsloth QLoRA
NUM_GPUS = torch.cuda.device_count()
batch_size = CONFIG['batch_size_a100' if 'A100' in GPU_NAME else 'batch_size_t4']
grad_accum = CONFIG['grad_accum_a100' if 'A100' in GPU_NAME else 'grad_accum_t4']

train_dataset = load_dataset('json', data_files={'train': str(train_jsonl), 'validation': str(val_jsonl)})

model, tokenizer, image_processor = FastVisionModel.from_pretrained(
    CONFIG['model_id'],
    max_seq_length=CONFIG['max_seq_len'],
    dtype=CONFIG['dtype'],
    load_in_4bit=True,
)
model = FastVisionModel.get_peft_model(
    model,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
)

training_args = TrainingArguments(
    output_dir=str(Path(CONFIG['artifacts_dir']) / 'trainer'),
    num_train_epochs=CONFIG['epochs'],
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=grad_accum,
    learning_rate=CONFIG['learning_rate'],
    warmup_ratio=CONFIG['warmup_ratio'],
    logging_steps=50,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    bf16=(CONFIG['dtype'] == 'bfloat16'),
    load_best_model_at_end=True,
    metric_for_best_model='eval_macro_f1',
    greater_is_better=True,
    predict_with_generate=True,
    generation_max_length=256,
)

vision_collator = UnslothVisionDataCollator(tokenizer=tokenizer, image_processor=image_processor)

def compute_metrics(eval_preds):
    preds = eval_preds.predictions
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    def extract_label(text: str) -> int:
        try:
            payload = json.loads(text)
            return int(bool(payload.get('is_spam')))
        except Exception:
            return 0

    pred_labels = [extract_label(t) for t in decoded_preds]
    y_true = VAL_TRUE_LABELS[: len(pred_labels)]

    precision, recall, f1, _ = precision_recall_fscore_support(y_true, pred_labels, average='macro', zero_division=0)
    accuracy = accuracy_score(y_true, pred_labels)
    return {
        'accuracy': accuracy,
        'macro_f1': f1,
        'macro_precision': precision,
        'macro_recall': recall,
    }

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset['train'],
    eval_dataset=train_dataset['validation'],
    data_collator=vision_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

train_result = trainer.train()
print(train_result)

merged_dir = Path(CONFIG['artifacts_dir']) / 'merged_model'
merged_dir.mkdir(parents=True, exist_ok=True)
FastVisionModel.save_pretrained(model, merged_dir, tokenizer=tokenizer, image_processor=image_processor)


## 6. Transformers inference

In [None]:
#@title Deterministic inference on validation (and optional test)
from transformers import AutoModelForVision2Seq, AutoProcessor

merged_path = Path(CONFIG["artifacts_dir"]) / "merged_model"
processor = AutoProcessor.from_pretrained(merged_path)
merged_model = AutoModelForVision2Seq.from_pretrained(merged_path).to("cuda" if torch.cuda.is_available() else "cpu")


def generate_predictions(df: pd.DataFrame) -> pd.DataFrame:
    outputs = []
    for record in tqdm(df.to_dict("records"), desc="Generating"):
        url = record.get("image_url") or ""
        image_path = utils.download_image(url, CONFIG["cache_dir"]) if url else None
        user_content = []
        if image_path is not None and image_path.exists():
            user_content.append({"type": "image", "image": str(image_path)})
        message_text = f"{PROMPT_TEMPLATE}\n\nDescription: {record['description']}"
        if not (image_path is not None and image_path.exists()):
            message_text += "\n\nNote: image unavailable. Base your judgment on the text only."
        user_content.append({"type": "text", "text": message_text})

        messages = [
            {"role": "system", "content": [{"type": "text", "text": "Strict JSON with is_spam, confidence, reason."}]},
            {"role": "user", "content": user_content},
        ]

        inputs = processor.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(merged_model.device)
        output = merged_model.generate(**inputs, max_new_tokens=256, temperature=0.0, do_sample=False)
        decoded = processor.batch_decode(output, skip_special_tokens=True)[0]
        try:
            payload = json.loads(decoded)
        except Exception:
            payload = {"is_spam": False, "confidence": 0.0, "reason": "malformed"}
        outputs.append(
            {
                "product_id": record["product_id"],
                "label": record.get("label", 0),
                "image_available": bool(image_path and image_path.exists()),
                "raw_response": decoded,
                "is_spam_pred": bool(payload.get("is_spam")),
                "confidence_pred": float(payload.get("confidence", 0.0)),
                "reason_pred": payload.get("reason", ""),
            }
        )
    return pd.DataFrame(outputs)


val_predictions = generate_predictions(val_split)
val_predictions.to_parquet(Path(CONFIG["artifacts_dir"]) / "validation_predictions.parquet", index=False)
print(val_predictions.head())

## 7. Evaluation & policy

In [None]:
#@title Threshold sweep, demotion policy, and metrics

def evaluate_predictions(df: pd.DataFrame, review_threshold: float, demote_threshold: float) -> dict[str, Any]:
    y_true = df["label"].astype(int).tolist()
    y_pred = df["is_spam_pred"].astype(int).tolist()

    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="macro", zero_division=0)
    accuracy = accuracy_score(y_true, y_pred)

    decisions = []
    for row in df.itertuples(index=False):
        if bool(row.is_spam_pred) and row.confidence_pred >= demote_threshold:
            decisions.append("demote")
        elif bool(row.is_spam_pred) and row.confidence_pred >= review_threshold:
            decisions.append("review")
        else:
            decisions.append("keep")
    df = df.assign(decision=decisions)

    image_present = df[df["image_available"]]
    image_missing = df[~df["image_available"]]
    metrics = {
        "accuracy": accuracy,
        "macro_f1": f1,
        "macro_precision": precision,
        "macro_recall": recall,
        "image_present_count": int(len(image_present)),
        "image_missing_count": int(len(image_missing)),
    }
    (Path(CONFIG["artifacts_dir"]) / "metrics.json").write_text(json.dumps(metrics, indent=2))
    (Path(CONFIG["artifacts_dir"]) / "classification_report.json").write_text(classification_report(y_true, y_pred, digits=3))
    df.to_parquet(Path(CONFIG["artifacts_dir"]) / "predictions_with_decisions.parquet", index=False)
    return metrics


def sweep_thresholds(df: pd.DataFrame, review_grid: Iterable[float], demote_grid: Iterable[float]) -> tuple[float, float, float]:
    best_score = -1.0
    best_review = CONFIG["review_threshold"]
    best_demote = CONFIG["demote_threshold"]
    for review in review_grid:
        for demote in demote_grid:
            if demote <= review:
                continue
            y_pred = []
            for row in df.itertuples(index=False):
                if bool(row.is_spam_pred) and row.confidence_pred >= demote:
                    y_pred.append(1)
                elif bool(row.is_spam_pred) and row.confidence_pred >= review:
                    y_pred.append(1)
                else:
                    y_pred.append(0)
            score = precision_recall_fscore_support(df["label"], y_pred, average="macro", zero_division=0)[2]
            if score > best_score:
                best_score = score
                best_review = review
                best_demote = demote
    return best_review, best_demote, best_score


review_candidates = np.linspace(0.1, 0.9, num=9)
demote_candidates = np.linspace(0.2, 1.0, num=9)
best_review, best_demote, best_score = sweep_thresholds(val_predictions, review_candidates, demote_candidates)
print(f"Best thresholds: review={best_review:.2f}, demote={best_demote:.2f}, macro_f1={best_score:.3f}")

metrics = evaluate_predictions(val_predictions, best_review, best_demote)
print(json.dumps(metrics, indent=2))

## 8. Curated gallery

In [None]:
#@title TP/TN/FP/FN examples with images
import base64
import itertools

CATEGORY_MAP = {
    "True Positive": lambda r: r.label == 1 and r.is_spam_pred,
    "True Negative": lambda r: r.label == 0 and not r.is_spam_pred,
    "False Positive": lambda r: r.label == 0 and r.is_spam_pred,
    "False Negative": lambda r: r.label == 1 and not r.is_spam_pred,
}

GALLERY_MAX = 4


def image_to_base64(path: Path | None) -> str:
    if path is None or not path.exists():
        return ""
    return base64.b64encode(path.read_bytes()).decode("utf-8")


def display_gallery(df: pd.DataFrame) -> None:
    rows = []
    for category, predicate in CATEGORY_MAP.items():
        subset = [r for r in df.itertuples(index=False) if predicate(r)]
        for row in itertools.islice(subset, GALLERY_MAX):
            original = train_df.loc[train_df.product_id == row.product_id].iloc[0]
            img_path = utils.download_image(original.get("image_url"), CONFIG["cache_dir"])
            encoded = image_to_base64(img_path)
            img_tag = f"<img src='data:image/jpeg;base64,{encoded}' width='200'>" if encoded else "(no image)"
            rows.append(
                {
                    "Category": category,
                    "Product ID": row.product_id,
                    "True Label": row.label,
                    "Predicted": row.is_spam_pred,
                    "Confidence": f"{row.confidence_pred:.2f}",
                    "Reason": row.reason_pred,
                    "Image": img_tag,
                }
            )
    if not rows:
        print("No rows to display")
        return
    display(HTML(pd.DataFrame(rows).to_html(escape=False)))


display_gallery(val_predictions)

## 9. Package artifacts

In [None]:
#@title Bundle outputs for download
import shutil

zip_path = Path(CONFIG["artifacts_dir"]) / "keyword_spam_artifacts.zip"
if zip_path.exists():
    zip_path.unlink()
shutil.make_archive(str(zip_path.with_suffix("")), "zip", CONFIG["artifacts_dir"])
print(f"Artifacts packaged at {zip_path}")