In [None]:
"""
高维数据密度估计 - Digits 数据集
====================================

作业要求：
1. 使用 sklearn.datasets.load_digits() 数据集（64维）
2. 实现4个密度估计模型：DensityForest, Single Gaussian, GMM, KDE
3. 使用两种核函数（SE和IMQ）计算 MMD²
4. 可视化生成的数字
5. 使用 RandomForestClassifier 检查可辨识性
"""

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.mixture import GaussianMixture
from sklearn.neighbors import KernelDensity
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import warnings
warnings.filterwarnings('ignore')

# 假设 density_forest.py 在同一目录下
from density_forest import DensityForest


# ============================================================================
# 1. MMD² 计算函数（两种核函数）
# ============================================================================

def rbf_kernel(X, Y, gamma=None):
    """
    SE核（Squared Exponential）= RBF核（Radial Basis Function）
    
    k(x, y) = exp(-gamma * ||x - y||²)
    
    参数:
        X: (n, d) 数组
        Y: (m, d) 数组
        gamma: 核参数，如果为None则使用 1/(2*d)
    
    返回:
        K: (n, m) 核矩阵
    """
    if gamma is None:
        gamma = 1.0 / (2 * X.shape[1])
    
    # 计算欧氏距离的平方: ||x - y||² = ||x||² - 2⟨x,y⟩ + ||y||²
    XX = np.sum(X**2, axis=1).reshape(-1, 1)  # (n, 1)
    YY = np.sum(Y**2, axis=1).reshape(1, -1)  # (1, m)
    XY = X @ Y.T  # (n, m)
    
    distances_sq = XX - 2 * XY + YY
    distances_sq = np.maximum(distances_sq, 0)  # 数值稳定性
    
    return np.exp(-gamma * distances_sq)


def imq_kernel(X, Y, c=1.0, beta=0.5):
    """
    IMQ核（Inverse Multi-Quadratic）
    
    k(x, y) = (c² + ||x - y||²)^(-beta)
    
    参数:
        X: (n, d) 数组
        Y: (m, d) 数组
        c: 尺度参数
        beta: 衰减参数
    
    返回:
        K: (n, m) 核矩阵
    """
    # 计算欧氏距离的平方
    XX = np.sum(X**2, axis=1).reshape(-1, 1)
    YY = np.sum(Y**2, axis=1).reshape(1, -1)
    XY = X @ Y.T
    
    distances_sq = XX - 2 * XY + YY
    distances_sq = np.maximum(distances_sq, 0)  # 数值稳定性
    
    return np.power(c**2 + distances_sq, -beta)


def compute_mmd_squared(X, Y, kernel_type='rbf', **kernel_params):
    """
    计算 MMD²（Maximum Mean Discrepancy）
    
    MMD² = E[k(X,X')] - 2*E[k(X,Y)] + E[k(Y,Y')]
    
    衡量两个分布之间的差异，值越小表示分布越接近
    
    参数:
        X: 真实数据 (n, d)
        Y: 生成数据 (m, d)
        kernel_type: 'rbf' 或 'imq'
        **kernel_params: 核函数的参数
    
    返回:
        mmd_squared: MMD²值
    """
    if kernel_type == 'rbf':
        K_XX = rbf_kernel(X, X, **kernel_params)
        K_YY = rbf_kernel(Y, Y, **kernel_params)
        K_XY = rbf_kernel(X, Y, **kernel_params)
    elif kernel_type == 'imq':
        K_XX = imq_kernel(X, X, **kernel_params)
        K_YY = imq_kernel(Y, Y, **kernel_params)
        K_XY = imq_kernel(X, Y, **kernel_params)
    else:
        raise ValueError(f"Unknown kernel type: {kernel_type}")
    
    # 计算 MMD²
    # 注意：对角线元素（i=i'）需要排除，因为是期望值
    n = K_XX.shape[0]
    m = K_YY.shape[0]
    
    # E[k(X,X')] - 排除对角线
    term1 = (K_XX.sum() - np.trace(K_XX)) / (n * (n - 1))
    
    # E[k(Y,Y')] - 排除对角线
    term2 = (K_YY.sum() - np.trace(K_YY)) / (m * (m - 1))
    
    # E[k(X,Y)]
    term3 = K_XY.mean()
    
    mmd_squared = term1 + term2 - 2 * term3
    
    return mmd_squared


