### Implement Naive Bayes 

Following https://towardsdatascience.com/text-classification-using-naive-bayes-theory-a-working-example-2ef4b7eb7d5a

In [1]:
import pickle
import numpy as np, pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
# from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import make_pipeline
import sklearn.metrics as metrics
sns.set() # use seaborn plotting style


In [2]:
train_dict = pickle.load(open('data/train.pkl', 'rb'))
test_dict = pickle.load(open('data/test.pkl', 'rb'))

In [3]:
train_data = [train_dict[i]['input'] for i in train_dict]
train_target = [train_dict[i]['label'] for i in train_dict]

test_data = [test_dict[i]['input'] for i in test_dict]
test_target = [test_dict[i]['label'] for i in test_dict]

In [4]:
# Build the model
model = make_pipeline(TfidfVectorizer(), MultinomialNB())

# Train the model using the training data
model.fit(train_data, train_target)

# Predict the categories of the test data
predicted_categories = model.predict(test_target)


In [5]:
# Generate confusion matrix

print("The accuracy is {}".format(metrics.accuracy_score(test_target, predicted_categories)))
print("The recall is {}".format(metrics.recall_score(test_target, predicted_categories, average='micro')))
print("The F1 score is {}".format(metrics.f1_score(test_target, predicted_categories, average='micro')))
print("The precision is {}".format(metrics.precision_score(test_target, predicted_categories, average='micro')))


The accuracy is 0.5249008373732922
The recall is 0.5249008373732922
The F1 score is 0.5249008373732922
The precision is 0.5249008373732922


In [8]:
print(metrics.classification_report(test_target, predicted_categories))

                                                       precision    recall  f1-score   support

                                                            0.00      0.00      0.00       773
                                          Case Report       1.00      1.00      1.00       297
                                            Diagnosis       1.00      1.00      1.00       374
                                  Diagnosis;Mechanism       0.00      0.00      0.00        17
          Diagnosis;Mechanism;Prevention;Transmission       0.00      0.00      0.00         1
                        Diagnosis;Mechanism;Treatment       0.00      0.00      0.00         1
                                 Diagnosis;Prevention       0.00      0.00      0.00        22
                       Diagnosis;Prevention;Mechanism       0.00      0.00      0.00         1
                    Diagnosis;Prevention;Transmission       0.00      0.00      0.00         2
                       Diagnosis;Prevention;Treat

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [7]:
# %%time

# # plot the confusion matrix

# mat = metrics.confusion_matrix(test_target, predicted_categories)
# sns.heatmap(mat.T, square = True, annot=True, fmt = "d", xticklabels=train_target,yticklabels=train_target)
# plt.xlabel("true labels")
# plt.ylabel("predicted label")
# plt.show()