In [275]:
from mnist import MNIST
from scipy.stats import norm
import numpy as np
import pandas as pd
import math
from tqdm import tqdm
tqdm.pandas()

## Read train and test data

In [276]:
# http://yann.lecun.com/exdb/mnist/

mndata = MNIST('data/mnist_data_files')
mndata.gz=True
train_images, train_labels = mndata.load_training()

### process data and threashold

In [277]:
train_images = np.array(train_images)
train_labels = np.array(train_labels)

In [278]:
train_images = (pd.DataFrame(train_images) > 127).astype(np.int)
train_labels = pd.DataFrame(train_labels,columns=["label"])

In [279]:
# 60000 rows 28*28 pixels
print(train_images.shape) 
print(train_labels.shape)

(60000, 784)
(60000, 1)


In [280]:
value_counts = train_labels["label"].value_counts(normalize=True)
p_train_labels = pd.DataFrame()
p_train_labels['label'] = value_counts.index
p_train_labels['probability'] = value_counts.values

In [281]:
p_train_labels.head(10)

Unnamed: 0,label,probability
0,1,0.112367
1,7,0.104417
2,3,0.102183
3,2,0.0993
4,9,0.09915
5,0,0.098717
6,6,0.098633
7,8,0.097517
8,4,0.097367
9,5,0.09035


In [282]:
test_images, test_labels = mndata.load_testing()

In [283]:
test_images = np.array(test_images)
test_labels = np.array(test_labels)

In [284]:
test_images = (pd.DataFrame(test_images) > 127).astype(np.int)
test_labels = pd.DataFrame(test_labels,columns=["label"])

In [285]:
print(test_images.shape)
print(test_labels[1:10])

(10000, 784)
   label
1      2
2      1
3      0
4      4
5      1
6      4
7      9
8      5
9      9


In [286]:
train_df = train_images.join(train_labels)
test_df = test_images.join(test_labels)
test_df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,775,776,777,778,779,780,781,782,783,label
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,7
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,2
2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,4


In [287]:
def get_params(label_group):
    images_df = label_group.drop(['label'], axis=1)
    return images_df.apply(lambda x: np.asarray(norm.fit(x)), axis=0)

label_params = train_df.groupby(['label']).apply(get_params)

In [288]:
assert label_params.shape == (20, 784)
label_params.shape

(20, 784)

In [289]:
# label_params.loc[0, :].loc[0]

## Naive Bayes - normal distribution - untouched

In [290]:
# params_ = np.apply_along_axis(lambda x: norm.fit(x), 1, train_images.T)

In [291]:
def calculate_likelihood_for_each_label(p_label, feature_vec, params):
    means = params.loc[0]
    stds = params.loc[1]
    likelihood = np.nansum(norm.logpdf(feature_vec, means, stds))
    likelihood = likelihood + np.log(p_label['probability'])
    return np.array([p_label['label'], likelihood])
    
def get_predict(likelihoods):
    max_row = [float("-inf"), float("-inf")]
    for likelihood in likelihoods:
        if(likelihood[1] > max_row[1]):
            max_row = likelihood
    return max_row[0]

### Evaluate

In [292]:
def predict(image):
    likelihoods = []
    for index, p_train_label in p_train_labels.iterrows():
        params = label_params.loc[p_train_label['label'], :]
        likelihoods.append(calculate_likelihood_for_each_label(p_train_label, image, params))
        
    return get_predict(np.array(likelihoods))

# predict(test_images.loc[1])


