Skip to content

Commit

Permalink
Add NLTKClassifier wrapper class
Browse files Browse the repository at this point in the history
This makes it easier to add more classifiers from nltk.classify
  • Loading branch information
sloria committed Sep 4, 2013
1 parent 9d9faac commit daac110
Showing 1 changed file with 44 additions and 26 deletions.
70 changes: 44 additions & 26 deletions text/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,31 @@ def extract_features(self, text):
except TypeError:
return self.feature_extractor(text)


class NLTKClassifier(BaseClassifier):

"""An abstract class that wraps around the nltk.classify module."""

nltk_class = None # This must be a class within nltk.classify

@cached_property
def classifier(self):
'''The classifier.'''
self.train_features = [(self.extract_features(d), c) for d, c in self.train_set]
try:
return self.nltk_class.train(self.train_features)
except AttributeError: # nltk_class has not been defined
raise ValueError("NLTKClassifier must have a nltk_class"
" variable that is not None.")

def classify(self, text):
'''Classifies the text.
:param text: A string of text.
'''
text_features = self.extract_features(text)
return self.classifier.classify(text_features)

def accuracy(self, test_set, format=None):
'''Compute the accuracy on a test set.
Expand All @@ -150,8 +175,25 @@ def accuracy(self, test_set, format=None):
test_features = [(self.extract_features(d), c) for d, c in test_data]
return nltk.classify.accuracy(self.classifier, test_features)

def update(self, new_data):
'''Update the classifier with new training data and re-trains the
classifier.
:param new_data: New data as a list of tuples of the form
``(text, label)``.
'''
self.train_set += new_data
self.train_features = [(self.extract_features(d), c)
for d, c in self.train_set]
try:
self.classifier = self.nltk_class.train(self.train_features)
except AttributeError: # Descendant has not defined nltk_class
raise ValueError("NLTKClassifier must have a nltk_class"
" variable that is not None.")
return True


class NaiveBayesClassifier(BaseClassifier):
class NaiveBayesClassifier(NLTKClassifier):

'''A classifier based on the Naive Bayes algorithm, as implemented in
NLTK.
Expand All @@ -164,19 +206,7 @@ class NaiveBayesClassifier(BaseClassifier):
.. versionadded:: 0.6.0
'''

@cached_property
def classifier(self):
'''The Naive Bayes classifier.'''
self.train_features = [(self.extract_features(d), c) for d, c in self.train_set]
return nltk.classify.NaiveBayesClassifier.train(self.train_features)

def classify(self, text):
'''Classifies the text.
:param text: A string of text.
'''
text_features = self.extract_features(text)
return self.classifier.classify(text_features)
nltk_class = nltk.classify.NaiveBayesClassifier

def prob_classify(self, text):
'''Return the label probability distribution for classifying a string
Expand Down Expand Up @@ -209,15 +239,3 @@ def show_informative_features(self, *args, **kwargs):
'''
return self.classifier.show_most_informative_features(*args, **kwargs)

def update(self, new_data):
'''Update the classifier with new training data and re-trains the
classifier.
:param new_data: New data as a list of tuples of the form
``(text, label)``.
'''
self.train_set += new_data
self.train_features = [(self.extract_features(d), c)
for d, c in self.train_set]
self.classifier = nltk.classify.NaiveBayesClassifier.train(self.train_features)
return True

0 comments on commit daac110

Please sign in to comment.