In [1]:


import numpy as np
import time
from numba import jit
import torch
from scipy.spatial import distance_matrix
import os
import h5py
import matplotlib.pyplot as plt




Convert an arbitrary matrix into the weighted sum of distance matrices.


In [2]:

import torch
import torch.nn as nn
from datetime import datetime

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)



cpu


In [3]:


def compute_stress(_model, tar_Wio):

    _W_io = _model.forward()
    loss_Wio = torch.sum(((_W_io-tar_Wio)**2)/_W_io.numel())
    loss_W = loss_Wio

    return loss_W


def update_DyN(_model, _optim, tar_Wio, dyn_epochs, meanL_thres=0.01):
    
    max_loss, mean_loss = 0,0
    for _ep in range(dyn_epochs):
        loss_W = compute_stress(_model, tar_Wio)        
        _optim.zero_grad()
        loss_W.backward()
        _optim.step()
        if _ep%(dyn_epochs//10)==0:
            mean_loss = round(torch.sum(torch.abs(_model.forward()-tar_Wio)).item()/tar_Wio.numel(),8) 
            max_loss = round(torch.max(torch.abs(_model.forward()-tar_Wio)).item(),8)
            
            # Print the current query loss
            dyn_vecs = rand_vec@_model.forward(_prec=-1)
            res_norms = torch.norm(tar_vecs-dyn_vecs, dim=1)
            norm_loss = torch.sum(res_norms)/(dyn_vecs.shape[0]*dyn_vecs.shape[1])
            print(norm_loss/torch.std(tar_vecs))
            
            print('---', _ep, '- DyN Loss:', loss_W.item(), '- Max Loss:', max_loss, '- Mean Loss:', mean_loss,'- Time:', datetime.now().time())
        if mean_loss < meanL_thres: return _model, max_loss, mean_loss
    
    return _model, max_loss, mean_loss


class dynMat(nn.Module):
    def __init__(self, num_input, num_output, num_Qs, q_dim, p=1, _scale=5):
        super(dynMat, self).__init__() 

        self.num_input = num_input
        self.num_output = num_output
        self.num_Qs = num_Qs
        self.q_dim = q_dim
        self.norm_p = p
        self._scale = _scale
        
        # num_input = number of points Y
        self.input_Qs = torch.nn.Parameter(1*torch.rand(num_Qs, num_input, q_dim, device=device))
        
        # num_output = number of points X
        self.output_Qs = torch.nn.Parameter(1*torch.rand(num_Qs, num_output, q_dim, device=device))
        
        # num_Qs = H
        self.lambdas_io = torch.nn.Parameter(torch.randn(num_Qs, 1, 1, device=device))
        
        
    def forward(self, _prec=-1):
        
        if _prec != -1:
            input_Qs = _prec*(torch.div(self.input_Qs, _prec, rounding_mode='floor'))
            output_Qs = _prec*(torch.div(self.output_Qs, _prec, rounding_mode='floor'))
        else:
            input_Qs = self.input_Qs
            output_Qs = self.output_Qs
        
        dist_io = self._scale*(torch.cdist(input_Qs, output_Qs, p=self.norm_p))         
        W_io = torch.sum(dist_io*self.lambdas_io,0)
        return W_io
    


In [4]:

# Note that the matrix in this script is the transpose of the matrix in the paper.
num_rowQ, num_colQ = 500,500
tar_mat = torch.randn(num_rowQ, num_colQ, device=device)

# "z = Ay" in paper is equivalent to "z = y@tar_mat" here.
rand_vec = torch.rand(1000, num_rowQ, device=device)
tar_vecs = rand_vec@tar_mat
print(torch.mean(tar_vecs), torch.std(tar_vecs))


tensor(-0.3595) tensor(12.7227)


In [5]:

meanL_thres = 1e-10
max_dynEpoch = 50000



The parameters required to reconstruct a matrix of shape a*b is q_dim*H_num*(X_num+Y_num).


In [6]:

# q_dim, H_num = [5,5] means that we need 5*5*(500+500)=25000=0.1*500*500
# one-tenth the parameters of the original matrix.
for q_dim, H_num in [[5,5]]:
    print('------')
    configs = {
        'H_num': H_num,
        'q_dim': q_dim,
        'norm_p': 1,
        '_scale': 1
    }

    print(configs)
    
    DyMat_model = dynMat(num_rowQ, num_colQ, configs['H_num'], configs['q_dim'], p=configs['norm_p'], _scale=configs['_scale'])
    DyMat_optim = torch.optim.Adam(DyMat_model.parameters(), lr=1e-4)

    DyMat_model, _max, _mean = update_DyN(DyMat_model, DyMat_optim, tar_mat, max_dynEpoch, meanL_thres=meanL_thres)

    dyn_vecs = rand_vec@DyMat_model.forward(_prec=-1)
    res_norms = torch.norm(tar_vecs-dyn_vecs, dim=1)
    norm_loss = torch.sum(res_norms)/(dyn_vecs.shape[0]*dyn_vecs.shape[1])
    print(norm_loss/torch.std(tar_vecs))


------
{'H_num': 5, 'q_dim': 5, 'norm_p': 1, '_scale': 1}
tensor(5.7826, grad_fn=<DivBackward0>)
--- 0 - DyN Loss: 47.49249267578125 - Max Loss: 16.1501236 - Mean Loss: 6.5507715 - Time: 23:14:23.319784
tensor(0.0963, grad_fn=<DivBackward0>)
--- 5000 - DyN Loss: 1.960894227027893 - Max Loss: 6.60639238 - Mean Loss: 1.11402925 - Time: 23:19:48.717743
tensor(0.0285, grad_fn=<DivBackward0>)
--- 10000 - DyN Loss: 0.9692800045013428 - Max Loss: 4.73323345 - Mean Loss: 0.78511456 - Time: 23:25:20.959070
tensor(0.0240, grad_fn=<DivBackward0>)
--- 15000 - DyN Loss: 0.8706965446472168 - Max Loss: 4.39426851 - Mean Loss: 0.74406669 - Time: 23:31:10.191984
tensor(0.0217, grad_fn=<DivBackward0>)
--- 20000 - DyN Loss: 0.8403550982475281 - Max Loss: 4.47598696 - Mean Loss: 0.73072825 - Time: 23:37:00.727550


KeyboardInterrupt: 

In [7]:

dyn_vecs = rand_vec@DyMat_model.forward(_prec=-1)
res_norms = torch.norm(tar_vecs-dyn_vecs, dim=1)
norm_loss = torch.sum(res_norms)/(dyn_vecs.shape[0]*dyn_vecs.shape[1])
print(norm_loss/torch.std(tar_vecs))


tensor(0.0217, grad_fn=<DivBackward0>)



Feedforward an FC layer if the weights can be represented by a single distance matrix.
There will be many warm-up procedures in a general CPU, so the speed improvement will be insignificant for a smaller matrix (e.g., dim < 1e+4).


In [8]:

def preprocess(X):
    return np.argsort(X, axis=0), np.argsort(np.argsort(X, axis=0), axis=0)

@jit(nopython=True)
def inner_loop(X,order,B,C,n,d,r_dim):
    z = np.zeros(r_dim)
    for k in range(r_dim):
        for i in range(d):
            q = order[k, i]
            z[k] += X[k,i]*(2*C[q,i] - C[n-1,i]) + B[n-1,i] - 2*B[q,i]
    return z

def query(X, order1, order2, y, r_dim):
    n,d = X.shape
    B = np.take_along_axis((((X.T)*y)[:r_dim].T), order1, axis=0).cumsum(axis=0)
    C = (y[order1.T][:r_dim].T).cumsum(axis=0)
    res = inner_loop(X,order2,B,C,n,d,r_dim)
    return res


In [17]:

r_dim, c_dim, q_dim = 1024*32, 1024*36, 3

query_vec = np.random.randn(c_dim)

Row_coords = np.random.randn(r_dim, q_dim)
Col_coords = np.random.randn(c_dim, q_dim)

RC_coords = np.concatenate((Row_coords, Col_coords), axis=0)

print(Row_coords.shape, Col_coords.shape, RC_coords.shape)



(32768, 3) (36864, 3) (69632, 3)


In [18]:

# Naive Method: Obtain the weights via computing the distance matrice with L1-norm
ss = time.time()
FC_weights = distance_matrix(Row_coords, Col_coords, p=1)#.astype(np.float32)    
ee = time.time()
print('TC:', ee-ss)
FC_weights.shape


TC: 45.93584108352661


(32768, 36864)

In [19]:

# Naive Method: Compute the matrix-vector multiplication
naiveT_list = []
for _ in range(11):
    ss = time.time()
    naive_output = FC_weights@query_vec
    ee = time.time()
    naiveT_list.append(ee-ss)
    
print(naiveT_list)
print('Average Time Cost:', sum(naiveT_list[1:])/10)
naive_output[:10]


[0.4056999683380127, 0.5634472370147705, 0.6175868511199951, 0.5821371078491211, 0.5890731811523438, 0.5718629360198975, 0.5798320770263672, 0.5889270305633545, 0.5879559516906738, 0.5862200260162354, 0.547415018081665]
Average Time Cost: 0.5814457416534424


array([327.02502683, 341.96596203, 360.42816509, 198.03636183,
       167.27976265, 432.80154072, 683.53682971, 706.63540331,
       376.82721384, 404.68833259])

In [20]:

# Faster Method: Feedforward the FC layer using the sub-models' states directly without computing the distance matrix.
RC_order1, RC_order2 = preprocess(RC_coords)

exp_query_vec = np.zeros(r_dim + c_dim)
exp_query_vec[-c_dim:] = query_vec

fastT_list = []
for i in range(11):
    ss = time.time()
    fast_output = query(RC_coords, RC_order1, RC_order2, exp_query_vec, r_dim)
    ee = time.time()
    fastT_list.append(ee-ss)
    
print(fastT_list)
print('Average Time Cost:', sum(fastT_list[1:])/10)
fast_output[:10]


[0.005308866500854492, 0.0060770511627197266, 0.0044858455657958984, 0.005629777908325195, 0.004824161529541016, 0.00467681884765625, 0.004562854766845703, 0.004745006561279297, 0.005925178527832031, 0.0048139095306396484, 0.005073070526123047]
Average Time Cost: 0.005081367492675781


array([327.02502683, 341.96596203, 360.42816509, 198.03636183,
       167.27976265, 432.80154072, 683.53682971, 706.63540331,
       376.82721384, 404.68833259])