In [1]:
import numpy as np
from scipy.optimize import minimize
np.set_printoptions(precision=4)

In [2]:
def interpolate(P, QQ, seed=1):
    rnd = np.random.RandomState(seed)
    
    M  = len(QQ)
    w  = rnd.uniform(size=M)
    w /= w.sum()
    v  = np.log(w)
    l  = 0.0
    
    obj = lambda x: interpolate_obj(P, QQ, x[:-1], x[-1])
    jac = lambda x: interpolate_jac(P, QQ, x[:-1], x[-1])
    sol = minimize(obj, np.r_[v, l], jac=jac, method='BFGS')
    
    w = np.exp(sol.x[:-1])
    
    return w
    

def interpolate_obj(P, QQ, v, l):
    return perplexity(P, QQ, v) + l * (1 - np.exp(v).sum())


def interpolate_jac(P, QQ, v, l):
    g = np.zeros(v.size + 1)
    g[:-1] = perplexity_jac(P, QQ, v)
    g[-1]  = 1 - np.exp(v).sum()
    return g


def perplexity(P, QQ, v):
    Q = mixture(QQ, v)
    return - np.sum(P * np.log(Q))


def perplexity_jac(P, QQ, v):
    Q = mixture(QQ, v)
    g = np.zeros_like(v)
    
    for i, vi in enumerate(v):
        wi   = np.exp(vi)
        g[i] = wi * np.sum(P * QQ[i] / Q)
        
    return g


def mixture(QQ, v):
    w = np.exp(v)
    Q = np.zeros_like(QQ[0])
    for wi, Qi in zip(w, QQ):
        Q += wi * Qi
    return Q

In [3]:
rnd = np.random.RandomState(0)

N, K = 100, 5
w    = np.array([0.7, 0.3])
v    = np.log(w)
Q1   = rnd.uniform(size=(N, K))
Q1  /= Q1.sum(axis=1)[:, np.newaxis]
Q2   = rnd.uniform(size=(N, K))
Q2  /= Q2.sum(axis=1)[:, np.newaxis]
P    = w[0] * Q1 + w[1] * Q2
QQ   = [Q1, Q2]

In [7]:
W = [interpolate(P, QQ, s) for s in range(100)]

In [8]:
perplexity(P, QQ, np.log(w))

151.56766495234976

In [9]:
min(perplexity(P, QQ, np.log(w_hat)) for w_hat in W)

151.56791969837786