In [293]:
predicts = test_images.progress_apply(predict, axis=1)


  x = np.asarray((x - loc)/scale, dtype=dtyp)
  x = np.asarray((x - loc)/scale, dtype=dtyp)
  return (self.a <= x) & (x <= self.b)
  return (self.a <= x) & (x <= self.b)

  0%|          | 2/10000 [00:00<08:43, 19.10it/s][A
  0%|          | 11/10000 [00:00<06:41, 24.88it/s][A
  0%|          | 20/10000 [00:00<05:14, 31.70it/s][A
  0%|          | 29/10000 [00:00<04:13, 39.30it/s][A
  0%|          | 38/10000 [00:00<03:30, 47.23it/s][A
  0%|          | 47/10000 [00:00<03:02, 54.41it/s][A
  1%|          | 56/10000 [00:00<02:42, 61.24it/s][A
  1%|          | 65/10000 [00:00<02:27, 67.46it/s][A
  1%|          | 74/10000 [00:00<02:16, 72.91it/s][A
  1%|          | 83/10000 [00:01<02:09, 76.46it/s][A
  1%|          | 93/10000 [00:01<02:01, 81.27it/s][A
  1%|          | 102/10000 [00:01<02:02, 80.74it/s][A
  1%|          | 112/10000 [00:01<01:57, 83.81it/s][A
  1%|          | 121/10000 [00:01<01:55, 85.40it/s][A
  1%|▏         | 130/10000 [00:01<01:54, 86.15it/s][A
  1%|▏         

 12%|█▏        | 1197/10000 [00:14<01:40, 87.54it/s][A
 12%|█▏        | 1206/10000 [00:14<01:43, 85.18it/s][A
 12%|█▏        | 1216/10000 [00:14<01:39, 87.85it/s][A
 12%|█▏        | 1226/10000 [00:14<01:38, 89.15it/s][A
 12%|█▏        | 1235/10000 [00:14<01:39, 88.41it/s][A
 12%|█▏        | 1245/10000 [00:14<01:37, 89.63it/s][A
 13%|█▎        | 1255/10000 [00:14<01:35, 91.72it/s][A
 13%|█▎        | 1265/10000 [00:14<01:34, 92.56it/s][A
 13%|█▎        | 1275/10000 [00:15<01:34, 92.65it/s][A
 13%|█▎        | 1285/10000 [00:15<01:34, 92.57it/s][A
 13%|█▎        | 1295/10000 [00:15<01:34, 92.60it/s][A
 13%|█▎        | 1305/10000 [00:15<01:32, 93.71it/s][A
 13%|█▎        | 1315/10000 [00:15<01:33, 93.22it/s][A
 13%|█▎        | 1325/10000 [00:15<01:33, 92.35it/s][A
 13%|█▎        | 1335/10000 [00:15<01:33, 92.53it/s][A
 13%|█▎        | 1345/10000 [00:15<01:32, 93.77it/s][A
 14%|█▎        | 1355/10000 [00:15<01:32, 93.42it/s][A
 14%|█▎        | 1365/10000 [00:16<01:32, 93.04i

 26%|██▌       | 2617/10000 [00:30<01:27, 84.52it/s][A
 26%|██▋       | 2626/10000 [00:30<01:26, 85.39it/s][A
 26%|██▋       | 2635/10000 [00:30<01:25, 86.62it/s][A
 26%|██▋       | 2645/10000 [00:30<01:22, 88.70it/s][A
 27%|██▋       | 2654/10000 [00:30<01:22, 88.67it/s][A
 27%|██▋       | 2663/10000 [00:30<01:23, 87.82it/s][A
 27%|██▋       | 2672/10000 [00:30<01:25, 85.81it/s][A
 27%|██▋       | 2681/10000 [00:30<01:26, 84.53it/s][A
 27%|██▋       | 2690/10000 [00:30<01:25, 85.97it/s][A
 27%|██▋       | 2700/10000 [00:31<01:22, 88.60it/s][A
 27%|██▋       | 2710/10000 [00:31<01:20, 90.14it/s][A
 27%|██▋       | 2720/10000 [00:31<01:23, 87.24it/s][A
 27%|██▋       | 2729/10000 [00:31<01:27, 83.28it/s][A
 27%|██▋       | 2738/10000 [00:31<01:31, 79.61it/s][A
 27%|██▋       | 2747/10000 [00:31<01:30, 80.55it/s][A
 28%|██▊       | 2756/10000 [00:31<01:32, 78.06it/s][A
 28%|██▊       | 2766/10000 [00:31<01:27, 82.74it/s][A
 28%|██▊       | 2775/10000 [00:31<01:28, 81.53i

 41%|████      | 4072/10000 [00:45<00:57, 103.60it/s][A
 41%|████      | 4083/10000 [00:46<00:57, 103.76it/s][A
 41%|████      | 4094/10000 [00:46<00:56, 103.98it/s][A
 41%|████      | 4105/10000 [00:46<00:56, 104.27it/s][A
 41%|████      | 4116/10000 [00:46<00:56, 104.31it/s][A
 41%|████▏     | 4127/10000 [00:46<00:56, 104.14it/s][A
 41%|████▏     | 4138/10000 [00:46<00:56, 104.00it/s][A
 41%|████▏     | 4149/10000 [00:46<00:56, 103.62it/s][A
 42%|████▏     | 4160/10000 [00:46<00:56, 103.91it/s][A
 42%|████▏     | 4171/10000 [00:46<00:55, 104.13it/s][A
 42%|████▏     | 4182/10000 [00:47<00:55, 104.40it/s][A
 42%|████▏     | 4193/10000 [00:47<00:55, 104.44it/s][A
 42%|████▏     | 4204/10000 [00:47<00:55, 104.56it/s][A
 42%|████▏     | 4215/10000 [00:47<00:55, 104.36it/s][A
 42%|████▏     | 4226/10000 [00:47<00:55, 104.26it/s][A
 42%|████▏     | 4237/10000 [00:47<00:55, 104.27it/s][A
 42%|████▏     | 4248/10000 [00:47<00:55, 103.27it/s][A
 43%|████▎     | 4259/10000 [00

 56%|█████▋    | 5629/10000 [01:01<00:43, 100.94it/s][A
 56%|█████▋    | 5640/10000 [01:01<00:43, 100.60it/s][A
 57%|█████▋    | 5651/10000 [01:01<00:43, 100.67it/s][A
 57%|█████▋    | 5662/10000 [01:01<00:43, 100.41it/s][A
 57%|█████▋    | 5673/10000 [01:01<00:42, 100.80it/s][A
 57%|█████▋    | 5684/10000 [01:01<00:42, 100.47it/s][A
 57%|█████▋    | 5695/10000 [01:02<00:42, 100.73it/s][A
 57%|█████▋    | 5706/10000 [01:02<00:42, 101.12it/s][A
 57%|█████▋    | 5717/10000 [01:02<00:42, 101.17it/s][A
 57%|█████▋    | 5728/10000 [01:02<00:42, 101.14it/s][A
 57%|█████▋    | 5739/10000 [01:02<00:42, 100.64it/s][A
 57%|█████▊    | 5750/10000 [01:02<00:42, 100.31it/s][A
 58%|█████▊    | 5761/10000 [01:02<00:42, 100.17it/s][A
 58%|█████▊    | 5772/10000 [01:02<00:42, 100.09it/s][A
 58%|█████▊    | 5783/10000 [01:02<00:42, 99.97it/s] [A
 58%|█████▊    | 5793/10000 [01:03<00:42, 99.93it/s][A
 58%|█████▊    | 5804/10000 [01:03<00:41, 100.68it/s][A
 58%|█████▊    | 5815/10000 [01:

 72%|███████▏  | 7197/10000 [01:16<00:27, 103.14it/s][A
 72%|███████▏  | 7208/10000 [01:16<00:26, 103.65it/s][A
 72%|███████▏  | 7219/10000 [01:17<00:27, 102.99it/s][A
 72%|███████▏  | 7230/10000 [01:17<00:26, 102.96it/s][A
 72%|███████▏  | 7241/10000 [01:17<00:26, 103.63it/s][A
 73%|███████▎  | 7252/10000 [01:17<00:26, 104.07it/s][A
 73%|███████▎  | 7263/10000 [01:17<00:26, 103.82it/s][A
 73%|███████▎  | 7274/10000 [01:17<00:26, 103.66it/s][A
 73%|███████▎  | 7285/10000 [01:17<00:26, 103.13it/s][A
 73%|███████▎  | 7296/10000 [01:17<00:26, 102.99it/s][A
 73%|███████▎  | 7307/10000 [01:17<00:26, 103.47it/s][A
 73%|███████▎  | 7318/10000 [01:18<00:27, 98.58it/s] [A
 73%|███████▎  | 7329/10000 [01:18<00:26, 100.27it/s][A
 73%|███████▎  | 7340/10000 [01:18<00:26, 101.61it/s][A
 74%|███████▎  | 7351/10000 [01:18<00:26, 101.39it/s][A
 74%|███████▎  | 7362/10000 [01:18<00:26, 101.19it/s][A
 74%|███████▎  | 7373/10000 [01:18<00:26, 97.51it/s] [A
 74%|███████▍  | 7383/10000 [01

 88%|████████▊ | 8769/10000 [01:32<00:12, 101.87it/s][A
 88%|████████▊ | 8780/10000 [01:32<00:11, 102.10it/s][A
 88%|████████▊ | 8791/10000 [01:32<00:11, 101.31it/s][A
 88%|████████▊ | 8802/10000 [01:32<00:11, 101.46it/s][A
 88%|████████▊ | 8813/10000 [01:32<00:11, 102.39it/s][A
 88%|████████▊ | 8824/10000 [01:32<00:11, 102.42it/s][A
 88%|████████▊ | 8835/10000 [01:32<00:11, 102.82it/s][A
 88%|████████▊ | 8846/10000 [01:32<00:11, 103.20it/s][A
 89%|████████▊ | 8857/10000 [01:32<00:11, 103.08it/s][A
 89%|████████▊ | 8868/10000 [01:33<00:10, 103.10it/s][A
 89%|████████▉ | 8879/10000 [01:33<00:10, 103.45it/s][A
 89%|████████▉ | 8890/10000 [01:33<00:10, 104.00it/s][A
 89%|████████▉ | 8901/10000 [01:33<00:10, 104.15it/s][A
 89%|████████▉ | 8912/10000 [01:33<00:10, 104.08it/s][A
 89%|████████▉ | 8923/10000 [01:33<00:10, 104.11it/s][A
 89%|████████▉ | 8934/10000 [01:33<00:10, 103.90it/s][A
 89%|████████▉ | 8945/10000 [01:33<00:10, 103.68it/s][A
 90%|████████▉ | 8956/10000 [01

In [294]:
predicts.head()

0    7.0
1    2.0
2    1.0
3    0.0
4    4.0
dtype: float64

In [295]:
def calculate_accuracy(actual, predicts):
    TP = 0
    num_total = len(actual)
    for i in range(num_total):
        if actual[i] == predicts[i]:
            TP = TP + 1
    return TP/num_total

In [296]:
calculate_accuracy(np.array(test_labels), np.array(predicts))

0.7824