In [2]:
from numba import jit
import numpy as np

def compute_log_lik(T, e_list, f_list):
    cur_log_lik = 0
    log_sum = None
    e = None
    f = None
    n = len(e_list)
    # e list = f _list len
    for i in range(n):
        e = e_list[i]
        f = f_list[i]
        for e_t in e:
            log_sum = 0
            for f_t in f:
                log_sum += T[e_t, f_t]
            cur_log_lik += np.log(log_sum)
    return cur_log_lik

In [4]:
@jit(nopython=True)
def _inner_log_lik(e, f, T):
    cur_log_lik = 0
    for e_t in e:
        log_sum = 0
        for f_t in f:
            log_sum += T[e_t, f_t]
        cur_log_lik += np.log(log_sum)
    return cur_log_lik


def compute_log_lik(T, pairs):
    cur_log_lik = 0
    # e list = f _list len
    for e, f in pairs:
        cur_log_lik += _inner_log_lik(e, f, T)
    return cur_log_lik

In [30]:
%timeit compute_log_lik(np.ones((100,100), dtype=np.float32), [[j for j in range(i)] for i in range(1, 100)], [[j for j in range(i)] for i in range(1, 100)])

124 ms ± 2.92 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [29]:
%timeit compute_log_lik_numba(np.ones((100,100), dtype=np.float32), [[j for j in range(i)] for i in range(1, 100)], [[j for j in range(i)] for i in range(1, 100)])

1.67 ms ± 108 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [21]:
import time

def _estimate(T, pairs, start_log_lik, tol):
    prev_log_lik = -np.inf
    counts = np.zeros(T.shape)
    total = np.zeros(T.shape[1])
    it_counter = 0
    cur_log_lik = start_log_lik
    while (cur_log_lik - prev_log_lik) > tol:
        start_it_time = time.clock()
#         print("current log likelihood: {cur_log_lik}".format(cur_log_lik=cur_log_lik))
        for e, f in pairs:
            # compute normalization
            s_total = np.zeros(T.shape[0])
            for e_t in e:
                for f_t in f:
                    s_total[e_t] += T[e_t, f_t]

            # counts
            for e_t in e:
                for f_t in f:
                    counts[e_t, f_t] += T[e_t, f_t] / s_total[e_t]
                    total[f_t] += T[e_t, f_t] / s_total[e_t]

        for e_t in range(T.shape[0]):
            for f_t in range(T.shape[1]):
                T[e_t, f_t] = counts[e_t, f_t] / total[f_t]

        prev_log_lik = cur_log_lik
        cur_log_lik = compute_log_lik(T, pairs)

#         print("time taken for loop {it}: {time}".format(it=str(it_counter),
#             time=str(time.clock() - start_it_time)))
        # serialize
#         pickle.dump(T, open('T_'+str(it_counter)+'.pkl', 'wb+'))
        it_counter += 1
    return T


In [64]:
import time

@jit(nopython=True)
def _estimate_numba_inner(e, f, s_total, counts, total):
    for e_t in e:
        for f_t in f:
            s_total[e_t] += T[e_t, f_t]

    # counts
    for e_t in e:
        for f_t in f:
            counts[e_t, f_t] += T[e_t, f_t] / s_total[e_t]
            total[f_t] += T[e_t, f_t] / s_total[e_t]

@jit(nopython=True)
def _fill_T(T, counts, total):
    for e_t in range(T.shape[0]):
        for f_t in range(T.shape[1]):
            T[e_t, f_t] = counts[e_t, f_t] / (total[f_t] + 0.00001)


def _estimate(T, pairs, start_log_lik, tol):
    prev_log_lik = -np.inf
    counts = np.zeros(T.shape)
    total = np.zeros(T.shape[1])
    it_counter = 0
    cur_log_lik = start_log_lik
    while (cur_log_lik - prev_log_lik) > tol:
        start_it_time = time.clock()
        print("current log likelihood: {cur_log_lik}".format(cur_log_lik=cur_log_lik))
        for e, f in pairs:
            # compute normalization
            s_total = np.zeros(T.shape[0])
            _estimate_numba_inner(e, f, s_total, counts, total)

        _fill_T(T, counts, total)

        prev_log_lik = cur_log_lik
        cur_log_lik = compute_log_lik(T, pairs)

        print("time taken for loop {it}: {time}".format(it=str(it_counter),
            time=str(time.clock() - start_it_time)))
        # serialize
        pickle.dump(T, open('T_'+str(it_counter)+'.pkl', 'wb+'))
        it_counter += 1
    return T



In [58]:
n = 100
T = np.ones((n,n), dtype=np.float32) * 1 / n

e = [[j-1 for j in range(n)] for i in range(1, n)]
f = [[j-1 for j in range(n)] for i in range(1, n)]
pairs = zip(e, f)
start_log_lik = compute_log_lik(T, pairs)
tol = 100

In [47]:
%timeit _estimate(T, pairs, start_log_lik, tol)



41.1 ms ± 1.14 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [63]:
%timeit _estimate_numba(T, pairs, start_log_lik, tol)

37.2 µs ± 908 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [3]:
@jit(nopython=True)
def _inner_log_lik(e, f, T):
    cur_log_lik = 0
    for e_t in e:
        log_sum = 0
        for f_t in f:
            log_sum += T[e_t, f_t]
        cur_log_lik += np.log(log_sum)
    return cur_log_lik


def compute_log_lik(T, pairs):
    cur_log_lik = 0
    # e list = f _list len
    for e, f in pairs:
        cur_log_lik += _inner_log_lik(e, f, T)
    return cur_log_lik

In [11]:
import pickle
t0 = pickle.load(open('T_0.pkl', 'rb'))
t0[:,8].sum()

1.0000000000004343

In [16]:
t1 = pickle.load(open('T_1.pkl', 'rb'))

t1[:,2].sum()

1.0000000000002183

In [22]:
(np.abs(t0-t1)).max()

1.2530698700885523e-11

In [26]:
t1.max(axis=1)

array([1.38398727e-04, 1.08643000e-02, 1.38398727e-04, 2.07598090e-03,
       2.56037644e-03, 6.91993634e-05, 3.45996817e-04, 2.07598090e-04,
       6.91993634e-05, 4.15196180e-04, 7.61192997e-04, 2.76797453e-04,
       6.91993634e-05, 9.27271469e-03, 6.91993634e-05, 6.91993634e-05,
       1.38398727e-04, 2.76797453e-04, 1.45318663e-03, 1.38398727e-04,
       6.91993634e-05, 7.81952806e-03, 6.91993634e-05, 6.91993634e-05,
       1.38398727e-04, 6.22794270e-04, 2.76797453e-04, 1.38398727e-04,
       2.14518026e-03, 2.56037644e-03, 6.91993634e-05, 2.56037644e-03,
       9.68791087e-04, 6.91993634e-05, 6.91993634e-05, 6.91993634e-05,
       6.91993634e-05, 2.76797453e-04, 6.91993634e-05, 6.91993634e-05,
       1.45318663e-03, 6.91993634e-05, 6.91993634e-05, 1.38398727e-04,
       6.91993634e-05, 6.91993634e-05, 6.91993634e-05, 6.91993634e-05,
       6.91993634e-05, 6.22794270e-04, 6.91993634e-05, 1.38398727e-04,
       2.69877517e-03, 6.91993634e-05, 6.91993634e-05, 6.91993634e-05,
      