# 朴素贝叶斯

In [1]:
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

In [14]:
class GaussianNB:
    def fit(self, X, y):
        self.classes = np.unique(y)  # 所有类别
        self.class_count = {c : np.sum(y == c) for c in self.classes}  # 各类别样本数
        # 计算各类别的均值、方差
        self.mean = {c : X[y == c].mean(axis = 0) for c in self.classes}
        self.var = {c: X[y == c].var(axis = 0) + 1e-6 for c in self.classes}

    def _gaussian_prob(self, x, mean, var):
        x = np.exp(-(x - mean) ** 2 / (2 * var))
        x = x / np.sqrt(2 * np.pi * var)
        return x

    def predict(self, X):
        predictions = []
        for x in X:
            # 遍历一条样本
            log_probs = {}
            # 计算各类别的后验概率
            for c in self.classes:
                # 计算先验概率
                prior = np.log(self.class_count[c] / sum(self.class_count.values()))
                likelihood = np.sum(np.log(self._gaussian_prob(x, self.mean[c], self.var[c])))
                log_probs[c] = prior + likelihood
            # 选择后验概率最大的类
            predictions.append(self.classes[np.argmax(log_probs)])

        return np.array(predictions)

In [16]:
if __name__ == '__main__':
    # 加载数据集
    data = load_iris()
    X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.3, random_state=42)
    # 训练与预测
    model = GaussianNB()
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    # 评估准确率
    print(f'朴素贝叶斯模型分类准确率{accuracy_score(y_test, y_pred)}')

朴素贝叶斯模型分类准确率0.4222222222222222


> example

In [2]:
X = np.array([[100, 5.0], [110, 5.2], [95, 4.8], [200, 3.0], [190, 2.8], [210, 3.2]])
y = np.array(['苹果', '苹果', '苹果', '橙子', '橙子', '橙子'])

In [3]:
y == '苹果'

array([ True,  True,  True, False, False, False])

In [5]:
select = X[y == '苹果']
type(select)

numpy.ndarray

In [6]:
select

array([[100. ,   5. ],
       [110. ,   5.2],
       [ 95. ,   4.8]])

In [7]:
select.mean(axis=0)

array([101.66666667,   5.        ])

In [9]:
np.unique(y)[0]

'橙子'