In [1]:
from scipy.stats import norm
import numpy as np
from collections import Counter


class GaussianNB:
    def __init__(self,priors=None):
        self.priors=priors
        
    def fit(self,X,y):
        if self.priors is None:
            self.priors=self.get_priors(y)
        else:
            self.priors={label:proba for label,proba in zip(set(y),self.priors)}
        self.groups={label:[] for label in set(y)}
        self.labels=list(set(y))
        for data,label in zip(X,y):
            self.groups[label].append(data)         
        self.proba_cal={label:[0,1] for label in set(y)}
        for label in set(y):
            self.proba_cal[label][0]=np.mean(X[np.where(y==label)],axis=0)
            self.proba_cal[label][1]=np.std(X[np.where(y==label)],axis=0)
        return self

     
    def predict(self,X):
        m,n=X.shape
        log_proba=np.ones((m,len(self.labels)))       
        for index,i in enumerate(X):
            proba1=np.array([self.cal_likehood_gaussian_prob(i,
                            self.proba_cal[label][0],
                            self.proba_cal[label][1]) for label in self.labels])
            proba1=proba1+np.array([np.log(self.priors[label]) for label in self.labels])    
            log_proba[index,:]=proba1
        pred=np.argmax(log_proba,axis=1)
        return pred
    
    @staticmethod
    def get_priors(y):
        m=len(y)
        priors={label:pro/m for label,pro in Counter(y).items()}
        return priors
    

    @staticmethod
    def cal_likehood_gaussian_prob(data,means,stds):
        prob=[norm.pdf(j,means[i],stds[i])  for i,j in enumerate(data)]
        prob=np.clip(prob,1e-32,np.inf)
        probs=np.log(prob).sum()
        return probs    

In [2]:
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
X,y=load_breast_cancer(return_X_y=True)
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3)

In [3]:
nb=GaussianNB(priors=[0.4,0.6]).fit(X_train,y_train)

In [4]:
from sklearn.metrics import accuracy_score
accuracy_score(y_test,nb.predict(X_test))

0.9122807017543859