In [11]:
import nltk
import numpy as np
from sklearn.datasets import fetch_20newsgroups

from random import shuffle
from collections import OrderedDict 

import string
from nltk.corpus import stopwords
from nltk import word_tokenize
from nltk.stem import PorterStemmer

In [16]:
# return list of newsgroup categories
def get_categories():
    return ['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc']

In [14]:
# return dataset and labels as np arrays
def get_dataset(n_samples_per_cat: int):
    categories = get_categories()
    data = np.array([])
    labels = np.array([])
    current_label = 0
    
    for category in categories:
        ng_category = fetch_20newsgroups(subset='all', shuffle=False, categories=[category])
        
        # add post contents to dataset
        ng_category_contents = ng_category.data[:n_samples_per_cat]
        data = np.append(data, ng_category_contents)
        
        # target attribute is the integer index of the category
        labels = np.append(labels, [current_label for x in range(0, len(ng_category_contents))])
        current_label += 1

    # shuffle elements in dataset
    indices = np.arange(len(data))
    np.random.shuffle(indices)
    data = np.array(data)[indices]
    labels = np.array(labels)[indices]
   
    return data, labels

In [13]:
#return doc as a list of cleaned tokens
def clean_doc(doc_str, stem=False, rem_punc=False, stop=False, lemmatize=False):    
    # Remove punctuations
    if rem_punc:
        exclude = set(string.punctuation)
        doc_str = ''.join(ch for ch in doc_str if ch not in exclude)
 
    tokens = word_tokenize(doc_str)
    
    if stop:
        stop = stopwords.words('english')
        tokens =[word for word in tokens if word not in stop]
        tokens = [word.lower() for word in tokens]

    if stem:
        stemmer = PorterStemmer()
        tokens = [stemmer.stem(t) for t in tokens]
    
    return tokens

# return list of pre_processed documents
def pre_process(document_list, parameters):
    preprocessed_document_list = []
    
    for document in document_list:
        cleaned_document = clean_doc(document, **parameters)
        preprocessed_document_list.append(cleaned_document)
    
    return preprocessed_document_list