In [1]:

import matplotlib.pyplot as plt
import numpy as np
import time
import random
from scipy.spatial.distance import cdist


In [2]:


def forward_layers(Qs, lambdas):
    dist_Qs = np.array([cdist(Qs[Q_id][0],Qs[Q_id][1]) for Q_id in range(num_Qs)])
    weighted_sum_mat = np.sum(dist_Qs*lambdas.reshape(num_Qs,1,1), axis=0)
    return weighted_sum_mat

def rel_position(cur_q, nex_q):
    # cur_q, nex_q are n*d matrices
    res_pos = []
    for q_id in range(cur_q.shape[0]):
        res_pos.append(cur_q[q_id].reshape(1,Q_dim) - nex_q)
    return np.array(res_pos)

def _update_params_horiz(cur_coords, nex_coords, lambdas, t_mat, learning_rate=0.002):
    
    Q_rol = cur_coords.shape[1]
    Q_col = nex_coords.shape[1]
    
    rel_dist_list = np.array([5*(cdist(cur_coords[set_id], nex_coords[set_id])-0.5)\
                              for set_id in range(num_Qs)]) 
    rel_vect_list = np.array([rel_position(cur_coords[set_id], nex_coords[set_id])\
                              for set_id in range(num_Qs)])
    res_mat = np.sum(rel_dist_list*lambdas.reshape(num_Qs,1,1), axis=0)
    res_error = (res_mat - t_mat)*indicator_T
        
    for Q_id in range(num_Qs):  
        cur_lambda = lambdas[Q_id]
        rel_dist_M = rel_dist_list[Q_id]
        rel_vect_M = rel_vect_list[Q_id]
        
        resize_error = np.repeat(res_error.reshape(Q_rol,Q_col,1),\
                                 Q_dim,axis=1).reshape(Q_rol,Q_col,Q_dim)

        _delta = 2*cur_lambda*rel_vect_M*resize_error
        
        cur_coords[Q_id] -= learning_rate*np.sum(_delta,axis=1)
        nex_coords[Q_id] += learning_rate*np.sum(_delta,axis=0)  
        lambdas[Q_id] -= learning_rate*np.sum(rel_dist_M*res_error)
        
    return cur_coords, nex_coords, lambdas, np.sum(abs(res_error))




In [3]:

num_inputs = 20
num_hidden = 20
density = 0.5
sparse_mat = True

target_T = np.random.randn(num_inputs,num_hidden)

indicator_T = np.zeros((num_inputs,num_hidden))

if sparse_mat:
    for row_id in range(num_inputs):
        for col_id in range(num_hidden):
            if random.random() < density:
                indicator_T[row_id][col_id] = 1


In [5]:

Q_dim = 3
num_Qs = 5
step_size = 0.0001

x_coords = np.random.rand(num_Qs, num_inputs,Q_dim)
h_coords = np.random.rand(num_Qs, num_hidden,Q_dim)

lambdas_xh = np.random.rand(num_Qs,1)/(0.5*num_Qs)

for ep in range(300000):
    x_coords, h_coords, lambdas_xh, error = _update_params_horiz(x_coords, h_coords, lambdas_xh, target_T, learning_rate=step_size)
    
    if ep%2000 == 0: print(ep, ':', error)




0 : 187.7353901328517
2000 : 127.29095421033509
4000 : 95.83444883872421
6000 : 63.76997636848253
8000 : 41.220902964116924
10000 : 27.796919909493145
12000 : 19.78156978747509
14000 : 14.566381009944646
16000 : 11.052066116438777
18000 : 8.55147316194629
20000 : 6.798326132931509
22000 : 5.5177793569801175
24000 : 4.573955118338496
26000 : 3.8394312507097226
28000 : 3.2568669443520726
30000 : 2.7820128087057725
32000 : 2.390748388944211
34000 : 2.075127544362104
36000 : 1.8084635778678653
38000 : 1.581003562184453
40000 : 1.387697137946039
42000 : 1.222003393608245
44000 : 1.0772882437862064
46000 : 0.9532981007779956
48000 : 0.8461019259427287
50000 : 0.7514334912162988


KeyboardInterrupt: 

In [10]:

xx = torch.rand(20,10)
torch.cdist(xx,xx,p=1)


tensor([[0.0000, 2.8315, 3.5018, 2.6684, 3.2536, 3.2745, 4.2990, 3.9686, 3.6265,
         3.0105, 4.6562, 4.3451, 2.7067, 4.4415, 3.4724, 4.1897, 4.0682, 4.2699,
         3.4365, 4.2265],
        [2.8315, 0.0000, 2.9557, 2.8128, 2.5230, 2.5453, 2.4577, 3.1909, 1.4871,
         2.3451, 2.9022, 2.8215, 1.5530, 2.8251, 2.0857, 3.5528, 2.4442, 2.8691,
         2.9126, 3.5958],
        [3.5018, 2.9557, 0.0000, 4.1629, 3.7989, 3.8064, 2.6796, 2.2937, 3.5774,
         3.8142, 4.7277, 3.2025, 2.5605, 3.9284, 2.1662, 3.8005, 4.5362, 2.9942,
         3.1079, 4.0490],
        [2.6684, 2.8128, 4.1629, 0.0000, 2.9031, 3.4671, 4.3401, 4.9041, 3.7018,
         2.2703, 3.7564, 3.9020, 2.6459, 4.0999, 3.9731, 3.8318, 3.9549, 4.4751,
         3.5647, 4.0755],
        [3.2536, 2.5230, 3.7989, 2.9031, 0.0000, 2.3741, 2.6033, 3.7959, 2.4903,
         2.8895, 3.9061, 3.0171, 2.7073, 3.1136, 2.6069, 3.5222, 3.1668, 2.7425,
         1.9616, 2.2344],
        [3.2745, 2.5453, 3.8064, 3.4671, 2.3741, 0.0000, 3.0

In [None]:

#target_T[20:30,10:20]
(target_T*indicator_T)[20:30,10:20]


In [None]:

def recover_mat(cur_coords, nex_coords, lambdas):
    rel_dist_list = np.array([5*(cdist(cur_coords[set_id], nex_coords[set_id])-0.5) for set_id in range(num_Qs)]) 
    rel_vect_list = np.array([rel_position(cur_coords[set_id], nex_coords[set_id]) for set_id in range(num_Qs)])
    res_mat = np.sum(rel_dist_list*lambdas.reshape(num_Qs,1,1), axis=0)
    
    return res_mat


rec_mat = recover_mat(x_coords, h_coords, lambdas_xh)

#rec_mat[20:30,10:20]
(rec_mat*indicator_T)[20:30,10:20]
