In [None]:
# Task List 8 - Dirichlet-multinomial model
import math

import matplotlib.pyplot as plt
import matplotlib.tri as tri
import numpy as np
from scipy.stats import dirichlet, multinomial

In [None]:
def all_args_to_np_array(f):
    def inner(*args, **kwargs):
        args = list(args)
        for idx, arg in enumerate(args):
            if isinstance(arg, list):
                args[idx] = np.array(arg)
        for kw_name, kw_val in kwargs.items():
            if isinstance(kw_val, list):
                kwargs[kw_name] = np.array(kw_val)
        return f(*args, **kwargs)
    
    return inner

def norm(l):
    return np.array(l) / sum(np.array(l))

In [None]:
# Dirichlet visualize
def xy2bc(xy, corners, tol=1.e-10):
    '''Converts 2D Cartesian coordinates to barycentric.'''
    midpoints = [(corners[(i + 1) % 3] + corners[(i + 2) % 3]) / 2.0 \
                 for i in range(3)]
    
    s = [(corners[i] - midpoints[i]).dot(xy - midpoints[i]) / 0.75 \
         for i in range(3)]
    return np.clip(s, tol, 1.0 - tol)


def draw_pdf_contours(calc_fn, ax):
    corners = np.array([[0, 0], [1, 0], [0.5, 0.75**0.5]])
    triangle = tri.Triangulation(corners[:, 0], corners[:, 1])
    
    refiner = tri.UniformTriRefiner(triangle)
    trimesh = refiner.refine_triangulation(subdiv=4)
    pvals = [calc_fn(xy2bc(xy, corners)) for xy in zip(trimesh.x, trimesh.y)]
    nlevels = 100
    
    ax.tricontourf(trimesh, pvals, nlevels, cmap='hot')
    ax.axis('equal')
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 0.75**0.5)
    ax.axis('off')

### Exercise 1 - Implement a Dirichlet-multinomial model for dice tossing problem

### Exercise 2 - Implement a method for posterior predictive distribution of a future observation

In [None]:
def calc_lk(theta, trials):
    lk = 1.0
    
    for theta_k, N_k  in zip(theta, trials):
        lk *= theta_k ** N_k 
    
    return lk


@all_args_to_np_array
def calc_map(alpha, results, dim=3):
    _map = (results + alpha - 1) / (sum(results) + sum(alpha) - dim) 
    return np.round(_map, 4)
    
    
@all_args_to_np_array
def calc_mle(results):
    _mle = results / sum(results)
    return np.round(_mle, 4)


@all_args_to_np_array
def calc_posterior_predictive(alpha, results):
    probs = (alpha + results) / (sum(alpha) + sum(results))
    return np.argmax(probs), list(np.round(probs, 4))
    

@all_args_to_np_array
def dir_mult_model(prior_alpha, results, simplex_x):
    prior = dirichlet.pdf(x=simplex_x, alpha=prior_alpha)
    likelihood = calc_lk(simplex_x, results)
    
    posterior = prior * likelihood
    exact_posterior = dirichlet.pdf(x=simplex_x, alpha=prior_alpha+results)
    
    return posterior, exact_posterior


def visualize_dirichlet_multinomial(prior, results):
    fig = plt.figure(figsize=(10, 10))

    ax1 = fig.add_subplot(221)
    ax1.set_title('Prior\n(dirichlet%s)' % prior)
    draw_pdf_contours(lambda x: dirichlet.pdf(x=x, alpha=prior), ax1)

    ax2 = fig.add_subplot(222)
    ax2.set_title('Likelihood\n(multinomial%s)' % results)
    draw_pdf_contours(lambda x: calc_lk(results, x), ax2)

    ax3 = fig.add_subplot(223)
    ax3.set_title('Posterior\n(dir_mult%s)' % (np.array(results) + np.array(prior)))
    draw_pdf_contours(lambda x: dir_mult_model(prior, results, x)[0], ax3)

    ax4 = fig.add_subplot(224)
    ax4.set_title('Posterior\n(dir_mult_ext%s)' % (np.array(results) + np.array(prior)))
    draw_pdf_contours(lambda x: dir_mult_model(prior, results, x)[1], ax4)

    print('MAP:', calc_map(prior, results))
    print('MLE:', calc_mle(results))
    print('Posterior predictive:', calc_posterior_predictive(prior, results))

visualize_dirichlet_multinomial(prior=[2, 2, 2],
                                results=[200, 100, 100])

### Exercise 3 - Implement prediction mechanism for next word in a text using your model

In [None]:
from pprint import pprint

import matplotlib.animation as anim
from matplotlib import rc
from IPython.display import HTML

import seaborn as sns
import spacy
nlp = spacy.load('en')

In [None]:
def load_document(filepath):
    with open(filepath, 'r') as f:
        return f.read().replace('\n', ' ').strip()

    
