I need to understand how input are transfered here i.e. what's their from and check if i can run the code in the comparaison notebook.

In [1]:
from time import time
import numpy as np


def sinkhorn_bicausal_markov(mu_list, nu_list, cost_list, n_list, m_list, eps_stop=10**-4, max_iter=10**4,
                             reshape=True, outputflag=0):
    # Only for MARKOV - MARKOV marginals, bicausal!
    """

    :param mu_list: as output by get_meas_for_sinkhorn
    :param nu_list: as output by get_meas_for_sinkhorn
    :param cost_list: list of matrices, one for each time point (markov case). Notably, the cost functions should
                    already be kernelized, i.e., values are exp(-c) instead of c
    :param n_list: sizes of supports for mu for each time step
    :param m_list: sizes of supports for nu for each time step
    :return:
    """
    t_max = len(mu_list)

    # initializing dual functions. We specify them in a multiplicative way, i.e., compared to the paper, we store values
    # of exp(f_t) and exp(g_t) instead of f_t and g_t, which is in line with standard implementations of Sinkhorn's
    tinit = time()
    f_1 = np.ones(n_list[0])
    g_1 = np.ones(m_list[0])
    f_list = [f_1]
    g_list = [g_1]
    const_f_list = [0]
    const_g_list = [0]
    for t in range(1, t_max):
        f_h = [[np.ones([len(mu_list[t][i][1]), 1]) for j in range(m_list[t-1])] for i in range(n_list[t-1])]
        g_h = [[np.ones([1, len(nu_list[t][j][1])]) for j in range(m_list[t-1])] for i in range(n_list[t-1])]
        c_f_h = [[1 for j in range(m_list[t-1])] for i in range(n_list[t-1])]
        c_g_h = [[1 for j in range(m_list[t - 1])] for i in range(n_list[t - 1])]
        f_list.append(f_h)
        g_list.append(g_h)
        const_f_list.append(c_f_h)
        const_g_list.append(c_g_h)
    if outputflag:
        print('Initializing took ' + str(time()-tinit) + ' seconds')

    # Define update iterations:
    t_funs = time()
    def update_f_t(mut, nut, gt, ct):
        """

        :param mut: should be of shape (a, 1)
        :param nut: should be of shape (1, b)
        :param gt: should be of shape (1, b)
        :param ct: should be of shape (a, b)
        :return: array of shape (a, 1) representing f_t
        """
        # at = 1. / np.sum(gt * ct * nut, axis=1, keepdims=True)
        # at = 1. / np.dot(ct, (gt*nut).T)
        at = 1. / np.matmul(ct, (gt*nut).T)
        cth = np.sum(np.log(at) * mut)
        return at/np.exp(cth), cth

    def update_g_t(mut, nut, ft, ct):
        """

        :param mut: should be of shape (a, 1)
        :param nut: should be of shape (1, b)
        :param ft: should be of shape (a, 1)
        :param ct: should be of shape (a, b)
        :return: array of shape (1, b) representing g_t
        """
        # bt = 1. / np.sum(ft*ct*mut, axis=0, keepdims=True)
        # bt = 1. / np.dot(ct.T, ft*mut).T
        bt = 1. / np.matmul(ct.T, ft*mut).T
        cth = np.sum(np.log(bt) * nut)
        return bt/np.exp(cth), cth

    def update_f_1(mut, nut, gt, ct):
        # inputs as for update_f_t
        at = 1. / np.sum(gt * ct * nut, axis=1, keepdims=True)
        return at, np.sum(np.log(at) * mut)

    def update_g_1(mut, nut, ft, ct):
        # inputs as for update_g_t
        bt = 1. / np.sum(ft * ct * mut, axis=0, keepdims=True)
        return bt, np.sum(np.log(bt) * nut)

    def full_update_f_list():
        for t_m in range(t_max):
            t = t_max - t_m - 1
            if t > 0:
                cvnew = np.ones([n_list[t-1], m_list[t-1]])
            if t == 0:
                f_list[0], value_f = update_f_1(mu_list[0][1], nu_list[0][1], g_list[0], cost_list[0][mu_list[0][0], :][:, nu_list[0][0]] * cvh[mu_list[0][0], :][:, nu_list[0][0]])
            elif t == t_max-1:
                for i in range(n_list[t-1]):
                    for j in range(m_list[t-1]):
                        f_list[t][i][j], cvnew[i, j] = update_f_t(mu_list[t][i][1], nu_list[t][j][1], g_list[t][i][j], cost_list[t][mu_list[t][i][0], :][:, nu_list[t][j][0]])
            else:
                for i in range(n_list[t-1]):
                    for j in range(m_list[t-1]):
                        f_list[t][i][j], cvnew[i, j] = update_f_t(mu_list[t][i][1], nu_list[t][j][1], g_list[t][i][j], cost_list[t][mu_list[t][i][0], :][:, nu_list[t][j][0]]
                                                                  * cvh[mu_list[t][i][0], :][:, nu_list[t][j][0]])
            cvh = np.exp(-cvnew.copy())
            const_f_list[t] = cvh.copy()
        return value_f

    def full_update_g_list():
        for t_m in range(t_max):
            t = t_max - t_m - 1
            if t > 0:
                cvnew = np.ones([n_list[t-1], m_list[t-1]])
            if t == 0:
                g_list[0], value_g = update_g_1(mu_list[0][1], nu_list[0][1], f_list[0], cost_list[0][mu_list[0][0], :][:, nu_list[0][0]] * cvh[mu_list[0][0], :][:, nu_list[0][0]])
            elif t == t_max-1:
                for i in range(n_list[t-1]):
                    for j in range(m_list[t-1]):
                        g_list[t][i][j], cvnew[i, j] = update_g_t(mu_list[t][i][1], nu_list[t][j][1], f_list[t][i][j], cost_list[t][mu_list[t][i][0], :][:, nu_list[t][j][0]])
            else:
                for i in range(n_list[t-1]):
                    for j in range(m_list[t-1]):
                        g_list[t][i][j], cvnew[i, j] = update_g_t(mu_list[t][i][1], nu_list[t][j][1], f_list[t][i][j], cost_list[t][mu_list[t][i][0], :][:, nu_list[t][j][0]]
                                                                  * cvh[mu_list[t][i][0], :][:, nu_list[t][j][0]])
            cvh = np.exp(-cvnew.copy())
            const_g_list[t] = cvh.copy()
        return value_g

    if outputflag:
        print('Defining update functions took ' + str(time()-t_funs) + ' seconds')

    if reshape:
        # reshape inputs
        # we want mu_list[t][i][1] to be shaped (a, 1) and nu_list[t][j][1] to be shaped (1, b) for some a and b that may
        # depend on i and j
        t_reshape = time()
        for t in range(t_max):
            if t == 0:
                if len(mu_list[t][1].shape) == 1:
                    mu_list[t][1] = np.expand_dims(mu_list[t][1], 1)
                if len(nu_list[t][1].shape) == 1:
                    nu_list[t][1] = np.expand_dims(nu_list[t][1], 0)
                if len(mu_list[t]) == 2:
                    mu_list[t].append(np.log(mu_list[t][1]))
                if len(nu_list[t]) == 2:
                    nu_list[t].append(np.log(nu_list[t][1]))
            else:
                for i in range(n_list[t-1]):
                    if len(mu_list[t][i][1].shape) == 1:
                        mu_list[t][i][1] = np.expand_dims(mu_list[t][i][1], 1)
                    if len(mu_list[t][i]) == 2:
                        mu_list[t][i].append(np.log(mu_list[t][i][1]))

                for j in range(m_list[t-1]):
                    if len(nu_list[t][j][1].shape) == 1:
                        nu_list[t][j][1] = np.expand_dims(nu_list[t][j][1], 0)
                    if len(nu_list[t][j]) == 2:
                        nu_list[t][j].append(np.log(nu_list[t][j][1]))

        if outputflag:
            print('Reshaping input took ' + str(time()-t_reshape) + ' seconds')

    t_solve = time()
    prev_val = -10**8
    value_f = -100
    value_g = -100
    iter_h = 0
    while iter_h < max_iter and np.abs(prev_val - value_f - value_g) > eps_stop:
        if iter_h % 10 == 0 and outputflag:
            print('Current iteration:', iter_h, 'Current value:', value_f+value_g, 'Current time:', time()-t_solve)
        iter_h += 1
        prev_val = value_f + value_g
        value_f = full_update_f_list()
        value_g = full_update_g_list()
        # print(value_f)
        # print(value_g)
    if outputflag:
        print('Solving took ' + str(time()-t_solve) + ' seconds')

    # get value without entropy
    for t_m in range(t_max):
        t = t_max - t_m - 1
        if t > 0:
            V_t = np.zeros([n_list[t-1], m_list[t-1]])
        if t == t_max-1:
            for i in range(n_list[t-1]):
                for j in range(m_list[t-1]):
                    V_t[i, j] = np.sum(-np.log(cost_list[t][mu_list[t][i][0], :][:, nu_list[t][j][0]]) * f_list[t][i][j] * g_list[t][i][j] * cost_list[t][mu_list[t][i][0], :][:, nu_list[t][j][0]] * (1./const_g_list[t][i, j]) * mu_list[t][i][1] * nu_list[t][j][1])
        elif t > 0:
            for i in range(n_list[t-1]):
                for j in range(m_list[t-1]):
                    V_t[i, j] = np.sum((-np.log(cost_list[t][mu_list[t][i][0], :][:, nu_list[t][j][0]]) + V_tp[mu_list[t][i][0], :][:, nu_list[t][j][0]]) * f_list[t][i][j] * g_list[t][i][j] * cost_list[t][mu_list[t][i][0], :][:, nu_list[t][j][0]] * const_g_list[t+1][mu_list[t][i][0], :][:, nu_list[t][j][0]] * (1./const_g_list[t]) * mu_list[t][i][1] * nu_list[t][j][1])
        else:
            value = np.sum((-np.log(cost_list[0][mu_list[0][0], :][:, nu_list[0][0]]) + V_tp[mu_list[t][0], :][:, nu_list[t][0]]) * f_list[0] * g_list[0] * cost_list[0][mu_list[0][0], :][:, nu_list[0][0]] * const_g_list[t+1][mu_list[t][0], :][:, nu_list[t][0]] * mu_list[0][1] * nu_list[0][1])
        V_tp = V_t.copy()
    return value

In [None]:
vs3 = sinkhorn_causal_markov(mu_list, nu_list, cost_mats, n_list, m_list, ind_tot, ind_next_l, nu_joint_prob, eps_stop=10**-4)