<a href="https://colab.research.google.com/github/wangyiyang/RAG-Cookbook-Code/blob/main/ch04/safety_checker.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install numpy transformers torch sentence-transformers

In [4]:
"""
内容安全检查器
实现内容安全检测、输出过滤净化和隐私信息保护
"""

import re
import hashlib
from typing import List, Dict, Any, Optional, Set, Tuple
from dataclasses import dataclass
from enum import Enum


class SafetyLevel(Enum):
    """安全级别枚举"""
    SAFE = "safe"
    WARNING = "warning"
    UNSAFE = "unsafe"
    BLOCKED = "blocked"


@dataclass
class SafetyIssue:
    """安全问题"""
    issue_type: str
    severity: SafetyLevel
    description: str
    confidence: float
    location: Optional[str] = None
    suggestions: Optional[List[str]] = None


@dataclass
class SafetyReport:
    """安全报告"""
    is_safe: bool
    overall_score: float
    issues: List[SafetyIssue]
    processed_content: Optional[str] = None

    def get_issues_by_type(self, issue_type: str) -> List[SafetyIssue]:
        """按类型获取安全问题"""
        return [issue for issue in self.issues if issue.issue_type == issue_type]

    def get_issues_by_severity(self, severity: SafetyLevel) -> List[SafetyIssue]:
        """按严重程度获取安全问题"""
        return [issue for issue in self.issues if issue.severity == severity]


class BiasDetector:
    """偏见检测器"""

    def __init__(self):
        self.bias_patterns = {
            'gender': [
                r'(男|女)(性|人).*?(更|较|比较).*?(适合|擅长|不擅长)',
                r'(男|女)(孩|生).*?(天生|生来).*?',
                r'(男|女)(性|人).*?(应该|不应该|必须|不能).*?'
            ],
            'racial': [
                r'(.*族|.*人).*?(天生|生来|本来).*?(聪明|愚笨|懒惰|勤劳)',
                r'(.*族|.*人).*?(都|全部|一律).*?',
                r'典型的.*?(.*族|.*人).*?'
            ],
            'age': [
                r'(老|年轻)(人|者).*?(应该|不应该|不能|只能).*?',
                r'(.*岁|老年|年轻).*?(肯定|一定|必然).*?'
            ],
            'professional': [
                r'(.*职业|.*工作).*?(低级|高级|下等|上等).*?',
                r'只有.*?(学历|背景).*?才.*?'
            ]
        }

        self.sensitive_terms = {
            '歧视性词汇': [
                '劣等', '低贱', '下等人', '贫民', '暴民'
            ],
            '刻板印象': [
                '天生', '生来就', '注定', '血统', '基因决定'
            ]
        }

    def detect_bias(self, text: str) -> float:
        """检测文本中的偏见"""
        bias_score = 0.0
        detected_issues = []

        text_lower = text.lower()

        # 检测偏见模式
        for bias_type, patterns in self.bias_patterns.items():
            for pattern in patterns:
                matches = re.findall(pattern, text_lower)
                if matches:
                    bias_score += 0.3
                    detected_issues.append({
                        'type': bias_type,
                        'pattern': pattern,
                        'matches': matches
                    })

        # 检测敏感词汇
        for category, terms in self.sensitive_terms.items():
            for term in terms:
                if term in text_lower:
                    bias_score += 0.2
                    detected_issues.append({
                        'type': 'sensitive_term',
                        'category': category,
                        'term': term
                    })

        return min(1.0, bias_score)

    def get_bias_issues(self, text: str) -> List[SafetyIssue]:
        """获取详细的偏见问题列表"""
        issues = []
        text_lower = text.lower()

        # 检测不同类型的偏见
        for bias_type, patterns in self.bias_patterns.items():
            for pattern in patterns:
                matches = re.finditer(pattern, text_lower)
                for match in matches:
                    issues.append(SafetyIssue(
                        issue_type=f"bias_{bias_type}",
                        severity=SafetyLevel.WARNING,
                        description=f"检测到{bias_type}偏见: {match.group()}",
                        confidence=0.8,
                        location=f"位置 {match.start()}-{match.end()}",
                        suggestions=[
                            "避免使用刻板印象的表述",
                            "使用更中性和客观的语言",
                            "考虑个体差异而非群体标签"
                        ]
                    ))

        return issues


