# Naive Bayes Classifier

## Imports and Initializations

We need to import some common packages and library such as `numpy`, `matplotlib`, etc.

In [1]:
import os, pprint, pickle, math
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
from copy import deepcopy
import math as m

pp = pprint.PrettyPrinter(indent=2)

TRAIN_FILE_PATH = '/Users/sounak/Documents/clg/nlp/nlp-projects/data/gram_error/train.txt'
TEST_FILE_PATH = '/Users/sounak/Documents/clg/nlp/nlp-projects/data/gram_error/dev.txt'

## Utility Functions

These functions are for saving the probability and count dicts as pickle files for faster loading when the script is run later.

In [2]:
def save_obj(obj, name):
    if 'obj' not in os.listdir():
        os.mkdir('obj')
    with open('obj/'+ name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(name):
    try:
        with open('obj/' + name + '.pkl', 'rb') as f:
            return pickle.load(f)
    except:
        return None

## Tokenizing Training Data

In the training data words are already tokenized. We only need to figure out which words contribute to grammatical errors and the classes that the grammatical error belongs to.

In [3]:
lines_train = load_obj('lines_train')
lines_test = load_obj('lines_test')

if not (lines_train and lines_test):
    f = open(TRAIN_FILE_PATH, 'r')
    text = f.read()
    f.close()
    raw_data_train = text.split('\n')
    lines_train = []
    temp = []
    for l in raw_data_train:
        if l == '':
            lines_train.append(temp)
            temp = []
        else:
            temp.append(l)           
    save_obj(lines_train, 'lines_train')

    f = open(TEST_FILE_PATH, 'r')
    text = f.read()
    f.close()    
    raw_data_test = text.split('\n')
    lines_test = []
    temp = []
    for l in raw_data_test:
        if l == '':
            lines_test.append(temp)
            temp = []
        else:
            temp.append(l)           
    save_obj(lines_test, 'lines_test')
    
print('lines_train and lines_test have been loaded')

lines_train and lines_test have been loaded


## Classifer

Here, we are creating the `Counter` with all the counts of the error grams and the correct grams classified according to the training data.

In [4]:
def get_grams(sent, i, n=1):
    grams = tuple()
    for j in range(i-n+1, i+1):
        try:
            if j < 0:
                grams += (' ', )
            else:
                grams += (sent[j].split('   ')[0], )
        except:
            grams += (' ', )
    return grams
        
def find_counts(data, n):
    counts_ = {}
    for sent in data:
        for i, word in enumerate(sent):
            tokens = word.split('   ')
            if len(tokens) == 3:
                try:
                    counts_[tokens[2]].update([get_grams(sent, i, n)])
                except KeyError:
                    counts_[tokens[2]] = Counter([get_grams(sent, i, n)])
                counts_['correct'].update([get_grams(sent, i, n)[:-1] + (tokens[1],)])
            else:
                try:
                    counts_['correct'].update([get_grams(sent, i, n)])
                except:
                    counts_['correct'] = Counter([get_grams(sent, i, n)])
    return counts_
    
counts = load_obj('counts')

if not counts:
    counts = {}
    for i in range(1, 4):
        counts[i] = find_counts(lines_train, i)
    save_obj(counts, 'counts')
print('counts have been loaded')

counts have been loaded


In [5]:
counts[2]['correct']

Counter({(' ', 'so'): 77,
         ('so', 'vhtr'): 1,
         ('vhtr', 'is'): 88,
         ('is', 'environment'): 2,
         ('environment', 'friendly'): 3,
         ('friendly', 'and'): 5,
         ('and', 'the'): 684,
         ('the', 'most'): 195,
         ('most', 'optimized'): 1,
         ('optimized', 'solution'): 1,
         ('solution', 'towards'): 2,
         ('towards', 'green'): 2,
         ('green', 'house'): 17,
         ('house', 'errors'): 1,
         ('effect', '.'): 10,
         (' ', 'retrieved'): 20,
         ('retrieved', 'on'): 9,
         ('on', '5/09/2009.http'): 1,
         ('5/09/2009.http', ':'): 1,
         (':', 'transportation'): 1,
         ('//www.purdue.edu/uns/x/2007a/070314agrawalbiomass.html', ')'): 1,
         (' ', 'for'): 476,
         ('for', 'instance'): 166,
         ('instance', ','): 152,
         (',', 'an'): 75,
         ('an', 'insurance'): 9,
         ('insurance', 'company'): 8,
         ('company', 'will'): 7,
         ('will', 'refuse

## Probabilities

Here, we are calculating the probabilities of the ngrams that are seen in the corpus.

In [35]:
def get_probability(group, error_class, n):
    return counts[n][error_class][group] / sum(counts[n][error_class].values())

## Smoothing

But what about the unseen ngrams? We deal with them using the add-k smoothing where we take the value of `k` as `0.1`.

In [7]:
error_classes = counts[1].keys()
N = {}
S = {}
V = {}

vocabulary = {}
for i in range(1, 4):
    vocabulary[i] = set()
    N[i] = dict()
    for e in error_classes:
        vocabulary[i] = vocabulary[i] | set(counts[i][e].keys())
        N[i][e] = sum(counts[i][e].values())
    V[i] = len(vocabulary[i])

for i in range(1, 4):
    S[i] = {}
    for e in error_classes:
        S[i][e] = sum([counts[i][e][_] for _ in vocabulary[i]]) + V[i]

In [8]:
def get_count(group, error_class, n):
    return counts[n][error_class][group]
    
def get_smoothed_prob(group, error_class, n, factor=0.1):
    return (get_count(group, error_class, n) + factor) / S[n][error_class]

probs = load_obj('probs')
if not probs:
    probs = deepcopy(counts)
    for i in range(1, 4):
        for cls in probs[i].keys():
            for group in probs[i][cls].keys():
                probs[i][cls][group] = get_smoothed_prob(group, cls, i)
        save_obj(probs, 'probs')
print('probs have been loaded')

def get_prob(group, error_class, n):
    if group in probs[n][error_class].keys():
        return probs[n][error_class][group]
    else:
        return get_smoothed_prob(group, error_class, n)

probs have been loaded


## Calculating `P(gram|class)`

We define functions which will return the conditional probability of a gram given a class and the function which returns the sorted probabilites of the classes given the gram.

In [9]:
def prob_class(groups, error_class, n, log=True):
    probs_grams = [get_prob(g, error_class, n) for g in groups]
    product = np.product(probs_grams)
    final_prob = product * (N[n][error_class] / sum(N[n].values()))
    if log:
        return np.log(final_prob)
    return final_prob

def naive_classifier(groups, n):
    temp = [(e, prob_class(groups, e, n)) for e in error_classes]
    return sorted(temp, key=lambda k: k[1], reverse=True)

In [28]:
sent = lines_test[2]
sent_grams = lambda sent, n: [get_grams(sent, i, n) for i in range(len(sent))]

In [36]:
def get_errors(sent):
    tokens = sent_grams(sent, 2)
    tokens = [sent_grams(_, 1) for _ in tokens]
    for token in tokens:
        classes = [(token, naive_classifier(token, 1)[:4]) for token in tokens]
        pp.pprint(classes)

get_errors(sent)

[ ( [(' ',), ('one',)],
    [ ('correct', -21.66288323759064),
      ('Rloc-', -23.77210245289838),
      ('Um', -24.495255181817317),
      ('Wci', -24.513283587630752)]),
  ( [('one',), ('cause',)],
    [ ('correct', -14.003240283025958),
      ('Rloc-', -20.05853038619407),
      ('Wci', -20.118834432958316),
      ('Um', -20.78168311511301)]),
  ( [('cause',), ('for',)],
    [ ('correct', -12.526439875002708),
      ('Wci', -18.320733788391813),
      ('Um', -18.785129233238944),
      ('Rloc-', -18.92340817548617)]),
  ( [('for',), ('this',)],
    [ ('correct', -10.265688996964904),
      ('Um', -16.582499237335814),
      ('Rloc-', -16.748102283857598),
      ('Wci', -16.913064567687186)]),
  ( [('this',), ('is',)],
    [ ('correct', -9.923593730785608),
      ('Um', -16.343986188690504),
      ('Rloc-', -16.377819957581163),
      ('Wci', -17.204546518048744)]),
  ( [('is',), ('attributed',)],
    [ ('correct', -15.13445736207576),
      ('Wci', -19.962142455702384),
      ('Um'

      ('Um', -21.515769880927195),
      ('Wci', -21.78707857624817)]),
  ( [('of',), ('baby',)],
    [ ('correct', -13.361292108275743),
      ('Um', -18.471247443203772),
      ('Wci', -19.3891833034498),
      ('WOinc', -20.52962279973486)]),
  ( [('baby',), ('boomers',)],
    [ ('correct', -20.64130036115993),
      ('Um', -23.848628016892267),
      ('Wci', -26.990585638033533),
      ('WOinc', -27.51526461737407)]),
  ( [('boomers',), (',',)],
    [ ('correct', -14.298997966318908),
      ('Um', -18.958995221079366),
      ('Rloc-', -21.2263368520603),
      ('Mec', -22.121579332080508)]),
  ( [(',',), ('who',)],
    [ ('correct', -10.00718648845957),
      ('Um', -16.21522693737508),
      ('Rloc-', -16.277576961682133),
      ('Wci', -18.846275964725116)]),
  ( [('who',), ('are',)],
    [ ('correct', -11.75053737018663),
      ('Um', -17.68635271399073),
      ('Rloc-', -17.701107128088665),
      ('Wci', -19.445724569375077)]),
  ( [('are',), ('now',)],
    [ ('correct', -12.7

[ ( [(' ',), ('one',)],
    [ ('correct', -21.66288323759064),
      ('Rloc-', -23.77210245289838),
      ('Um', -24.495255181817317),
      ('Wci', -24.513283587630752)]),
  ( [('one',), ('cause',)],
    [ ('correct', -14.003240283025958),
      ('Rloc-', -20.05853038619407),
      ('Wci', -20.118834432958316),
      ('Um', -20.78168311511301)]),
  ( [('cause',), ('for',)],
    [ ('correct', -12.526439875002708),
      ('Wci', -18.320733788391813),
      ('Um', -18.785129233238944),
      ('Rloc-', -18.92340817548617)]),
  ( [('for',), ('this',)],
    [ ('correct', -10.265688996964904),
      ('Um', -16.582499237335814),
      ('Rloc-', -16.748102283857598),
      ('Wci', -16.913064567687186)]),
  ( [('this',), ('is',)],
    [ ('correct', -9.923593730785608),
      ('Um', -16.343986188690504),
      ('Rloc-', -16.377819957581163),
      ('Wci', -17.204546518048744)]),
  ( [('is',), ('attributed',)],
    [ ('correct', -15.13445736207576),
      ('Wci', -19.962142455702384),
      ('Um'

      ('Wci', -23.47227884822447)])]
[ ( [(' ',), ('one',)],
    [ ('correct', -21.66288323759064),
      ('Rloc-', -23.77210245289838),
      ('Um', -24.495255181817317),
      ('Wci', -24.513283587630752)]),
  ( [('one',), ('cause',)],
    [ ('correct', -14.003240283025958),
      ('Rloc-', -20.05853038619407),
      ('Wci', -20.118834432958316),
      ('Um', -20.78168311511301)]),
  ( [('cause',), ('for',)],
    [ ('correct', -12.526439875002708),
      ('Wci', -18.320733788391813),
      ('Um', -18.785129233238944),
      ('Rloc-', -18.92340817548617)]),
  ( [('for',), ('this',)],
    [ ('correct', -10.265688996964904),
      ('Um', -16.582499237335814),
      ('Rloc-', -16.748102283857598),
      ('Wci', -16.913064567687186)]),
  ( [('this',), ('is',)],
    [ ('correct', -9.923593730785608),
      ('Um', -16.343986188690504),
      ('Rloc-', -16.377819957581163),
      ('Wci', -17.204546518048744)]),
  ( [('is',), ('attributed',)],
    [ ('correct', -15.13445736207576),
      ('Wc

      ('Rloc-', -23.142697298013747)]),
  ( [('late',), ('fifties',)],
    [ ('correct', -22.96687765564222),
      ('Wci', -26.343958473108483),
      ('Um', -26.89315045461569),
      ('Rloc-', -29.297555392030166)]),
  ( [('fifties',), ('or',)],
    [ ('correct', -19.094504892831367),
      ('Rloc-', -23.164157349033516),
      ('Um', -23.806248793923405),
      ('Wci', -24.994031756159465)]),
  ( [('or',), ('early',)],
    [ ('correct', -14.794132111514326),
      ('Um', -20.761726356199983),
      ('Rloc-', -20.766262076235144),
      ('Wci', -21.28045968945516)]),
  ( [('early',), ('sixties',)],
    [ ('correct', -24.108922584846972),
      ('Wci', -25.674908844127597),
      ('Um', -26.246523289690636),
      ('Rloc-', -26.899660119231793)]),
  ( [('sixties',), ('.',)],
    [ ('correct', -18.775644014524993),
      ('Um', -21.806115444124398),
      ('Rloc-', -22.19835364847707),
      ('Wci', -23.47227884822447)])]
[ ( [(' ',), ('one',)],
    [ ('correct', -21.66288323759064),


[ ( [(' ',), ('one',)],
    [ ('correct', -21.66288323759064),
      ('Rloc-', -23.77210245289838),
      ('Um', -24.495255181817317),
      ('Wci', -24.513283587630752)]),
  ( [('one',), ('cause',)],
    [ ('correct', -14.003240283025958),
      ('Rloc-', -20.05853038619407),
      ('Wci', -20.118834432958316),
      ('Um', -20.78168311511301)]),
  ( [('cause',), ('for',)],
    [ ('correct', -12.526439875002708),
      ('Wci', -18.320733788391813),
      ('Um', -18.785129233238944),
      ('Rloc-', -18.92340817548617)]),
  ( [('for',), ('this',)],
    [ ('correct', -10.265688996964904),
      ('Um', -16.582499237335814),
      ('Rloc-', -16.748102283857598),
      ('Wci', -16.913064567687186)]),
  ( [('this',), ('is',)],
    [ ('correct', -9.923593730785608),
      ('Um', -16.343986188690504),
      ('Rloc-', -16.377819957581163),
      ('Wci', -17.204546518048744)]),
  ( [('is',), ('attributed',)],
    [ ('correct', -15.13445736207576),
      ('Wci', -19.962142455702384),
      ('Um'

[ ( [(' ',), ('one',)],
    [ ('correct', -21.66288323759064),
      ('Rloc-', -23.77210245289838),
      ('Um', -24.495255181817317),
      ('Wci', -24.513283587630752)]),
  ( [('one',), ('cause',)],
    [ ('correct', -14.003240283025958),
      ('Rloc-', -20.05853038619407),
      ('Wci', -20.118834432958316),
      ('Um', -20.78168311511301)]),
  ( [('cause',), ('for',)],
    [ ('correct', -12.526439875002708),
      ('Wci', -18.320733788391813),
      ('Um', -18.785129233238944),
      ('Rloc-', -18.92340817548617)]),
  ( [('for',), ('this',)],
    [ ('correct', -10.265688996964904),
      ('Um', -16.582499237335814),
      ('Rloc-', -16.748102283857598),
      ('Wci', -16.913064567687186)]),
  ( [('this',), ('is',)],
    [ ('correct', -9.923593730785608),
      ('Um', -16.343986188690504),
      ('Rloc-', -16.377819957581163),
      ('Wci', -17.204546518048744)]),
  ( [('is',), ('attributed',)],
    [ ('correct', -15.13445736207576),
      ('Wci', -19.962142455702384),
      ('Um'

  ( [('baby',), ('boomers',)],
    [ ('correct', -20.64130036115993),
      ('Um', -23.848628016892267),
      ('Wci', -26.990585638033533),
      ('WOinc', -27.51526461737407)]),
  ( [('boomers',), (',',)],
    [ ('correct', -14.298997966318908),
      ('Um', -18.958995221079366),
      ('Rloc-', -21.2263368520603),
      ('Mec', -22.121579332080508)]),
  ( [(',',), ('who',)],
    [ ('correct', -10.00718648845957),
      ('Um', -16.21522693737508),
      ('Rloc-', -16.277576961682133),
      ('Wci', -18.846275964725116)]),
  ( [('who',), ('are',)],
    [ ('correct', -11.75053737018663),
      ('Um', -17.68635271399073),
      ('Rloc-', -17.701107128088665),
      ('Wci', -19.445724569375077)]),
  ( [('are',), ('now',)],
    [ ('correct', -12.79619417632477),
      ('Rloc-', -17.9403368171545),
      ('Vt', -19.104985123156474),
      ('Um', -19.11444420378908)]),
  ( [('now',), ('mostly',)],
    [ ('correct', -18.024923256066092),
      ('Rloc-', -21.543502752994407),
      ('WOinc', 