"""
End-to-End 广告审核治理系统 Demo
包含：
1. 规则层 + OCR（模拟）
2. 文本/数值特征建模 + SMOTE/class_weight
3. PR 曲线阈值优化
4. SHAP 模型解释

依赖：
pip install scikit-learn imbalanced-learn lightgbm shap sentence-transformers rapidfuzz pillow pytesseract matplotlib
"""

In [None]:
import re, unicodedata, os
import numpy as np
import pandas as pd
from rapidfuzz import fuzz

## ========= 1. 模拟数据生成 ==========

In [None]:
np.random.seed(42)
N = 2000
texts = np.random.choice([
    "Free iPhone click now", "Limited offer discount shoes", "Educational course in AI",
    "Normal clothing ad", "High quality electronics sale", "Get rich quick!!!"
], size=N)
ocr_variants = []
for t in texts:
    if "Free" in t or "rich" in t:
        ocr_variants.append(t if np.random.rand() < 0.5 else "")
    else:
        ocr_variants.append("")
ocr = np.array(ocr_variants)
ctr = np.random.beta(1.2, 10, size=N)
impr_1h = np.random.poisson(50, size=N)
adv_hist_violation = np.random.binomial(1, 0.05, size=N)
labels = np.array([1 if ("Free" in t or "rich" in t) else 0 for t in texts])
flip_idx = np.random.choice(N, size=int(0.02*N), replace=False)
labels[flip_idx] = 1 - labels[flip_idx]
df = pd.DataFrame({
    "ad_text": texts,
    "ocr_text": ocr,
    "ctr": ctr,
    "impr_1h": impr_1h,
    "adv_hist_violation": adv_hist_violation,
    "label": labels
})

## ========= 2. 规则层 + OCR 模拟过滤 ==========

In [None]:
def normalize_text(s):
    if pd.isna(s): return ""
    s = unicodedata.normalize('NFKC', str(s))
    return s.replace('\u200b','').lower().strip()

def rule_check_text(text):
    text = normalize_text(text)
    rules = [r"100%中奖", r"free", r"get rich", r"免费"]
    for p in rules:
        if re.search(p, text, flags=re.I):
            return True, p
    if fuzz.partial_ratio("free", text) > 85:
        return True, "fuzzy:free"
    return False, None

df['rule_block'] = df['ad_text'].apply(lambda x: rule_check_text(x)[0]) | df['ocr_text'].apply(lambda x: rule_check_text(x)[0])
print("Rule-blocked samples:", df['rule_block'].sum())

## ========= 3. 特征提取（文本嵌入 + 数值特征） ==========

In [None]:
from sentence_transformers import SentenceTransformer
text_model = SentenceTransformer('all-MiniLM-L6-v2')
docs = (df['ad_text'].fillna('') + " " + df['ocr_text'].fillna('')).tolist()
text_emb = text_model.encode(docs, show_progress_bar=True, batch_size=64)
num_feats = df[['ctr','impr_1h','adv_hist_violation']].values
X = np.hstack([text_emb, num_feats])
y = df['label'].values

## ========= 4. 训练集/验证集划分 ==========

In [None]:
from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)

## ========= 5. 不平衡处理 + 模型训练 ==========

In [None]:
import lightgbm as lgb
from sklearn.metrics import average_precision_score, roc_auc_score

n_pos = y_train.sum()
n_neg = len(y_train) - n_pos
scale_pos_weight = max(1.0, n_neg / max(1, n_pos))
clf = lgb.LGBMClassifier(n_estimators=200, scale_pos_weight=scale_pos_weight)
clf.fit(X_train, y_train)
probs_val = clf.predict_proba(X_val)[:,1]
print("AP:", average_precision_score(y_val, probs_val), "AUC:", roc_auc_score(y_val, probs_val))

## ========= 6. PR 曲线阈值选择 ==========

In [None]:
from sklearn.metrics import precision_recall_curve, confusion_matrix
precision, recall, thresholds = precision_recall_curve(y_val, probs_val)
C_FP, C_FN = 1.0, 50.0
best_t, best_cost = 0.5, float('inf')
for t in np.linspace(0.01, 0.99, 50):
    y_pred = (probs_val >= t).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_val, y_pred).ravel()
    cost = C_FP*fp + C_FN*fn
    if cost < best_cost:
        best_cost, best_t = cost, t
print("Best threshold:", best_t, "with cost:", best_cost)

## ========= 7. SHAP 模型解释 ==========

In [None]:
try:
    import shap, matplotlib.pyplot as plt
    explainer = shap.TreeExplainer(clf)
    shap_values = explainer.shap_values(X_val)
    if isinstance(shap_values, list):
        shap_values = shap_values[1]
    shap.summary_plot(shap_values, X_val, show=False)
    plt.savefig("shap_summary.png")
    print("Saved SHAP summary plot to shap_summary.png")
except Exception as e:
    print("SHAP failed:", e)

## ========= 8. 测试接口示例 ==========

In [None]:
def predict_ad(ad_text, ocr_text="", ctr=0.05, impr_1h=10, adv_hist_violation=0):
    # 规则层
    rule_block, reason = rule_check_text(ad_text)
    if rule_block:
        return {"block": True, "reason": reason}
    # 特征
    doc = ad_text + " " + ocr_text
    emb = text_model.encode([doc])
    feats = np.hstack([emb, [[ctr, impr_1h, adv_hist_violation]]])
    prob = clf.predict_proba(feats)[:,1][0]
    decision = (prob >= best_t)
    return {"block": bool(decision), "prob": float(prob), "threshold": float(best_t)}

print("Test ad:", predict_ad("Free bitcoin win now"))
print("Test ad:", predict_ad("Educational AI course"))