class ToxicContentDetector:
    """有毒内容检测器"""

    def __init__(self):
        self.toxic_categories = {
            'hate_speech': {
                'patterns': [
                    r'(仇恨|憎恨|厌恶).*?(民族|种族|宗教)',
                    r'(消灭|铲除|清理).*?(.*族|.*教)',
                    r'(他们|这些人).*?(都应该|必须|应当).*?(死|滚|消失)'
                ],
                'keywords': ['仇杀', '种族清洗', '屠杀', '灭绝']
            },
            'harassment': {
                'patterns': [
                    r'(你|他|她).*?(白痴|傻子|蠢货|废物)',
                    r'(去死|死开|滚开).*?',
                    r'(威胁|要杀|要打).*?'
                ],
                'keywords': ['骚扰', '威胁', '恐吓', '辱骂']
            },
            'violence': {
                'patterns': [
                    r'(如何|怎样).*?(制造|制作).*?(炸弹|武器)',
                    r'(伤害|攻击|袭击).*?(方法|技巧)',
                    r'(杀死|杀害).*?(步骤|过程)'
                ],
                'keywords': ['暴力', '攻击', '伤害', '武器制造']
            }
        }

    def detect_toxicity(self, text: str) -> float:
        """检测有毒内容"""
        toxicity_score = 0.0
        text_lower = text.lower()

        for category, config in self.toxic_categories.items():
            # 检测模式
            for pattern in config['patterns']:
                if re.search(pattern, text_lower):
                    toxicity_score += 0.4

            # 检测关键词
            for keyword in config['keywords']:
                if keyword in text_lower:
                    toxicity_score += 0.2

        return min(1.0, toxicity_score)

    def get_toxicity_issues(self, text: str) -> List[SafetyIssue]:
        """获取有毒内容问题列表"""
        issues = []
        text_lower = text.lower()

        for category, config in self.toxic_categories.items():
            # 检测模式匹配
            for pattern in config['patterns']:
                matches = re.finditer(pattern, text_lower)
                for match in matches:
                    severity = SafetyLevel.UNSAFE if category == 'hate_speech' else SafetyLevel.WARNING

                    issues.append(SafetyIssue(
                        issue_type=f"toxic_{category}",
                        severity=severity,
                        description=f"检测到{category}内容: {match.group()}",
                        confidence=0.9,
                        location=f"位置 {match.start()}-{match.end()}",
                        suggestions=[
                            "移除有害内容",
                            "使用积极正面的表述",
                            "避免可能引起争议的内容"
                        ]
                    ))

        return issues


class PrivacyChecker:
    """隐私检查器"""

    def __init__(self):
        self.privacy_patterns = {
            'phone': {
                'pattern': r'1[3-9]\d{9}',
                'description': '手机号码'
            },
            'email': {
                'pattern': r'[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}',
                'description': '邮箱地址'
            },
            'id_card': {
                'pattern': r'\b\d{17}[\dX]\b',
                'description': '身份证号'
            },
            'credit_card': {
                'pattern': r'\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b',
                'description': '信用卡号'
            },
            'bank_account': {
                'pattern': r'\b\d{16,19}\b',
                'description': '银行账号'
            },
            'address': {
                'pattern': r'(.*省.*市.*区.*路.*号|.*街.*弄.*号)',
                'description': '详细地址'
            }
        }

    def check_privacy(self, text: str) -> List[SafetyIssue]:
        """检查隐私信息泄露"""
        privacy_violations = []

        for privacy_type, config in self.privacy_patterns.items():
            pattern = config['pattern']
            description = config['description']

            matches = re.finditer(pattern, text)
            for match in matches:
                # 对敏感信息进行脱敏处理
                masked_content = self._mask_sensitive_info(match.group(), privacy_type)

                privacy_violations.append(SafetyIssue(
                    issue_type=f"privacy_{privacy_type}",
                    severity=SafetyLevel.WARNING,
                    description=f"检测到{description}: {masked_content}",
                    confidence=0.95,
                    location=f"位置 {match.start()}-{match.end()}",
                    suggestions=[
                        "移除或脱敏个人信息",
                        "使用示例数据替代真实信息",
                        "添加隐私声明"
                    ]
                ))

        return privacy_violations

    def _mask_sensitive_info(self, content: str, info_type: str) -> str:
        """脱敏处理敏感信息"""
        if info_type == 'phone':
            return content[:3] + '****' + content[-4:]
        elif info_type == 'email':
            parts = content.split('@')
            if len(parts) == 2:
                username = parts[0]
                if len(username) <= 2:
                    return '*' + '@' + parts[1]
                else:
                    return username[:2] + '***@' + parts[1]
            return content
        elif info_type == 'id_card':
            return content[:6] + '********' + content[-4:]
        elif info_type == 'credit_card':
            return '**** **** **** ' + content[-4:]
        else:
            return '***敏感信息***'


