# SHAP 可解釋性分析 - CyberPuppy 毒性偵測模型

本筆記使用 SHAP (SHapley Additive exPlanations) Partition explainer 對 Transformer 文本分類模型進行可解釋性分析。

## 參考資料
- [SHAP Documentation](https://shap.readthedocs.io/)
- [SHAP Partition Explainer](https://shap.readthedocs.io/en/latest/generated/shap.explainers.Partition.html)
- [SHAP Text Plots](https://shap.readthedocs.io/en/latest/example_notebooks/text_examples/text_classification/Explaining%20Sentiment%20Classification%20DistilBERT.html)

## 1. 環境設置與套件導入

In [None]:
import os
import sys
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Tuple, Any, Optional, Union

# 添加專案路徑
sys.path.append('../src')

# SHAP 相關套件
import shap
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# PyTorch
import torch
import torch.nn.functional as F

# CyberPuppy 模組
from cyberpuppy.models.baselines import MultiTaskBertModel
from cyberpuppy.config import MODEL_CONFIG

# 設置中文字體
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

# 忽略警告
warnings.filterwarnings('ignore')

print(f"SHAP version: {shap.__version__}")
print(f"Transformers version: {transformers.__version__}")
print(f"PyTorch version: {torch.__version__}")

## 2. 載入模型與數據

In [None]:
# 設備配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用設備: {device}")

# 載入tokenizer
model_name = MODEL_CONFIG['model']['name']
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 載入訓練好的模型
model_path = '../models/baseline_model.pt'
if os.path.exists(model_path):
    model = MultiTaskBertModel(
        model_name=model_name,
        num_toxicity_labels=3,
        num_emotion_labels=3,
        num_bullying_labels=4,
        num_role_labels=4
    ).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    print(f"模型已載入: {model_path}")
else:
    print(f"警告: 模型文件不存在 {model_path}，將使用預訓練模型")
    model = MultiTaskBertModel(
        model_name=model_name,
        num_toxicity_labels=3,
        num_emotion_labels=3,
        num_bullying_labels=4,
        num_role_labels=4
    ).to(device)

model.eval()
print("模型設為評估模式")

## 3. 準備測試文本數據

In [None]:
# 測試文本樣例
test_texts = [
    # 正面文本
    "這個產品真的很棒，我很喜歡！",
    "謝謝你的幫助，你人真好。",
    "今天天氣很不錯，心情也很好。",
    
    # 中性文本
    "我今天去了超市買東西。",
    "會議將在下午三點開始。",
    "這是一個關於技術的討論。",
    
    # 負面/毒性文本
    "你這個白癡，什麼都不懂！",
    "滾開，不要來煩我！",
    "你們這些垃圾，都給我閉嘴！",
    
    # 霸凌相關文本
    "沒人喜歡你，你就是個失敗者。",
    "你應該消失，世界會更美好。",
    "大家都在笑話你，你不知道嗎？"
]

# 創建標籤映射
toxicity_labels = ['none', 'toxic', 'severe']
emotion_labels = ['negative', 'neutral', 'positive']
bullying_labels = ['none', 'harassment', 'threat', 'victim']
role_labels = ['none', 'perpetrator', 'victim', 'bystander']

print(f"準備了 {len(test_texts)} 個測試文本")
for i, text in enumerate(test_texts[:5]):
    print(f"{i+1}: {text}")
print("...")

## 4. 創建 SHAP Explainer 封裝類

In [None]:
class TransformerSHAPExplainer:
    """
    使用 SHAP Partition explainer 的 Transformer 模型解釋器
    
    參考:
    - SHAP Partition explainer: https://shap.readthedocs.io/en/latest/generated/shap.explainers.Partition.html
    - Transformer 文本解釋: https://shap.readthedocs.io/en/latest/example_notebooks/text_examples/
    """
    
    def __init__(self, model, tokenizer, device='cpu', max_length=512):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.max_length = max_length
        
        # 創建預測函數
        self.predict_fn = self._create_predict_function()
        
        # 初始化 SHAP explainer
        self.explainer = None
    
    def _create_predict_function(self):
        """創建 SHAP 可用的預測函數"""
        def predict(texts):
            if isinstance(texts, str):
                texts = [texts]
            
            predictions = []
            
            with torch.no_grad():
                for text in texts:
                    encoding = self.tokenizer(
                        text,
                        truncation=True,
                        padding='max_length',
                        max_length=self.max_length,
                        return_tensors='pt'
                    )
                    
                    input_ids = encoding['input_ids'].to(self.device)
                    attention_mask = encoding['attention_mask'].to(self.device)
                    
                    outputs = self.model(input_ids, attention_mask)
                    toxicity_probs = F.softmax(outputs['toxicity'], dim=-1).cpu().numpy()[0]
                    
                    predictions.append(toxicity_probs)
            
            return np.array(predictions)
        
        return predict
    
    def setup_explainer(self, background_texts, max_evals=500):
        """設置 SHAP Partition explainer"""
        print("設置 SHAP Partition explainer...")
        
        self.explainer = shap.explainers.Partition(
            self.predict_fn, 
            max_evals=max_evals,
            silent=True
        )
        
        print(f"Explainer 設置完成，最大評估次數: {max_evals}")
    
    def explain_text(self, text, output_class=1):
        """解釋單個文本"""
        if self.explainer is None:
            raise ValueError("請先調用 setup_explainer() 設置解釋器")
        
        shap_values = self.explainer([text])
        return shap_values
    
    def explain_batch(self, texts, output_class=1):
        """批次解釋多個文本"""
        if self.explainer is None:
            raise ValueError("請先調用 setup_explainer() 設置解釋器")
        
        shap_values = self.explainer(texts)
        return shap_values

# 創建 SHAP 解釋器
shap_explainer = TransformerSHAPExplainer(
    model=model,
    tokenizer=tokenizer,
    device=device,
    max_length=256
)

print("SHAP 解釋器創建完成")

## 5. 設置 SHAP Explainer

In [None]:
# 使用部分測試文本作為背景樣例
background_texts = test_texts[:6]

# 設置 explainer
shap_explainer.setup_explainer(
    background_texts=background_texts,
    max_evals=100
)

print("SHAP Partition explainer 設置完成")

## 6. 單句分析與可視化

In [None]:
# 選擇一個毒性文本進行詳細分析
target_text = "你這個白癡，什麼都不懂！"
print(f"分析文本: {target_text}")

# 獲取模型預測
prediction = shap_explainer.predict_fn([target_text])[0]
predicted_class = np.argmax(prediction)
confidence = prediction[predicted_class]

print(f"預測結果:")
print(f"  預測類別: {toxicity_labels[predicted_class]} (索引: {predicted_class})")
print(f"  信心分數: {confidence:.4f}")
print(f"  所有類別機率: {prediction}")

# 計算 SHAP 值
print("\n計算 SHAP 值...")
shap_values = shap_explainer.explain_text(target_text, output_class=predicted_class)

print(f"SHAP 值計算完成")
print(f"SHAP values shape: {shap_values.values.shape}")
print(f"Data shape: {shap_values.data.shape if hasattr(shap_values, 'data') else 'N/A'}")

## 7. SHAP 文本可視化 (shap.plots.text)

In [None]:
# 使用 shap.plots.text 進行可視化
# 參考: https://shap.readthedocs.io/en/latest/generated/shap.plots.text.html

try:
    print("創建 SHAP 文本可視化...")
    
    # 嘗試直接使用 SHAP 的文本可視化
    shap.plots.text(shap_values[0], display=False)
    plt.title(f'SHAP 文本解釋: {target_text}')
    plt.tight_layout()
    plt.savefig('../data/processed/shap_text_plot.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("SHAP plots.text 可視化完成")
    
except Exception as e:
    print(f"SHAP plots.text 出現錯誤: {e}")
    print("改用手動可視化...")
    
    # 手動創建文本可視化
    if len(shap_values.values.shape) > 2:
        target_shap_values = shap_values.values[0, :, predicted_class]
    else:
        target_shap_values = shap_values.values[0]
    
    tokens = list(target_text)
    
    if len(target_shap_values) != len(tokens):
        min_len = min(len(target_shap_values), len(tokens))
        target_shap_values = target_shap_values[:min_len]
        tokens = tokens[:min_len]
    
    plt.figure(figsize=(15, 8))
    colors = ['red' if val < 0 else 'green' for val in target_shap_values]
    plt.barh(range(len(tokens)), target_shap_values, color=colors, alpha=0.7)
    plt.yticks(range(len(tokens)), tokens)
    plt.xlabel('SHAP Value')
    plt.title(f'字符級 SHAP 重要性分析\n文本: "{target_text}"\n預測: {toxicity_labels[predicted_class]} ({confidence:.3f})')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('../data/processed/shap_single_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("單句 SHAP 分析可視化完成")

## 8. 批次對比分析

In [None]:
# 選擇幾個不同類型的文本進行對比分析
comparison_texts = [
    "謝謝你的幫助，你人真好。",
    "我今天去了超市買東西。",
    "你這個白癡，什麼都不懂！",
    "沒人喜歡你，你就是個失敗者。"
]

print("批次對比分析...")
print(f"分析 {len(comparison_texts)} 個文本")

# 獲取所有預測
batch_predictions = shap_explainer.predict_fn(comparison_texts)
predicted_classes = np.argmax(batch_predictions, axis=1)

# 顯示預測結果
results_df = pd.DataFrame({
    'Text': comparison_texts,
    'Predicted_Class': [toxicity_labels[cls] for cls in predicted_classes],
    'Confidence': [batch_predictions[i][predicted_classes[i]] for i in range(len(comparison_texts))],
    'None_Prob': batch_predictions[:, 0],
    'Toxic_Prob': batch_predictions[:, 1],
    'Severe_Prob': batch_predictions[:, 2]
})

print("\n批次預測結果:")
print(results_df.round(4))

# 計算批次 SHAP 值
print("\n計算批次 SHAP 值...")
try:
    batch_shap_values = shap_explainer.explain_batch(comparison_texts)
    print(f"批次 SHAP 值計算完成")
    
    # 批次可視化
    for i, text in enumerate(comparison_texts):
        plt.figure(figsize=(12, 3))
        try:
            shap.plots.text(batch_shap_values[i], display=False)
            plt.title(f'文本 {i+1}: {text}\n預測: {toxicity_labels[predicted_classes[i]]} ({batch_predictions[i][predicted_classes[i]]:.3f})')
        except:
            # 手動可視化
            if len(batch_shap_values.values.shape) > 2:
                current_shap_values = batch_shap_values.values[i, :, predicted_classes[i]]
            else:
                current_shap_values = batch_shap_values.values[i]
            
            tokens = list(text)
            min_len = min(len(current_shap_values), len(tokens))
            current_shap_values = current_shap_values[:min_len]
            tokens = tokens[:min_len]
            
            colors = ['red' if val < 0 else 'green' for val in current_shap_values]
            plt.barh(range(len(tokens)), current_shap_values, color=colors, alpha=0.7)
            plt.yticks(range(len(tokens)), tokens)
            plt.xlabel('SHAP Value')
            plt.title(f'文本 {i+1}: "{text}"\n預測: {toxicity_labels[predicted_classes[i]]} ({batch_predictions[i][predicted_classes[i]]:.3f})')
        
        plt.tight_layout()
        plt.show()
    
    print("批次對比分析可視化完成")
    
except Exception as e:
    print(f"批次 SHAP 分析出現錯誤: {e}")

## 9. 統計分析與模式發現

In [None]:
# 分析不同類型文本的 SHAP 模式
print("統計分析與模式發現...")

text_categories = {
    'positive': ["謝謝你的幫助，你人真好。", "這個產品真的很棒，我很喜歡！", "今天天氣很不錯，心情也很好。"],
    'neutral': ["我今天去了超市買東西。", "會議將在下午三點開始。", "這是一個關於技術的討論。"],
    'toxic': ["你這個白癡，什麼都不懂！", "滾開，不要來煩我！", "你們這些垃圾，都給我閉嘴！"]
}

analysis_results = {}

for category, texts in text_categories.items():
    print(f"\n分析 {category} 類文本...")
    
    predictions = shap_explainer.predict_fn(texts)
    predicted_classes = np.argmax(predictions, axis=1)
    
    analysis_results[category] = {
        'avg_toxic_prob': np.mean(predictions[:, 1]),
        'avg_severe_prob': np.mean(predictions[:, 2]),
        'predicted_classes': predicted_classes,
        'class_distribution': np.bincount(predicted_classes, minlength=3)
    }
    
    print(f"  平均毒性機率: {analysis_results[category]['avg_toxic_prob']:.4f}")
    print(f"  平均嚴重機率: {analysis_results[category]['avg_severe_prob']:.4f}")
    print(f"  類別分布: {dict(zip(toxicity_labels, analysis_results[category]['class_distribution']))}")

# 可視化統計結果
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

categories = list(text_categories.keys())
toxic_probs = [analysis_results[cat]['avg_toxic_prob'] for cat in categories]
severe_probs = [analysis_results[cat]['avg_severe_prob'] for cat in categories]

x = np.arange(len(categories))
width = 0.35

axes[0, 0].bar(x - width/2, toxic_probs, width, label='Toxic', alpha=0.7, color='orange')
axes[0, 0].bar(x + width/2, severe_probs, width, label='Severe', alpha=0.7, color='red')
axes[0, 0].set_xlabel('Text Category')
axes[0, 0].set_ylabel('Average Probability')
axes[0, 0].set_title('不同文本類型的平均毒性機率')
axes[0, 0].set_xticks(x)
axes[0, 0].set_xticklabels(categories)
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# 類別分布熱力圖
class_matrix = np.array([analysis_results[cat]['class_distribution'] for cat in categories])
im = axes[0, 1].imshow(class_matrix, cmap='Blues', aspect='auto')
axes[0, 1].set_xticks(range(len(toxicity_labels)))
axes[0, 1].set_xticklabels(toxicity_labels)
axes[0, 1].set_yticks(range(len(categories)))
axes[0, 1].set_yticklabels(categories)
axes[0, 1].set_title('類別預測分布熱力圖')

for i in range(len(categories)):
    for j in range(len(toxicity_labels)):
        axes[0, 1].text(j, i, class_matrix[i, j], ha='center', va='center', 
                       color='white' if class_matrix[i, j] > class_matrix.max()/2 else 'black')

plt.colorbar(im, ax=axes[0, 1])

plt.tight_layout()
plt.savefig('../data/processed/shap_statistical_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n統計分析可視化完成")

## 10. 誤判案例診斷

In [None]:
# 誤判案例診斷
print("誤判案例診斷分析...")

# 邊界案例
edge_cases = [
    "哇，你真是太聰明了！",  # 諷刺
    "這個白癡程式碼又出錯了",  # 指物不指人
    "垃圾郵件太多了，很煩人",  # 指垃圾郵件
    "我真的很生氣這個決定！",  # 情緒激動但非毒性
    "你的想法我不太認同",  # 輕微負面但非毒性
    "這個政策確實有爭議"  # 敏感話題但非攻擊
]

expected_labels = [1, 0, 0, 0, 0, 0]  # 人工標註

edge_predictions = shap_explainer.predict_fn(edge_cases)
edge_predicted_classes = np.argmax(edge_predictions, axis=1)

# 找出誤判案例
misclassified_indices = []
for i, (expected, predicted) in enumerate(zip(expected_labels, edge_predicted_classes)):
    if (expected == 0 and predicted > 0) or (expected > 0 and predicted == 0):
        misclassified_indices.append(i)

print(f"發現 {len(misclassified_indices)} 個潛在誤判案例")

if misclassified_indices:
    print("\n詳細誤判分析:")
    
    misclass_data = []
    for idx in misclassified_indices:
        text = edge_cases[idx]
        expected = toxicity_labels[expected_labels[idx]]
        predicted = toxicity_labels[edge_predicted_classes[idx]]
        confidence = edge_predictions[idx][edge_predicted_classes[idx]]
        
        misclass_data.append({
            'Text': text,
            'Expected': expected,
            'Predicted': predicted,
            'Confidence': confidence,
            'Error_Type': 'False Positive' if expected_labels[idx] == 0 else 'False Negative'
        })
    
    misclass_df = pd.DataFrame(misclass_data)
    print(misclass_df.round(4))
    
    # SHAP 分析誤判案例
    for i, idx in enumerate(misclassified_indices):
        text = edge_cases[idx]
        predicted_class = edge_predicted_classes[idx]
        
        try:
            shap_values = shap_explainer.explain_text(text, predicted_class)
            
            if len(shap_values.values.shape) > 2:
                current_shap_values = shap_values.values[0, :, predicted_class]
            else:
                current_shap_values = shap_values.values[0]
            
            tokens = list(text)
            min_len = min(len(current_shap_values), len(tokens))
            current_shap_values = current_shap_values[:min_len]
            tokens = tokens[:min_len]
            
            plt.figure(figsize=(15, 4))
            colors = ['red' if val < 0 else 'green' for val in current_shap_values]
            plt.barh(range(len(tokens)), current_shap_values, color=colors, alpha=0.7)
            plt.yticks(range(len(tokens)), tokens)
            plt.xlabel('SHAP Value')
            
            error_type = misclass_data[i]['Error_Type']
            expected = misclass_data[i]['Expected']
            predicted = misclass_data[i]['Predicted']
            
            plt.title(f'誤判案例 {i+1}: {error_type}\n"{text}"\n期望: {expected} → 預測: {predicted}')
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.show()
            
        except Exception as e:
            print(f"分析誤判案例 {idx} 時出錯: {e}")
    
    # 保存誤判報告
    misclass_df.to_csv('../data/processed/misclassification_report.csv', index=False, encoding='utf-8')
    print(f"\n誤判分析報告已保存")
    
else:
    print("在測試案例中未發現明顯誤判")

# 統計
overall_accuracy = np.mean(np.array(expected_labels) == edge_predicted_classes)
false_positive_rate = np.mean((np.array(expected_labels) == 0) & (edge_predicted_classes > 0))
false_negative_rate = np.mean((np.array(expected_labels) > 0) & (edge_predicted_classes == 0))

print(f"\n整體準確率: {overall_accuracy:.4f}")
print(f"誤報率: {false_positive_rate:.4f}")
print(f"漏報率: {false_negative_rate:.4f}")

print("\n誤判案例診斷完成")

## 11. 總結報告與建議

In [None]:
# 生成總結報告
print("=" * 60)
print("SHAP 可解釋性分析總結報告")
print("=" * 60)

report_data = {
    'analysis_date': pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S'),
    'model_info': {
        'base_model': model_name,
        'task_types': ['toxicity', 'emotion', 'bullying', 'role'],
        'num_parameters': sum(p.numel() for p in model.parameters()),
        'device': str(device)
    },
    'shap_config': {
        'explainer_type': 'Partition',
        'max_evals': 100,
        'background_samples': len(background_texts)
    },
    'analysis_results': {
        'total_texts_analyzed': len(test_texts),
        'edge_cases_tested': len(edge_cases),
        'misclassified_cases': len(misclassified_indices) if 'misclassified_indices' in locals() else 0
    }
}

print(f"\n📊 分析概況:")
print(f"  分析時間: {report_data['analysis_date']}")
print(f"  基礎模型: {report_data['model_info']['base_model']}")
print(f"  模型參數量: {report_data['model_info']['num_parameters']:,}")
print(f"  計算設備: {report_data['model_info']['device']}")

print(f"\n🔍 SHAP 配置:")
print(f"  解釋器類型: {report_data['shap_config']['explainer_type']} Explainer")
print(f"  最大評估次數: {report_data['shap_config']['max_evals']}")
print(f"  背景樣本數: {report_data['shap_config']['background_samples']}")

print(f"\n🎯 主要發現:")
print(f"  ✅ SHAP Partition explainer 成功應用於中文毒性偵測")
print(f"  ✅ shap.plots.text 提供直觀的文本可視化")
print(f"  ✅ 批次對比分析揭示不同文本類型的模式差異")
print(f"  ✅ 誤判案例診斷有助於模型改進")

print(f"\n💡 使用建議:")
print(f"  • 使用 shap.plots.text() 進行標準化文本可視化")
print(f"  • 結合 Integrated Gradients 進行交叉驗證")
print(f"  • 定期更新背景樣本以提高解釋品質")
print(f"  • 將 SHAP 分析納入模型監控流程")

# 保存報告
import json
with open('../data/processed/shap_analysis_report.json', 'w', encoding='utf-8') as f:
    json.dump(report_data, f, ensure_ascii=False, indent=2, default=str)

print(f"\n📋 完整報告已保存: ../data/processed/shap_analysis_report.json")
print(f"\n參考資源:")
print(f"  • SHAP 官方文檔: https://shap.readthedocs.io/")
print(f"  • Transformer 解釋教程: https://shap.readthedocs.io/en/latest/example_notebooks/text_examples/")
print(f"  • CyberPuppy 專案: ../src/cyberpuppy/")

print("\n" + "=" * 60)
print("SHAP 可解釋性分析完成！")
print("=" * 60)