def get_lemmas(doc):
    proc_doc = nlp(doc)
    
    lemmas = [tok.lemma_ for tok in proc_doc 
              if (not tok.is_punct) and (tok.lemma_ != '-PRON-') and (not tok.is_stop)]
    return lemmas


def to_bag_of_words(vocabulary, words):
    word_counts = []
    for w in vocabulary:
        word_counts.append(words.count(w))
        
    return word_counts


def split_train_test(doc, train_size=0.8):
    words = doc.split(' ')
    split_index = int(len(words) * train_size)
    train_words, test_words = words[:split_index], words[split_index:]
    
    return ' '.join(train_words), ' '.join(test_words)


In [None]:
def make_word_occurrences_barplot(words, occurrences, ax):
    ax.cla()
    g = sns.barplot(words, occurrences, ax=ax)
    for item in g.get_xticklabels():
        item.set_rotation(90)
        
        
def visualize_word_occurrences(prior, vocabulary, test_lemmas):
    def animate(i, ax1, ax2):
        if i == 0:
            return
        
        # Priors
        make_word_occurrences_barplot(vocabulary, prior, ax1)
        
        # Results
        results = to_bag_of_words(vocabulary, test_lemmas[:i])
        make_word_occurrences_barplot(vocabulary, np.array(prior) + np.array(results), ax2)

    
    fig = plt.figure(figsize=(15, 7))
    ax1 = fig.add_subplot(211)
    ax2 = fig.add_subplot(212)
    func_anim = anim.FuncAnimation(fig, animate,
                                   frames=list(range(len(test_lemmas) - 1)),
                                   fargs=(ax1, ax2), interval=50)
    
    return func_anim

In [None]:
def run_visualization_word_occurrences():
    doc = load_document('lyrics.txt')
    vocabulary = sorted(set(get_lemmas(doc)))
    train_doc, test_doc = split_train_test(doc, train_size=0.2)
    
    prior = to_bag_of_words(vocabulary, get_lemmas(train_doc))
    test_lemmas = get_lemmas(test_doc)
    
    return visualize_word_occurrences(prior, vocabulary, test_lemmas)
        
    
func_anim = run_visualization_word_occurrences()
HTML(func_anim.to_jshtml())

In [None]:
def run_word_prediction(filepath, lemmatize_fn, train_size=0.2):
    doc = load_document(filepath)
    vocabulary = sorted(set(lemmatize_fn(doc)))
    train_doc, test_doc = split_train_test(doc, train_size=train_size)
    
    prior = to_bag_of_words(vocabulary, lemmatize_fn(train_doc))
    test_lemmas = lemmatize_fn(test_doc)
    
    prediction_results = {
        'correct': 0,
        'false': 0
    }
    
    from time import time
    last_ts = time()
    bow = [0] * len(vocabulary)
    
    for i in range(len(test_lemmas) - 1):
        if i % max(1, int((len(test_lemmas) - 1) / 10)) == 0:
            curr_time = time()
            interval = np.round((curr_time - last_ts) * 1000, 2)
            last_ts = curr_time
            print('Step %d of %d' % (i, len(test_lemmas) - 1), 'time: ', interval, '(ms)')
        
        pred_idx, pred_probs = calc_posterior_predictive(prior, bow)
        
        if vocabulary[pred_idx] == test_lemmas[i + 1]:
            prediction_results['correct'] += 1
        else:
            prediction_results['false'] += 1
            
        bow[vocabulary.index(test_lemmas[i])] += 1
    
    for k, v  in prediction_results.items():
        prediction_results[k] = np.round(100 * v / (len(test_lemmas) - 1), 2)
            
    return prediction_results


def run_all_predictions():
    filenames = ['lyrics.txt', 'poem.txt', 'story.txt']
    lemmatize_fns = [
        ('No lemmatization', lambda doc: doc.split(' ')),
        ('Spacy lemmatize', lambda doc: get_lemmas(doc)),
    ]
    results = {}
    
    for f in filenames:
        for fn_name, fn in lemmatize_fns:
            print('Prediction for text:', f, 'using', fn_name)
            results[(f, fn_name)] = run_word_prediction(f, fn, train_size=0.8)
    
    pprint(results)
    
    fig = plt.figure(figsize=(15, 15))
    idx = 1
    
    for f in filenames:
        for fn_name, _ in lemmatize_fns:
            ax = fig.add_subplot(len(filenames), 2, idx)
            res = results[(f, fn_name)]
            g = sns.barplot(list(res.keys()), list(res.values()), ax=ax)
            g.text(0, res['correct'] + 0.5, res['correct'], color='black', ha="center")
            g.text(1, res['false'] + 0.5, res['false'], color='black', ha="center")
            ax.set_title('%s - %s' % (f, fn_name))
            
            idx += 1
    
    
run_all_predictions()