class FactualVerifier:
    """事实核查器"""

    def __init__(self):
        self.known_facts = {
            'technology': {
                'rag': {
                    'definition': 'RAG是检索增强生成技术',
                    'components': ['检索器', '生成器'],
                    'applications': ['问答系统', '知识助手']
                }
            }
        }

        self.fact_error_patterns = [
            r'100%.*?(准确|正确|无误)',  # 过于绝对的表述
            r'永远不会.*?(错误|失败)',  # 不现实的声明
            r'所有.*?都.*?',            # 过于绝对的概括
            r'从来没有.*?',            # 绝对否定
            r'绝对.*?(安全|可靠)'       # 过度保证
        ]

    def verify_facts(self, content: str, context: str) -> List[SafetyIssue]:
        """验证事实准确性"""
        factual_errors = []

        # 检测过于绝对的表述
        for pattern in self.fact_error_patterns:
            matches = re.finditer(pattern, content)
            for match in matches:
                factual_errors.append(SafetyIssue(
                    issue_type="factual_absolute_claim",
                    severity=SafetyLevel.WARNING,
                    description=f"检测到过于绝对的表述: {match.group()}",
                    confidence=0.7,
                    location=f"位置 {match.start()}-{match.end()}",
                    suggestions=[
                        "使用更谨慎的表述",
                        "添加适当的限定词",
                        "承认可能存在的例外情况"
                    ]
                ))

        # 检测与上下文的不一致
        inconsistencies = self._check_context_consistency(content, context)
        factual_errors.extend(inconsistencies)

        return factual_errors

    def _check_context_consistency(self, content: str, context: str) -> List[SafetyIssue]:
        """检查与上下文的一致性"""
        issues = []

        # 简化的一致性检查
        # 实际应用中需要更复杂的NLP技术

        return issues


class OutputSanitizer:
    """输出净化器"""

    def __init__(self):
        self.replacement_rules = [
            {
                'pattern': r'1[3-9]\d{9}',
                'replacement': '[手机号码]',
                'description': '手机号码脱敏'
            },
            {
                'pattern': r'[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}',
                'replacement': '[邮箱地址]',
                'description': '邮箱地址脱敏'
            },
            {
                'pattern': r'\b\d{17}[\dX]\b',
                'replacement': '[身份证号]',
                'description': '身份证号脱敏'
            }
        ]

        self.content_filters = {
            'inappropriate_language': [
                '傻逼', '操蛋', '他妈的', '狗屎'
            ],
            'spam_indicators': [
                '点击这里', '立即购买', '限时优惠', '马上行动'
            ]
        }

    def sanitize_output(self, content: str) -> Tuple[str, List[str]]:
        """输出内容净化"""
        sanitized_content = content
        applied_rules = []

        # 1. 敏感信息脱敏
        for rule in self.replacement_rules:
            if re.search(rule['pattern'], sanitized_content):
                sanitized_content = re.sub(
                    rule['pattern'],
                    rule['replacement'],
                    sanitized_content
                )
                applied_rules.append(rule['description'])

        # 2. 不当内容过滤
        for category, terms in self.content_filters.items():
            for term in terms:
                if term in sanitized_content:
                    sanitized_content = sanitized_content.replace(
                        term,
                        '[已过滤]'
                    )
                    applied_rules.append(f'{category}过滤')

        # 3. 格式规范化
        sanitized_content = self.normalize_format(sanitized_content)

        return sanitized_content, applied_rules

    def normalize_format(self, content: str) -> str:
        """格式规范化"""
        # 移除多余空行
        content = re.sub(r'\n\s*\n\s*\n', '\n\n', content)

        # 统一标点符号
        content = content.replace('，，', '，')
        content = content.replace('。。', '。')

        # 移除首尾空白
        content = content.strip()

        return content


