In [79]:
import polars as pl
from sklearn.model_selection import train_test_split

f_name = "amazon_cells_labelled.txt"
df = pl.read_csv(f_name, has_header=False, sep="\t", new_columns=["text", "class"])
df = df.with_columns(pl.col("text").str.replace_all("[^\s\w\d]", "").str.to_lowercase())
df_neg, df_pos = df.partition_by(groups="class", as_dict=False) # split data into its respective classes

# todo: use these
neg_train, neg_test = train_test_split(df_neg, test_size=0.2)
pos_train, pos_test = train_test_split(df_pos, test_size=0.2)

In [72]:
# count the occurrence of each distinct word
def count_words(df):
    return df.select(pl.col("text").str.split(" ").alias("words").flatten()).to_series().value_counts()

VOCAB_SIZE = len(count_words(df)) # number of distinct words in the entire dataset
neg_prior, pos_prior = len(df_neg)/len(df), len(df_pos)/len(df)

pos_wc = count_words(df_pos).get_column("counts").sum() # number of distinct words in positive class
pos_dict = dict(count_words(df_pos).rows()) # dictionary of (word, frequency) tuples in pos class

neg_wc = count_words(df_neg).get_column("counts").sum() # number of distinct words in negative class
neg_dict = dict(count_words(df_neg).rows()) # dictionary of (word, frequency) tuples in neg class

In [74]:
from typing import List, Dict

def conditional_prob(word_class_count, word_count):
    return (word_count + 1) / (word_class_count + VOCAB_SIZE)

def sentence_prob(sentence: List[str], class_dict, word_class_count):
    prob = 1.0
    for word in sentence:
        word_count = class_dict.get(word, 0)
        prob *= conditional_prob(word_class_count, word_count)
    return prob

def nb(sentence: List[str], class_dicts: List[Dict], word_class_counts: List[Dict], priors):
    best_idx, best = 0, -1
    for idx, (word_class_count, class_dict, prior) in enumerate(zip(word_class_counts, class_dicts, priors)):
        prob = prior * sentence_prob(sentence, class_dict, word_class_count)
        if prob > best:
            best_idx, best = idx, prob
    return best_idx, best

nb(["best", "product", "ever"], [neg_dict, pos_dict], [neg_wc, pos_wc], [neg_prior, pos_prior])

(1, 8.098930004075396e-09)