In [1]:
import os
import collections
import random
import math

In [2]:
DATA_DIRPATH = os.path.join('..', 'data', 'dane_pozytywistyczne')

CORP_FILEPATHS = {
    'P': os.path.join(DATA_DIRPATH, 'korpus_prusa.txt'),
    'O': os.path.join(DATA_DIRPATH, 'korpus_orzeszkowej.txt'),
    'S': os.path.join(DATA_DIRPATH, 'korpus_sienkiewicza.txt'),
}

In [3]:
TARGETS = list('POS')

UNIGRAMS = {
    target: collections.defaultdict(lambda: 1)
    for target in TARGETS
}

PUNCTUATION = {
    target: collections.defaultdict(lambda: 1)
    for target in TARGETS
}

In [4]:
LETTERS_PL = set('aąbcćdeęfghijklłmnńoópqrsśtuvwxyzżź')
PUNCT = set('.,;?!')

def extract_features(text):
    # words only for now
    buf = []
    words = []
    punctuation = []
    for ch in text:
        if ch in LETTERS_PL:
            buf.append(ch)
        elif buf:
            words.append(''.join(buf))
            buf.clear()
        if ch in PUNCT:
            punctuation.append(ch)
            
    return words, punctuation

def train(target, text):
    words, punctuation = extract_features(text)
    for word in words:
        UNIGRAMS[target][word] += 1
    for pun in punctuation:
        PUNCTUATION[target][pun] += 1

In [5]:
for target, filepath in CORP_FILEPATHS.items():
    with open(filepath) as f:
        for line in f:
            train(target, line)

In [6]:
def classify(text):
    words, punctuation = extract_features(text)
    probs = {}
    for target in TARGETS:
        all_ugs_count = sum(UNIGRAMS[target].values())
        words_prob = sum(math.log(UNIGRAMS[target][word]) for word in words)
        words_prob -= len(words) * math.log(all_ugs_count)
        
        all_pun_count = sum(PUNCTUATION[target].values())
        punct_prob = sum(math.log(PUNCTUATION[target][pun]) for pun in punctuation)
        punct_prob -= len(punctuation) * math.log(all_pun_count)

        probs[target] = words_prob + punct_prob
    pred = sorted(probs.items(), key=lambda p: p[1], reverse=True)[0][0]
    return pred

In [7]:
def test(target, text):
    pred = classify(text)
    return pred == target

In [8]:
# test on tests data
TESTS_DIRPATH = os.path.join(DATA_DIRPATH, 'testy1')

test_filenames = os.listdir(TESTS_DIRPATH)
correct = 0
preds = collections.defaultdict(lambda: collections.defaultdict(int))

for test_filename in test_filenames:
    if 'prus' in test_filename:
        target = 'P'
    elif 'orze' in test_filename:
        target = 'O'
    else:
        target = 'S'
    with open(os.path.join(TESTS_DIRPATH, test_filename)) as f:
        text = f.read()
    pred = classify(text)
    preds[target][pred] += 1
    if pred == target:
        correct += 1
print(correct / len(test_filenames))
for t in TARGETS:
    print(t)
    for p in preds:
        print(f'    {p} -> {preds[t][p]}')

0.7
P
    P -> 19
    S -> 2
    O -> 0
O
    P -> 0
    S -> 9
    O -> 3
S
    P -> 7
    S -> 20
    O -> 0