class ContentSafetyChecker:
    """内容安全检查器主类"""

    def __init__(self):
        self.bias_detector = BiasDetector()
        self.toxic_detector = ToxicContentDetector()
        self.privacy_checker = PrivacyChecker()
        self.factual_verifier = FactualVerifier()
        self.output_sanitizer = OutputSanitizer()

        # 安全阈值配置
        self.safety_thresholds = {
            'bias_score': 0.2,      # 降低偏见容忍度
            'toxicity_score': 0.1,  # 降低有毒内容容忍度
            'overall_safety': 0.90  # 进一步提高总体安全要求
        }

    def comprehensive_safety_check(
        self,
        generated_content: str,
        context: str = ""
    ) -> SafetyReport:
        """综合安全检查"""
        all_issues = []

        # 1. 偏见检测
        bias_score = self.bias_detector.detect_bias(generated_content)
        if bias_score > self.safety_thresholds['bias_score']:
            bias_issues = self.bias_detector.get_bias_issues(generated_content)
            all_issues.extend(bias_issues)

        # 2. 有毒内容检测
        toxicity_score = self.toxic_detector.detect_toxicity(generated_content)
        if toxicity_score > self.safety_thresholds['toxicity_score']:
            toxic_issues = self.toxic_detector.get_toxicity_issues(generated_content)
            all_issues.extend(toxic_issues)

        # 3. 隐私信息检查 - 总是检查，不设阈值
        privacy_issues = self.privacy_checker.check_privacy(generated_content)
        all_issues.extend(privacy_issues)

        # 4. 事实核查
        factual_issues = self.factual_verifier.verify_facts(
            generated_content, context
        )
        all_issues.extend(factual_issues)

        # 5. 内容净化
        sanitized_content, applied_rules = self.output_sanitizer.sanitize_output(
            generated_content
        )

        # 计算总体安全分数
        overall_score = self._calculate_safety_score(
            bias_score, toxicity_score, len(privacy_issues), len(factual_issues)
        )

        # 判断是否安全
        is_safe = (
            overall_score >= self.safety_thresholds['overall_safety'] and
            not any(issue.severity == SafetyLevel.UNSAFE for issue in all_issues)
        )

        return SafetyReport(
            is_safe=is_safe,
            overall_score=overall_score,
            issues=all_issues,
            processed_content=sanitized_content
        )

    def _calculate_safety_score(
        self,
        bias_score: float,
        toxicity_score: float,
        privacy_violations: int,
        factual_errors: int
    ) -> float:
        """计算总体安全分数"""
        base_score = 1.0

        # 偏见扣分 - 增加权重
        base_score -= bias_score * 0.5

        # 有毒内容扣分
        base_score -= toxicity_score * 0.4

        # 隐私泄露扣分
        base_score -= privacy_violations * 0.1

        # 事实错误扣分
        base_score -= factual_errors * 0.05

        return max(0.0, base_score)

    def quick_safety_check(self, content: str) -> bool:
        """快速安全检查"""
        # 简化的快速检查
        bias_score = self.bias_detector.detect_bias(content)
        toxicity_score = self.toxic_detector.detect_toxicity(content)

        return (bias_score <= self.safety_thresholds['bias_score'] and
                toxicity_score <= self.safety_thresholds['toxicity_score'])

    def get_safety_suggestions(self, issues: List[SafetyIssue]) -> List[str]:
        """获取安全建议"""
        suggestions = set()

        for issue in issues:
            if issue.suggestions:
                suggestions.update(issue.suggestions)

        return list(suggestions)

    def debug_privacy_check(self, content: str) -> Dict[str, Any]:
        """调试隐私检查 - 显示所有匹配结果"""
        debug_info = {}

        for privacy_type, config in self.privacy_checker.privacy_patterns.items():
            pattern = config['pattern']
            description = config['description']

            matches = list(re.finditer(pattern, content))
            debug_info[privacy_type] = {
                'pattern': pattern,
                'description': description,
                'matches': [match.group() for match in matches],
                'positions': [(match.start(), match.end()) for match in matches]
            }

        return debug_info

    def test_regex_patterns(self, test_text: str) -> None:
        """测试正则表达式模式"""
        print(f"测试文本: {test_text}")

        # 直接测试邮箱正则表达式
        email_pattern = r'[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}'
        email_matches = re.findall(email_pattern, test_text)
        print(f"邮箱正则表达式直接测试: {email_matches}")

        # 测试其他模式
        phone_pattern = r'1[3-9]\d{9}'
        phone_matches = re.findall(phone_pattern, test_text)
        print(f"手机正则表达式直接测试: {phone_matches}")


