In [1]:
['hi'] |*> print

hi


In [214]:
from jax import grad, jit, vmap
import jax.numpy as np
from jax.lax import dynamic_slice

In [212]:


# i,j = (np.ones((N,N)) 
#  |> np.triu 
#  |> np.tile$(?,(N,1,1))
#  |> np.transpose$
#  |> map$(?, [(2,0,1), (2,1,0)])
# )


def QR_mask(N, n):
    """censored random-walk mask
    
    N: states in transition matrix
    n: length of walk 
    """
    i,j = (np.ones((N,N))   # ones, like P
     |> np.triu             # make upper-triangular
     |> np.tile$(?,(N,1,1)) # make into upper-tri cube
     |> np.transpose$       # rotate cube
     |> map$(?, [(2,0,1), (2,1,0)])  # left->right, top->down
    )
    q = (i*j)      # growing square (upper left)
    r = ((1-i)*j)  # shrinking rectangle (upper right)
    return  q[:n],r[:n]  # oly need up to walk length


N = 20
qrmask = QR_mask$(N)  # hold transition matrix constant

normalize = A -> A/A.sum(axis=1)[:,None]

def likelihood(P, m):
    n = len(m) - 1  # walk length
    N = P.shape[0]
#     q_ma, r_ma = qrmask(n)
    
    # now we have to sort P to make masks work
    # --> sort by RW-order + remaining nodes
    # index = np.concatenate([m, np.delete(np.arange(N, ))])
    index = list(m) + list(set(range(10)) - set(m))
    
    q, r = (P[index][:index]
        |> np.exp
        |> normalize
        |> map((*)$, qrmask(n))
    )
    p = (np.eye(N) - q
        |>np.linalg.inv$(?, r)
    )
    lik = (p
        .diagonal()
        .diagonal(offset=-1)
        .log().sum()
    )
    return lik



# n = 4

# q_mask = (i*j)[:n]
# r_mask = ((1-i)*j)[:n] #-\
#     np.fliplr((i*j)[n])*(j[:n])

# qrmask(4)

In [213]:
T = np.array([[0.9, 0.1],
              [0.,  1.0]])
m = [0,1]
likelihood(T, m)


TypeError: <class 'list'> is not a valid Jax type

In [143]:
# ((1-i)*j)[:n] - np.fliplr((i*j)[n])*j[:n]

import time

t = time.time()
qrmask(10)
print(time.time() - t)
t = time.time()
qrmask(11)
print(time.time() - t)


# qrmask(10)
# QR_mask(10)

0.0065882205963134766
0.003113985061645508


In [203]:

np.triu(np.ones(5))/(np.triu(np.ones(5))).sum(axis=1)[:,None]

DeviceArray([[0.2       , 0.2       , 0.2       , 0.2       , 0.2       ],
             [0.        , 0.25      , 0.25      , 0.25      , 0.25      ],
             [0.        , 0.        , 0.33333334, 0.33333334, 0.33333334],
             [0.        , 0.        , 0.        , 0.5       , 0.5       ],
             [0.        , 0.        , 0.        , 0.        , 1.        ]],
            dtype=float32)

In [None]:
def P_i(T, a, idx):
    """ Probability of absorbtions given an observed chain:

    We need to partition the transition matrix
        $$ T =
        \begin{pmatrix}
         Q & R \\
         0 & I
        \end{pmatrix}
        $$
    where:
        $Q$: the non-absorbing transitions,
        $R$: non-absorbing to absorbing transitions
    Then, probability of being absorbed is given as
        $$P = (I-Q)^{-1} R$$

    In this case, we only want the probability of transitioning from the
    most recent state to the current absorbing state.
    """

    a_trans = np.array(a)[0:idx]  # visited
    a_absrb = np.array(a)[idx:len(a) - idx]  # not visited

    Q = T[a_trans, :][:, a_trans]
    R = T[a_trans, :][:, a_absrb]
    I = np.identity(Q.shape[0])

    P = np.dot(np.linalg.pinv(I-Q), R)

    return P[-1, 0]  # ...from previous state (P[-1,:] by construction) into next


def P_a(T, a):
    """ Calculate the log-likelihood of a transition matrix $T$, given censored
    observed INVITE sequence $a$.
    """

    frontload_a = list(a) + list(set(range(T.shape[0])) - set(a))
    like_i = partial(P_i, T, frontload_a)

    return -1*np.sum([np.log(like_i(idx)) for idx in range(1, len(a)-1)])


def _symmetrize(a):  # assumes 0-diags

    bott = np.tril(a) + np.tril(a).T
    top = np.triu(a) + np.triu(a).T
    # return (bott+top)/2. + infs
    return np.fmax(bott, top)


def _softmax(a, axis=None):
    a = a - a.max(axis=axis, keepdims=True)
    infs = np.diag(a.shape[0] * [-np.inf])

    y = np.exp(a + infs)
    return y / y.sum(axis=axis, keepdims=True)


def loss_i(m, idx, A, reg=1e-2):
    """Per-iteration objective function for use in ASGD"""

    T = _softmax(_symmetrize(A), axis=1)

    # for it in range(1, 4):  # 2x stochastic; Sinkhorn, 1964
    #     T = _softmax(T + np.diag(A.shape[0]*[-np.inf]), axis=it % 2)

    like = P_a(T, m[idx])
    # penalty = (1. / len(m)) * np.linalg.norm(A)  # Frob-norm
    penalty = (1. / len(m)) * np.abs(A).sum(axis=0).max()  # L1-norm

    return like + reg*penalty
