### 第十八章 概率潜在语义分析

In [19]:
import numpy as np
from typing import Tuple
def PLSA(X:np.ndarray, k:int, max_iter: int=100) -> Tuple[np.ndarray, np.ndarray]:
    """
    概率潜在语义分析的EM算法
    :param X 样本数据矩阵
    :param k 话题数量
    :return P(w_i|z_k)
    :return P(z_k|d_j)
    """
    m, n = X.shape

    # p_w_z = np.ones((m, k)) / m
    # p_z_d = np.ones((k, n)) / k
    p_w_z = np.random.rand(m, k)
    p_w_z = p_w_z / p_w_z.sum(axis=0, keepdims=True)  # 按列归一化（每个话题下所有词的概率和为1）
    p_z_d = np.random.rand(k, n)
    p_z_d = p_z_d / p_z_d.sum(axis=0, keepdims=True)  # 按列归一化（每个文档下所有话题的概率和为1）


    num_iter = 0
    while num_iter < max_iter:
        num_iter += 1

        # E步骤
        joint_prob = p_w_z[:, :, np.newaxis] * p_z_d[np.newaxis, :, :]  # shape: (m, k, n)
        denominator = joint_prob.sum(axis=1, keepdims=True)  # shape: (m, 1, n)
        p_z_wd = joint_prob / (denominator + 1e-10)  # 加小值避免除零，shape: (m, k, n)

        # M步骤
        numerator_wz = (X[:, np.newaxis, :] * p_z_wd).sum(axis=2)  # shape: (m, k)
        denominator_wz = numerator_wz.sum(axis=0, keepdims=True)  # shape: (1, k)
        new_p_w_z = (numerator_wz / (denominator_wz + 1e-10)).copy()  # 按列归一化
        
        numerator_zd = (X[:, np.newaxis, :] * p_z_wd).sum(axis=0)  # shape: (k, n)
        denominator_zd = numerator_zd.sum(axis=0, keepdims=True)  # shape: (1, n)
        new_p_z_d = (numerator_zd / (denominator_zd + 1e-10)).copy()  # 按列归一化

        p_w_z, p_z_d = new_p_w_z.copy(), new_p_z_d.copy()
    return p_w_z, p_z_d

X = np.array([[0,0,1,1,0,0,0,0,0],
              [0,0,0,0,0,1,0,0,1],
              [0,1,0,0,0,0,0,1,0],
              [0,0,0,0,0,0,1,0,1],
              [1,0,0,0,0,1,0,0,0],
              [1,1,1,1,1,1,1,1,1],
              [1,0,1,0,0,0,0,0,0],
              [0,0,0,0,0,0,1,0,1],
              [0,0,0,0,0,2,0,0,1],
              [1,0,1,0,0,0,0,1,0],
              [0,0,0,1,1,0,0,0,0]])

PLSA(X, k=4, max_iter=100)

(array([[0.00000000e+000, 0.00000000e+000, 0.00000000e+000,
         1.69429943e-001],
        [0.00000000e+000, 8.54447077e-002, 1.88922108e-001,
         0.00000000e+000],
        [4.00000000e-001, 0.00000000e+000, 0.00000000e+000,
         0.00000000e+000],
        [0.00000000e+000, 3.03504563e-001, 8.33062553e-114,
         0.00000000e+000],
        [0.00000000e+000, 0.00000000e+000, 2.62949463e-001,
         8.15815746e-029],
        [4.00000000e-001, 2.73359907e-001, 1.83322435e-001,
         3.22280229e-001],
        [0.00000000e+000, 0.00000000e+000, 2.17942553e-037,
         1.69429943e-001],
        [0.00000000e+000, 3.03504563e-001, 6.78126378e-114,
         0.00000000e+000],
        [0.00000000e+000, 3.41862603e-002, 3.64805994e-001,
         0.00000000e+000],
        [2.00000000e-001, 0.00000000e+000, 3.59899419e-040,
         1.69429943e-001],
        [0.00000000e+000, 0.00000000e+000, 0.00000000e+000,
         1.69429943e-001]]),
 array([[5.78038404e-10, 1.00000000e+00, 