# 使用示例
if __name__ == "__main__":
    # 初始化安全检查器
    safety_checker = ContentSafetyChecker()

    # 测试内容
    test_contents = [
        "RAG是一种先进的AI技术，它结合了检索和生成的优势。",
        "男性天生就比女性更适合编程工作。",  # 偏见内容
        "我的手机号是13812345678，邮箱是test@example.com。",  # 隐私信息
        "RAG技术100%准确，永远不会出错。"  # 过于绝对的表述
    ]

    for i, content in enumerate(test_contents, 1):
        print(f"\n测试内容 {i}: {content}")
        print("=" * 50)

        # 综合安全检查
        safety_report = safety_checker.comprehensive_safety_check(content)

        print(f"安全状态: {'✅ 安全' if safety_report.is_safe else '❌ 不安全'}")
        print(f"安全分数: {safety_report.overall_score:.3f}")

        if safety_report.issues:
            print(f"检测到 {len(safety_report.issues)} 个问题:")
            for issue in safety_report.issues:
                print(f"  - {issue.issue_type}: {issue.description}")

        if safety_report.processed_content != content:
            print(f"净化后内容: {safety_report.processed_content}")

        # 添加隐私检查调试信息
        if "手机号" in content or "@" in content:
            debug_info = safety_checker.debug_privacy_check(content)
            print("隐私检查调试信息:")
            for privacy_type, info in debug_info.items():
                if info['matches']:
                    print(f"  {privacy_type}: 匹配 {info['matches']}")
                else:
                    print(f"  {privacy_type}: 无匹配 (模式: {info['pattern']})")

            # 添加直接正则测试
            print("正则表达式直接测试:")
            safety_checker.test_regex_patterns(content)

        print("=" * 50)


测试内容 1: RAG是一种先进的AI技术，它结合了检索和生成的优势。
安全状态: ✅ 安全
安全分数: 1.000

测试内容 2: 男性天生就比女性更适合编程工作。
安全状态: ❌ 不安全
安全分数: 0.750
检测到 1 个问题:
  - bias_gender: 检测到gender偏见: 男性天生就比女性更适合

测试内容 3: 我的手机号是13812345678，邮箱是test@example.com。
安全状态: ❌ 不安全
安全分数: 0.800
检测到 2 个问题:
  - privacy_phone: 检测到手机号码: 138****5678
  - privacy_email: 检测到邮箱地址: te***@example.com
净化后内容: 我的手机号是[手机号码]，邮箱是[邮箱地址]。
隐私检查调试信息:
  phone: 匹配 ['13812345678']
  email: 匹配 ['test@example.com']
  id_card: 无匹配 (模式: \b\d{17}[\dX]\b)
  credit_card: 无匹配 (模式: \b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b)
  bank_account: 无匹配 (模式: \b\d{16,19}\b)
  address: 无匹配 (模式: (.*省.*市.*区.*路.*号|.*街.*弄.*号))
正则表达式直接测试:
测试文本: 我的手机号是13812345678，邮箱是test@example.com。
邮箱正则表达式直接测试: ['test@example.com']
手机正则表达式直接测试: ['13812345678']

测试内容 4: RAG技术100%准确，永远不会出错。
安全状态: ✅ 安全
安全分数: 0.950
检测到 1 个问题:
  - factual_absolute_claim: 检测到过于绝对的表述: 100%准确