# ============================================================================
# 2. 可视化函数
# ============================================================================

def visualize_digits(samples, n_rows=2, n_cols=5, title="Generated Digits"):
    """
    可视化生成的数字（8x8图像）
    
    参数:
        samples: (n, 64) 数组，每行是一个展平的8x8图像
        n_rows: 显示的行数
        n_cols: 显示的列数
        title: 图表标题
    """
    n_samples = min(n_rows * n_cols, len(samples))
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols*1.5, n_rows*1.5))
    axes = axes.ravel()
    
    for i in range(n_samples):
        # 将64维向量reshape成8x8图像
        digit = samples[i].reshape(8, 8)
        
        # 裁剪到合理范围（digits数据集范围是0-16）
        digit = np.clip(digit, 0, 16)
        
        axes[i].imshow(digit, cmap='gray_r', vmin=0, vmax=16)
        axes[i].axis('off')
    
    # 隐藏多余的子图
    for i in range(n_samples, n_rows * n_cols):
        axes[i].axis('off')
    
    plt.suptitle(title, fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()


def compare_real_vs_generated(real_data, generated_data, model_name):
    """
    对比真实数据和生成数据
    
    参数:
        real_data: 真实数据样本
        generated_data: 生成数据样本
        model_name: 模型名称
    """
    fig, axes = plt.subplots(2, 5, figsize=(12, 5))
    
    # 第一行：真实数据
    for i in range(5):
        digit = real_data[i].reshape(8, 8)
        axes[0, i].imshow(digit, cmap='gray_r', vmin=0, vmax=16)
        axes[0, i].axis('off')
        if i == 2:
            axes[0, i].set_title('Real Data', fontsize=12, pad=10)
    
    # 第二行：生成数据
    for i in range(5):
        digit = generated_data[i].reshape(8, 8)
        digit = np.clip(digit, 0, 16)
        axes[1, i].imshow(digit, cmap='gray_r', vmin=0, vmax=16)
        axes[1, i].axis('off')
        if i == 2:
            axes[1, i].set_title(f'Generated ({model_name})', fontsize=12, pad=10)
    
    plt.tight_layout()
    plt.show()


# ============================================================================
# 3. 主要实验流程
# ============================================================================

def main():
    print("=" * 80)
    print("高维数据密度估计实验 - Digits Dataset (64D)")
    print("=" * 80)
    print()
    
    # ------------------------------------------------------------------------
    # 步骤1: 加载数据
    # ------------------------------------------------------------------------
    print("步骤1: 加载 Digits 数据集")
    print("-" * 80)
    
    digits = load_digits()
    X = digits.data  # (1797, 64)
    y = digits.target  # (1797,)
    
    print(f"数据形状: {X.shape}")
    print(f"样本数: {X.shape[0]}")
    print(f"特征维度: {X.shape[1]}")
    print(f"类别数: {len(np.unique(y))}")
    print(f"数值范围: [{X.min():.1f}, {X.max():.1f}]")
    print()
    
    # 可视化一些真实样本
    print("可视化真实数据样本:")
    visualize_digits(X[:10], n_rows=2, n_cols=5, title="Real Digits from Dataset")
    
    # ------------------------------------------------------------------------
    # 步骤2: 定义密度估计模型
    # ------------------------------------------------------------------------
    print("\n步骤2: 定义密度估计模型")
    print("-" * 80)
    
    models = {}
    
    # 模型1: DensityForest
    print("初始化 DensityForest...")
    models['DensityForest'] = DensityForest(
        n_trees=50,      # 树的数量
        n_min=10         # 叶节点最小样本数
    )
    
    # 模型2: Single Gaussian
    print("初始化 Single Gaussian...")
    models['Single_Gaussian'] = GaussianMixture(
        n_components=1,           # 只有1个高斯分量
        covariance_type='full',   # 完整协方差矩阵
        random_state=42
    )
    
    # 模型3: GMM (Gaussian Mixture Model)
    print("初始化 GMM (K=10)...")
    models['GMM_K10'] = GaussianMixture(
        n_components=10,          # 10个高斯分量
        covariance_type='full',
        random_state=42,
        max_iter=200
    )
    
    # 模型4: KDE (Kernel Density Estimation)
    print("初始化 KDE...")
    models['KDE'] = KernelDensity(
        bandwidth=2.0,            # 核宽度
        kernel='gaussian'
    )
    
    print(f"\n共定义 {len(models)} 个模型")
    print()
    
    # ------------------------------------------------------------------------
    # 步骤3: 训练模型并生成样本
    # ------------------------------------------------------------------------
    print("\n步骤3: 训练模型并生成样本")
    print("-" * 80)
    
    results = {}
    generated_samples = {}
    
    for name, model in models.items():
        print(f"\n处理模型: {name}")
        print("  训练中...")
        
        # 训练模型
        model.fit(X)
        
        # 生成样本
        print("  生成样本中...")
        n_samples = len(X)  # 生成与真实数据相同数量的样本
        
        if name == 'DensityForest':
            # DensityForest 有自己的 sample 方法
            generated = model.sample(n_samples=n_samples)
        elif name == 'KDE':
            # KDE 的 sample 方法
            generated = model.sample(n_samples=n_samples)
        else:
            # GMM 和 Single Gaussian 的 sample 方法
            generated, _ = model.sample(n_samples=n_samples)
        
        generated_samples[name] = generated
        
        print(f"  生成样本形状: {generated.shape}")
        print(f"  生成样本范围: [{generated.min():.2f}, {generated.max():.2f}]")
    
    print("\n所有模型训练和采样完成！")
    print()
    
    # ------------------------------------------------------------------------
    # 步骤4: 计算 MMD²（两种核函数）
    # ------------------------------------------------------------------------
    print("\n步骤4: 计算 MMD²")
    print("-" * 80)
    
    # 为高维数据选择合适的核参数
    # gamma for RBF: 使用中位数启发式
    distances = []
    sample_size = min(500, len(X))
    X_sample = X[np.random.choice(len(X), sample_size, replace=False)]
    for i in range(sample_size):
        for j in range(i+1, sample_size):
            distances.append(np.linalg.norm(X_sample[i] - X_sample[j]))
    median_dist = np.median(distances)
    gamma_rbf = 1.0 / (2 * median_dist**2) if median_dist > 0 else 1.0 / (2 * X.shape[1])
    
    print(f"RBF核参数 gamma = {gamma_rbf:.6f}")
    print(f"IMQ核参数 c = 1.0, beta = 0.5")
    print()
    
    for name, generated in generated_samples.items():
        print(f"\n{name}:")
        
        # SE核（RBF核）
        mmd_se = compute_mmd_squared(X, generated, kernel_type='rbf', gamma=gamma_rbf)
        print(f"  MMD² (SE核):  {mmd_se:.6f}")
        
        # IMQ核
        mmd_imq = compute_mmd_squared(X, generated, kernel_type='imq', c=1.0, beta=0.5)
        print(f"  MMD² (IMQ核): {mmd_imq:.6f}")
        
        results[name] = {
            'mmd_se': mmd_se,
            'mmd_imq': mmd_imq,
            'generated': generated
        }
    
    print()
    
    # ------------------------------------------------------------------------
    # 步骤5: 可视化生成的数字
    # ------------------------------------------------------------------------
    print("\n步骤5: 可视化生成的数字")
    print("-" * 80)
    
    for name, generated in generated_samples.items():
        print(f"\n可视化 {name} 生成的数字:")
        visualize_digits(generated[:10], n_rows=2, n_cols=5, 
                        title=f"Generated Digits - {name}")
        
        # 对比真实数据和生成数据
        compare_real_vs_generated(X[:5], generated[:5], name)
    
    # ------------------------------------------------------------------------
    # 步骤6: 可辨识性检查（RandomForestClassifier）
    # ------------------------------------------------------------------------
    print("\n步骤6: 可辨识性检查")
    print("-" * 80)
    print()
    
    print("训练 RandomForestClassifier 用于数字识别...")
    
    # 划分训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    
    # 训练分类器
    clf = RandomForestClassifier(
        n_estimators=100,
        max_depth=20,
        random_state=42,
        n_jobs=-1
    )
    clf.fit(X_train, y_train)
    
    # 在测试集上评估
    train_acc = clf.score(X_train, y_train)
    test_acc = clf.score(X_test, y_test)
    print(f"分类器在训练集上的准确率: {train_acc:.4f}")
    print(f"分类器在测试集上的准确率: {test_acc:.4f}")
    print()
    
    print("=" * 80)
    print("可辨识性测试说明:")
    print("=" * 80)
    print("我们使用训练好的分类器来评估生成数字的质量：")
    print("1. 对每个生成模型，用分类器预测生成样本的类别")
    print("2. 计算预测的置信度和类别分布的均匀性")
    print("3. 如果生成的数字质量高，分类器应该能：")
    print("   - 以较高置信度识别出数字")
    print("   - 生成的10个数字类别分布应该相对均匀")
    print("=" * 80)
    print()
    
    for name, generated in generated_samples.items():
        print(f"\n{name} 的可辨识性分析:")
        print("-" * 50)
        
        # 预测生成样本的类别
        pred_labels = clf.predict(generated)
        pred_proba = clf.predict_proba(generated)
        
        # 统计类别分布
        unique, counts = np.unique(pred_labels, return_counts=True)
        class_dist = dict(zip(unique, counts))
        
        print("\n生成样本的类别分布:")
        for digit in range(10):
            count = class_dist.get(digit, 0)
            percentage = count / len(generated) * 100
            print(f"  数字 {digit}: {count:4d} 个 ({percentage:5.1f}%)")
        
        # 计算分布的均匀性（使用标准差）
        expected_count = len(generated) / 10
        std_dev = np.std(list(class_dist.values()))
        uniformity = 1 - (std_dev / expected_count)  # 0-1之间，越接近1越均匀
        
        # 计算平均预测置信度
        max_proba = pred_proba.max(axis=1)
        avg_confidence = max_proba.mean()
        
        print(f"\n分布均匀性: {uniformity:.4f} (1.0 = 完全均匀)")
        print(f"平均预测置信度: {avg_confidence:.4f} (1.0 = 完全确定)")
        
        # 综合可辨识性分数（结合均匀性和置信度）
        recognizability = (uniformity + avg_confidence) / 2
        
        print(f"综合可辨识性分数: {recognizability:.4f}")
        
        results[name]['recognizability'] = recognizability
        results[name]['uniformity'] = uniformity
        results[name]['confidence'] = avg_confidence
        results[name]['class_distribution'] = class_dist
    
    # ------------------------------------------------------------------------
    # 步骤7: 总结结果
    # ------------------------------------------------------------------------
    print("\n" + "=" * 80)
    print("最终结果总结")
    print("=" * 80)
    print()
    
    # 创建结果表格
    print(f"{'模型':<20} {'MMD²(SE)':<12} {'MMD²(IMQ)':<12} {'可辨识性':<12} {'置信度':<12}")
    print("-" * 80)
    
    for name in models.keys():
        mmd_se = results[name]['mmd_se']
        mmd_imq = results[name]['mmd_imq']
        recog = results[name]['recognizability']
        conf = results[name]['confidence']
        print(f"{name:<20} {mmd_se:<12.6f} {mmd_imq:<12.6f} {recog:<12.4f} {conf:<12.4f}")
    
    print()
    
    # 找出最佳模型
    best_mmd_se = min(results.items(), key=lambda x: x[1]['mmd_se'])
    best_mmd_imq = min(results.items(), key=lambda x: x[1]['mmd_imq'])
    best_recog = max(results.items(), key=lambda x: x[1]['recognizability'])
    
    print("最佳模型：")
    print(f"  最低 MMD² (SE核):  {best_mmd_se[0]} ({best_mmd_se[1]['mmd_se']:.6f})")
    print(f"  最低 MMD² (IMQ核): {best_mmd_imq[0]} ({best_mmd_imq[1]['mmd_imq']:.6f})")
    print(f"  最高可辨识性:      {best_recog[0]} ({best_recog[1]['recognizability']:.4f})")
    
    print()
    print("=" * 80)
    print("实验完成！")
    print("=" * 80)
    
    return results


if __name__ == "__main__":
    results = main()