In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

In [None]:
class MultinomialNB(object):
  def fit(self, X, y, alpha=1):
    n_samples, n_features = X.shape
    self.classes = np.unique(y)
    n_classes = len(self.classes)

    self.w = np.zeros((n_classes, n_features), dtype=np.float64)
    self.w_prior = np.zeros(n_classes, dtype=np.float64)

    for idx, c in enumerate(self.classes):
      X_c = X[y == c]
      total_count = np.sum(np.sum(X_c, axis=1))
      self.w[idx, :] = (np.sum(X_c, axis=0) + alpha) / (total_count + n_features * alpha)
      self.w_prior[idx] = (X_c.shape[0] + alpha) / (float(n_samples) + alpha * n_classes)
    
    print('W: ', self.w)
    print('Prior: ', self.w_prior)
    
  def log_likelyhood_prior_prod(self, X):
    llp = X @ np.log(self.w).T + np.log(self.w_prior)
    return llp
  
  def predict(self, X):
    llp = self.log_likelyhood_prior_prod(X)
    return np.argmax(llp, axis=1)
  
  def predict_probability(self, X):
    q = self.log_likelyhood_prior_prod(X)
    prob = np.exp(q) / np.expand_dims(np.sum(np.exp(q),axis=1),axis=1)
    return prob

In [None]:
rng = np.random.RandomState(1)
X = rng.randint(5, size=(1000,5))
y = rng.randint(2, size=(1000,))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
print(X_train.shape)
print(y_train.shape)

(800, 5)
(800,)


In [None]:
multinomial_nb = MultinomialNB()
multinomial_nb.fit(X_train, y_train)

W:  [[0.21320706 0.18454057 0.19119529 0.20220118 0.2088559 ]
 [0.20838352 0.20067454 0.1992291  0.1989882  0.19272464]]
Prior:  [0.48877805 0.51122195]


In [None]:
print(classification_report(y_test, multinomial_nb.predict(X_test)))

              precision    recall  f1-score   support

           0       0.46      0.35      0.40        94
           1       0.53      0.64      0.58       106

    accuracy                           0.51       200
   macro avg       0.50      0.50      0.49       200
weighted avg       0.50      0.51      0.49       200

