In [411]:
import numpy as np


In [412]:


def sigmoid(x):
    return 1. / (1. + np.exp(-x))

def softmax(x, axis=-1):
    x -= np.max(x, axis=axis, keepdims=True)
    if x.dtype == np.float32 or x.dtype == np.float64:
        np.exp(x, out=x)
    else:
        x = np.exp(x)
        x /= np.sum(x, axis=axis, keepdims=True)
    return x

def parse_mdn(name, outs, in_N=0, out_N=1, out_shape=None):
    raw = outs[name]
    raw = raw.reshape((raw.shape[0], max(in_N, 1), -1))

    pred_mu = raw[:,:,:(raw.shape[2] - out_N)//2]
    n_values = (raw.shape[2] - out_N)//2
    pred_mu = raw[:,:,:n_values]
    pred_std = np.exp(raw[:,:,n_values: 2*n_values])
    #print("pred_mu1", pred_mu.shape, pred_mu)

    if in_N > 1:
        weights = np.zeros((raw.shape[0], in_N, out_N), dtype=raw.dtype)
        for i in range(out_N):
            # print("weights index shape", raw[:,:,i - out_N].shape)
            weights[:,:,i - out_N] = softmax(raw[:,:,i - out_N], axis=-1)

        if out_N == 1:
            for fidx in range(weights.shape[0]):
                idxs = np.argsort(weights[fidx][:,0])[::-1]
                weights[fidx] = weights[fidx][idxs]
                pred_mu[fidx] = pred_mu[fidx][idxs]
                pred_std[fidx] = pred_std[fidx][idxs]
        full_shape = tuple([raw.shape[0], in_N] + list(out_shape))
        outs[name + '_weights'] = weights
        outs[name + '_hypotheses'] = pred_mu.reshape(full_shape)
        outs[name + '_stds_hypotheses'] = pred_std.reshape(full_shape)

        pred_mu_final = np.zeros((raw.shape[0], out_N, n_values), dtype=raw.dtype)
        pred_std_final = np.zeros((raw.shape[0], out_N, n_values), dtype=raw.dtype)
        for fidx in range(weights.shape[0]):
            for hidx in range(out_N):
                idxs = np.argsort(weights[fidx,:,hidx])[::-1]
                pred_mu_final[fidx, hidx] = pred_mu[fidx, idxs[0]]
                pred_std_final[fidx, hidx] = pred_std[fidx, idxs[0]]
        # print("pred_mu2", pred_mu_final.shape, pred_mu_final)
    else:
        pred_mu_final = pred_mu
        pred_std_final = pred_std

    if out_N > 1:
        final_shape = tuple([raw.shape[0], out_N] + list(out_shape))
    else:
        final_shape = tuple([raw.shape[0],] + list(out_shape))
    outs[name] = pred_mu_final.reshape(final_shape)
    outs[name + '_stds'] = pred_std_final.reshape(final_shape)
    return outs


In [413]:
def new_parse_mdn(input_tensor, in_N=0, out_N=1, out_shape=None):
    raw = input_tensor
    raw = raw.reshape((raw.shape[0], max(in_N, 1), -1))
    #print("raw1", raw.shape, raw)

    n_values = (raw.shape[2] - out_N)//2
    pred_mu = raw[:,:,:n_values]
    pred_std = np.exp(raw[:,:,n_values: 2*n_values])
    # print("pred_mu1", pred_mu.shape, pred_mu)

    if in_N > 1:
        weights = np.zeros((raw.shape[0], in_N, out_N), dtype=raw.dtype)
        for i in range(out_N):
            # print("weights index shape", raw[:,:,i - out_N].shape)
            weights[:,:,i - out_N] = softmax(raw[:,:,i - out_N], axis=-1)

        if out_N == 1:
            for fidx in range(weights.shape[0]):
                idxs = np.argsort(weights[fidx][:,0])[::-1]
                weights[fidx] = weights[fidx][idxs]
                pred_mu[fidx] = pred_mu[fidx][idxs]
                pred_std[fidx] = pred_std[fidx][idxs]

        pred_mu_final = np.zeros((raw.shape[0], out_N, n_values), dtype=raw.dtype)
        pred_std_final = np.zeros((raw.shape[0], out_N, n_values), dtype=raw.dtype)
        for fidx in range(weights.shape[0]):
            for hidx in range(out_N):
                idxs = np.argsort(weights[fidx,:,hidx])[::-1]
                pred_mu_final[fidx, hidx] = pred_mu[fidx, idxs[0]]
                pred_std_final[fidx, hidx] = pred_std[fidx, idxs[0]]
        # print("pred_mu2", pred_mu_final.shape, pred_mu_final)
    else:
        pred_mu_final = pred_mu
        pred_std_final = pred_std

    if out_N > 1:
        final_shape = tuple([raw.shape[0], out_N] + list(out_shape))
    else:
        final_shape = tuple([raw.shape[0],] + list(out_shape))
    final = pred_mu_final.reshape(final_shape)
    final_stds = pred_std_final.reshape(final_shape)
    return final, final_stds

In [414]:
def compare(name, rand, in_N, out_N, out_shape):
    rand = np.random.rand(*rand)
    print("all", name, rand, in_N, out_N, out_shape)
    print("rand", rand[:2, :2])
    outs_orig = {}
    outs_orig[name] = np.copy(rand)
    outs_orig = parse_mdn(name, outs_orig, in_N, out_N, out_shape)
    orig_out = outs_orig[name]
    orig_stds = outs_orig[name + '_stds']
    print("all", name, rand, in_N, out_N, out_shape)

    # new

    final, stds = new_parse_mdn(np.copy(rand), in_N, out_N, out_shape)

    print("all", name, rand, in_N, out_N, out_shape)

    if np.allclose(orig_out, final, atol=0.001) and orig_out.shape == final.shape:
        print("Output matches!")
    else:
        print("not matching", orig_out[:2, :2, :2],"\n", final[:2, :2, :2])
    if np.allclose(orig_stds, stds) and orig_stds.shape == stds.shape:
        print("Stds match!")

    else:
        print("It doesn't match!")


In [415]:
outs = {}
outs["plan"] = np.random.rand(1, 4955)
parse_mdn('plan', outs, in_N=5, out_N=1, out_shape=(33 ,15))
outs['plan'].shape
""

''

In [416]:
compare("plan", (1, 4955), 5, 1, (33, 15))

all plan [[0.48743325 0.56121467 0.09961541 ... 0.22013133 0.20395704 0.75284165]] 5 1 (33, 15)
rand [[0.48743325 0.56121467]]
all plan [[0.48743325 0.56121467 0.09961541 ... 0.22013133 0.20395704 0.75284165]] 5 1 (33, 15)
all plan [[0.48743325 0.56121467 0.09961541 ... 0.22013133 0.20395704 0.75284165]] 5 1 (33, 15)
Output matches!
Stds match!


In [417]:

outs = {}
outs["lead"] = np.random.rand(1, 102)
parse_mdn('lead', outs, in_N=2, out_N=3, out_shape=(6 ,4 ))
""

''

In [418]:
compare("lead", (1, 102), 2, 3, (6, 4))

all lead [[0.26682139 0.71964466 0.02787892 0.67291071 0.40037167 0.36446549
  0.41055723 0.12974938 0.12993392 0.77385801 0.39646749 0.20018192
  0.01344558 0.67531722 0.0946575  0.59514798 0.8475552  0.77191966
  0.48688467 0.98612799 0.41440056 0.68404347 0.18090945 0.91551586
  0.01068158 0.29665883 0.22518282 0.77340035 0.53225784 0.71508046
  0.84995425 0.70788866 0.70880783 0.37611446 0.03190937 0.6027609
  0.63789596 0.98470403 0.65126685 0.77574072 0.09857331 0.04705282
  0.89849166 0.97035705 0.56022714 0.41224445 0.95050002 0.78116782
  0.47887082 0.43779115 0.74643279 0.49026343 0.50155099 0.23999204
  0.7865971  0.11203179 0.25172103 0.79561071 0.03476969 0.49118885
  0.56329064 0.45173026 0.48054116 0.01114354 0.8269745  0.97293342
  0.63512192 0.68553129 0.83883242 0.72811961 0.21840604 0.72170723
  0.93555554 0.19954617 0.4021347  0.11536183 0.9030587  0.85162222
  0.07333316 0.45486776 0.35950759 0.8550457  0.78795931 0.63490911
  0.92610087 0.98443235 0.67320828 0.615