# ナイーブベイズ推定器の評価関数
（正規化項は省略、事前分布に一様分布を仮定）

$$
\begin{align}
    P(y|x_1, \dots, x_m)\
        &\propto P(x_1, \dots, x_m|y)\\
        &= \prod_{i=1}^{m} P(x_i|y) \\
    \log P(y|x_1, \dots, x_m)\
        &\propto \sum_{i=1}^{m} \log P(x_i|y)\\
        &=\sum_{i \in \text{categorical}}\log P(x_i|y)\
            + \sum_{i \in \text{numerical}}\log P(x_i|y) \\
        &=\sum_{i \in \text{categorical}} \left\{ \
                \log c(x_i|y) - \log c(y) \
            \right\} \
            + \sum_{i \in \text{numerical}}\log X(x_i|y)
\end{align}
$$
$c$は出現回数、$X$は正規分布に従う

In [133]:
import seaborn
import pandas as pd
import pandas_ml
import numpy as np
from scipy.stats import norm

from sklearn.metrics import confusion_matrix

pd.options.display.max_rows = 8

In [79]:
df = seaborn.load_dataset('titanic')
mf = pandas_ml.ModelFrame(df, target='survived')
train, test = mf.model_selection.train_test_split(test_size=0.3)

In [53]:
mf.target.unique()

array([0, 1])

In [70]:
train

Unnamed: 0,survived,pclass,sex,age,sibsp,parch,fare,embarked,class,who,adult_male,deck,embark_town,alive,alone
115,0,3,male,21,0,0,7.925,S,Third,man,True,,Southampton,no,True
19,1,3,female,,0,0,7.225,C,Third,woman,False,,Cherbourg,yes,True
509,1,3,male,26,0,0,56.4958,S,Third,man,True,,Southampton,yes,True
704,0,3,male,26,1,0,7.8542,S,Third,man,True,,Southampton,no,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
547,1,2,male,,0,0,13.8625,C,Second,man,True,,Cherbourg,yes,True
154,0,3,male,,0,0,7.3125,S,Third,man,True,,Southampton,no,True
690,1,1,male,31,1,0,57,S,First,man,True,B,Southampton,yes,False
701,1,1,male,35,0,0,26.2875,S,First,man,True,E,Southampton,yes,True


In [143]:
class HbkNaiveBayes:
    def __init__(self, model_frame, numerical=None):
        if not isinstance(model_frame, pandas_ml.core.frame.ModelFrame):
            raise ValueError()

        self._target_name = model_frame.target.name
        self._class_labels = model_frame.target.unique()
        
        self._numerical = numerical
            
        if not numerical is None:
            self._categorical = set(mf.data.columns) - set(numerical) 
        else:
            self._categorical = set(mf.data.columns)
            
    def train(self, train_mf):
        if not isinstance(train_mf, pandas_ml.core.frame.ModelFrame):
            raise ValueError()
            
        self._train_mf = train_mf
        
    def _condition(self, col_name, value): # -> list of boolean
        return self._train_mf[col_name] == value
        
    def _estimate_by_single_row(self, row):
        def eval_func(class_label, row):
            cl = class_label
            
            eval_value = 0.0
            for c in self._categorical:
                eval_value += np.log(
                    (self._condition(self._target_name, cl) & self._condition(c, row[c])).sum() + 1
                )
                eval_value -= np.log(
                    (self._condition(self._target_name, cl)).sum() + 1
                )
                
            for n in self._numerical:
                cond = self._condition(self._target_name, cl)
                mu = self._train_mf[cond][n].mean()
                sigma = self._train_mf[cond][n].std()
                eval_value += norm.logpdf(row[n], loc=mu, scale=sigma)
                
            return eval_value
            
        eval_values = np.array([eval_func(i, row) for i in self._class_labels])
        return self._class_labels[eval_values.argmax()]
    
    def estimate(self, test):
        return test.apply(self._estimate_by_single_row, axis=1)      
        
hbk = HbkNaiveBayes(mf, numerical=['age', 'fare'])
hbk.train(train)

In [144]:
est = hbk.estimate(test)
est

596    0
171    0
349    0
823    1
      ..
527    0
190    1
567    0
343    0
Length: 268, dtype: int64

In [145]:
test['survived']

596    1
171    0
349    0
823    1
      ..
527    0
190    1
567    0
343    0
Name: survived, Length: 268, dtype: int64

In [146]:
confusion_matrix(list(test['survived']), list(est))

array([[172,   3],
       [ 16,  77]])