In [1]:
import numpy as np
import tensorflow as tf

In [2]:
import scipy.sparse as sp
import time

In [3]:
x = np.linspace(0, 1, 4)
xx, yy = np.meshgrid(x, x)

# input_img = np.tanh(np.random.rand(10, 200, 4, 4))
input_img = np.empty(shape=(10, 200, 4, 4))
for i in range(input_img.shape[1]):
    std = np.random.rand()
    c = np.random.rand()+1
    input_img[0, i] = c * np.exp(-0.5*(xx**2+yy**2)/std**2)

for i in range(1, input_img.shape[0]):
    temp = input_img[i-1, :, -1, :].copy()
    input_img[i, :, 1:, :] = input_img[i, :, 0:-1, :]
    input_img[i, :, 0, :] = temp

In [4]:
class periodic_padding(object):
    # only operates on square images
    def __init__(self, elp, erp, **kwargs):
        super().__init__()
        self.elp = elp
        self.erp = erp
        self.M_mat = None
        self.N_mat = None
        
    def build(self, input_shape):
        # super().build(input_shape)
        input_xlen = input_shape[-1]
        input_ylen = input_shape[-2] # must be the same as `input_xlen`
        M_mat = np.zeros(shape=(input_ylen+self.elp+self.erp, input_ylen), dtype=np.float32)
        for i in range(1, self.elp + 1):
            M_mat[self.elp - i, (input_ylen - i) % input_ylen] = 1.
        for i in range(self.elp, self.elp + input_ylen):
            M_mat[i, i-self.elp] = 1.
        for i in range(self.elp + input_ylen, M_mat.shape[0]):
            M_mat[i, (i - (self.elp + input_ylen)) % input_ylen] = 1.
        # B_mat = np.empty(shape=(A_mat.shape[1], A_mat.shape[0]), dtype=A_mat.dtype)
        N_mat = np.transpose(M_mat)
        
        self.M_mat = M_mat
        self.N_mat = N_mat

    def call(self, x, training=None):
        '''
        x has shape --> [batch_size, channels, y_len, x_len]
        '''
        return np.matmul(np.matmul(self.M_mat, x), self.N_mat)

In [5]:
kernel_size = 3
elp = int(0.5*(kernel_size-1))
erp = kernel_size - 1 - elp

pp_class = periodic_padding(elp, erp)
pp_class.build(input_img.shape)

In [6]:
input_img_padded = pp_class.call(input_img)

In [7]:
input_img.shape, input_img_padded.shape

((10, 200, 4, 4), (10, 200, 6, 6))

In [8]:
input_channels, input_ylen, input_xlen = input_img_padded.shape[-3:]
output_channels_og, output_ylen, output_xlen = input_img.shape[-3:]

In [9]:
use_bias = True

In [10]:
output_channels = 1

U_shape_rows_0 = kernel_size*kernel_size*input_channels
U_shape_cols_0 = input_channels*input_ylen*input_xlen

U_shape_cols = U_shape_cols_0
if use_bias == True:
    U_shape_rows_0 += 1
    U_shape_cols += 1
    
U_shape_rows_1 = U_shape_rows_0*output_channels
U_shape_rows = U_shape_rows_1*output_ylen*output_xlen*output_channels

In [11]:
# U = np.zeros(shape=(U_shape_rows, U_shape_cols), dtype=np.bool_)
row_ind = []
col_ind = []

for l0 in range(output_channels):
    begin_row_l0 = l0*int(U_shape_rows/output_channels)
    for i0 in range(output_ylen):
        begin_row_i0 = begin_row_l0 + U_shape_rows_1*output_xlen*i0
        for i1 in range(output_xlen):
            begin_row_i1 = begin_row_i0 + U_shape_rows_1*i1
#             for j0 in range(output_channels):
            begin_row_j0 = begin_row_i1 + U_shape_rows_0*l0
            for j1 in range(input_channels):
                begin_row_j1 = begin_row_j0 + j1*kernel_size*kernel_size
                begin_col_j1 = j1*input_xlen*input_ylen

                for k0 in range(kernel_size):
                    begin_col_k0 = begin_col_j1 + i0*input_xlen + i1 + k0*input_xlen
                    for k1 in range(kernel_size):
#                         print(begin_row_j1+k0*kernel_size+k1, begin_col_k0+k1)
#                         U[begin_row_j1+k0*kernel_size+k1, begin_col_k0+k1] = 1.0
                        row_ind.append(begin_row_j1+k0*kernel_size+k1)
                        col_ind.append(begin_col_k0+k1)
            if use_bias == True:
#                 print(begin_row_j0+input_channels*kernel_size*kernel_size, U_shape_cols-1)
#                 U[begin_row_j0+input_channels*kernel_size*kernel_size, U_shape_cols-1] = 1.0
                row_ind.append(begin_row_j0+input_channels*kernel_size*kernel_size)
                col_ind.append(U_shape_cols-1)
#             print('--')

In [12]:
row_ind = np.array(row_ind)
col_ind = np.array(col_ind)

In [13]:
data = np.ones_like(row_ind)

In [14]:
row_ind.shape[0]*32/(8*1024*1024)

0.10992431640625

In [15]:
U_mat = sp.coo_array((data, (row_ind, col_ind)), shape=(U_shape_rows, U_shape_cols)).tocsr()

In [16]:
# S_mat_lst = []
# S_mat_rows = output_channels*output_xlen*output_ylen
# S_mat_cols = U_shape_rows

# for i0 in range(U_shape_rows_1):
#     row_ind = np.arange(S_mat_rows)
#     col_ind = np.arange(i0, S_mat_cols, U_shape_rows_1)
# #     print(col_ind)
#     data = np.ones_like(row_ind)
#     S_mat_lst.append(
#         sp.coo_array((data, (row_ind, col_ind)), shape=(S_mat_rows, S_mat_cols))
#     )

In [17]:
# S_mat_lst

In [50]:
SU_mat_lst = []
SU_mat_rows = output_xlen*output_ylen*output_channels

U_rows_to_index = np.arange(0, U_shape_rows, U_shape_rows_1)-1
# print(U_shape_rows, U_shape_rows_1, U_rows_to_index.shape, U_rows_to_index)
for i0 in range(U_shape_rows_1):
    iter_time = time.time()
    U_rows_to_index += 1
    SU_mat_ind = []
    SU_mat_indptr = []
    SU_mat_indptr.append(len(SU_mat_ind))
#     if i0 == 0:
#         SU_mat_ind = []
#         SU_mat_indptr = []
#         SU_mat_indptr.append(len(SU_mat_ind))
#     else:
#         SU_mat_indptr[0] = 0
    elems_row_ind = 0
    for i1_iter in range(len(U_rows_to_index)):
        i1 = U_rows_to_index[i1_iter]
        row = U_mat.getrow(i1)
        row_ind = row.indices
        # row_indptr = row.indptr
        SU_mat_ind.extend(row_ind)
        SU_mat_indptr.append(len(SU_mat_ind))
#         if i0 == 0:
#             SU_mat_ind.extend(row_ind)
#             SU_mat_indptr.append(len(SU_mat_ind))
#         else:
#             SU_mat_ind[elems_row_ind:elems_row_ind+len(row_ind)] = row_ind
#             SU_mat_indptr[i1_iter+1] = elems_row_ind+len(row_ind)
        elems_row_ind += len(row_ind)
#     if i0 == 0:
#         SU_mat_ind = np.array(SU_mat_ind)
#         SU_mat_indptr = np.array(SU_mat_indptr)
#         SU_mat_data = np.ones_like(SU_mat_ind)
    SU_mat_ind = np.array(SU_mat_ind)
    SU_mat_indptr = np.array(SU_mat_indptr)
    SU_mat_data = np.ones_like(SU_mat_ind)
#     SU_mat_lst.append(sp.csr_array(
#         (SU_mat_data[0:elems_row_ind], SU_mat_ind[0:elems_row_ind], SU_mat_indptr),
#         shape=(SU_mat_rows, U_shape_cols),
#     ))
    SU_mat_lst.append(sp.csr_array(
        (SU_mat_data, SU_mat_ind, SU_mat_indptr),
        shape=(SU_mat_rows, U_shape_cols),
    ))
    print('iter_time : {} s.'.format(time.time() - iter_time))

iter_time : 0.0014030933380126953 s.
iter_time : 0.0012629032135009766 s.
iter_time : 0.0013272762298583984 s.
iter_time : 0.0013225078582763672 s.
iter_time : 0.0013272762298583984 s.
iter_time : 0.0012955665588378906 s.
iter_time : 0.0012848377227783203 s.
iter_time : 0.0013175010681152344 s.
iter_time : 0.002578258514404297 s.
iter_time : 0.003119945526123047 s.
iter_time : 0.003099203109741211 s.
iter_time : 0.0029311180114746094 s.
iter_time : 0.0029337406158447266 s.
iter_time : 0.002917051315307617 s.
iter_time : 0.002893209457397461 s.
iter_time : 0.002569437026977539 s.
iter_time : 0.0025391578674316406 s.
iter_time : 0.0026977062225341797 s.
iter_time : 0.0025217533111572266 s.
iter_time : 0.001527547836303711 s.
iter_time : 0.0011970996856689453 s.
iter_time : 0.0011906623840332031 s.
iter_time : 0.0012049674987792969 s.
iter_time : 0.00179290771484375 s.
iter_time : 0.002491474151611328 s.
iter_time : 0.002476215362548828 s.
iter_time : 0.0025539398193359375 s.
iter_time : 

iter_time : 0.0013391971588134766 s.
iter_time : 0.0008940696716308594 s.
iter_time : 0.0008509159088134766 s.
iter_time : 0.0008466243743896484 s.
iter_time : 0.0008456707000732422 s.
iter_time : 0.0010120868682861328 s.
iter_time : 0.0009002685546875 s.
iter_time : 0.0008428096771240234 s.
iter_time : 0.0008349418640136719 s.
iter_time : 0.0008389949798583984 s.
iter_time : 0.0008683204650878906 s.
iter_time : 0.0008389949798583984 s.
iter_time : 0.000858306884765625 s.
iter_time : 0.0008401870727539062 s.
iter_time : 0.0017802715301513672 s.
iter_time : 0.0018854141235351562 s.
iter_time : 0.00168609619140625 s.
iter_time : 0.0016624927520751953 s.
iter_time : 0.0016024112701416016 s.
iter_time : 0.0015108585357666016 s.
iter_time : 0.0017104148864746094 s.
iter_time : 0.001867055892944336 s.
iter_time : 0.0018508434295654297 s.
iter_time : 0.0018105506896972656 s.
iter_time : 0.0018401145935058594 s.
iter_time : 0.0017578601837158203 s.
iter_time : 0.0018451213836669922 s.
iter_tim

iter_time : 0.13026165962219238 s.
iter_time : 0.0008037090301513672 s.
iter_time : 0.000789642333984375 s.
iter_time : 0.0007777214050292969 s.
iter_time : 0.0007703304290771484 s.
iter_time : 0.0009152889251708984 s.
iter_time : 0.0009016990661621094 s.
iter_time : 0.0009012222290039062 s.
iter_time : 0.0008995532989501953 s.
iter_time : 0.000911712646484375 s.
iter_time : 0.0008995532989501953 s.
iter_time : 0.0008962154388427734 s.
iter_time : 0.0011577606201171875 s.
iter_time : 0.0009546279907226562 s.
iter_time : 0.0009021759033203125 s.
iter_time : 0.000896453857421875 s.
iter_time : 0.0008931159973144531 s.
iter_time : 0.0008902549743652344 s.
iter_time : 0.0009021759033203125 s.
iter_time : 0.0008962154388427734 s.
iter_time : 0.0008974075317382812 s.
iter_time : 0.00089263916015625 s.
iter_time : 0.0009109973907470703 s.
iter_time : 0.0008962154388427734 s.
iter_time : 0.0008897781372070312 s.
iter_time : 0.0008935928344726562 s.
iter_time : 0.0009095668792724609 s.
iter_tim

In [51]:
A_mat = np.zeros(shape=(U_shape_rows_1, U_shape_rows_1))
B_vec = np.zeros(shape=U_shape_rows_1)
lambda_reg = 1e-5

cols = input_img_padded.shape[-3]*input_img_padded.shape[-2]*input_img_padded.shape[-1]
input_x_mat = input_img_padded.reshape((input_img_padded.shape[0], cols))
input_x_mat = np.concatenate(
    (
        input_x_mat,
        np.ones(shape=(input_x_mat.shape[0], 1))
    ),
    axis=1,     
)
kernels_lst = []
kernels_bias_lst = []

SU_x_lst = []
input_x = input_x_mat[:].transpose()
for i1 in range(U_shape_rows_1):
    SUi1_x = SU_mat_lst[i1].dot(input_x)
    SU_x_lst.append(SUi1_x)
    for i2 in range(i1):
        # SUi2_x = SU_mat_lst[i2].dot(input_x)
        SUi2_x = SU_x_lst[i2]
        A_mat[i1, i2] += np.sum(SUi2_x*SUi1_x)
    A_mat[i1, i1] += np.sum(SUi1_x*SUi1_x)
for i1 in range(U_shape_rows_1):
    for i2 in range(i1+1, U_shape_rows_1):
        A_mat[i1, i2] = A_mat[i2, i1]
A_mat_solver = A_mat + lambda_reg*np.eye(A_mat.shape[0])

for k0 in range(output_channels_og):
    outchannel_time = time.time()
    # A_mat[:, :] = 0.0
    B_vec[:] = 0.
    # for i0 in range(input_img_padded.shape[0]):
    #     single_image_time = time.time()
    #     input_x = input_x_mat[i0]
    #     output_y = input_img[i0, k0].flatten()
    #     for i1 in range(U_shape_rows_1):
    #         SUi1_x = SU_mat_lst[i1].dot(input_x)
    #         B_vec[i1] += np.dot(output_y, SUi1_x)
    #         for i2 in range(i1):
    #             SUi2_x = SU_mat_lst[i2].dot(input_x)
    #             A_mat[i1, i2] += np.dot(SUi2_x, SUi1_x)
    #         A_mat[i1, i1] += np.dot(SUi1_x, SUi1_x)
    #     print('single_image_time : {} s.'.format(time.time() - single_image_time))
#     single_image_time = time.time()
#     # input_x = input_x_mat[:].transpose()
#     output_y = input_img[:, k0].reshape(input_img.shape[0], -1).transpose()
#     for i1 in range(U_shape_rows_1):
#         SUi1_x = SU_mat_lst[i1].dot(input_x)
#         B_vec[i1] += np.sum(output_y*SUi1_x)
#         for i2 in range(i1):
#             SUi2_x = SU_mat_lst[i2].dot(input_x)
#             A_mat[i1, i2] += np.sum(SUi2_x*SUi1_x)
#         A_mat[i1, i1] += np.sum(SUi1_x*SUi1_x)
#     print('single_image_time : {} s.'.format(time.time() - single_image_time))

#     for i1 in range(U_shape_rows_1):
#         for i2 in range(i1+1, U_shape_rows_1):
#             A_mat[i1, i2] = A_mat[i2, i1]

#     A_mat_solver = A_mat + lambda_reg*np.eye(A_mat.shape[0])

    output_y = input_img[:, k0].reshape(input_img.shape[0], -1).transpose()
    for i1 in range(U_shape_rows_1):
        B_vec[i1] += np.sum(output_y*SU_x_lst[i1])

    kernels = np.linalg.inv(A_mat_solver) @ B_vec.reshape(B_vec.shape[0], 1)
    
    kernels_lst.append(kernels[0:-1].flatten().reshape(input_channels, kernel_size, kernel_size))
    kernels_bias_lst.append(kernels.flatten()[-1])
    print('outchannel_time : {} s.'.format(time.time() - outchannel_time))
    

outchannel_time : 0.2398090362548828 s.
outchannel_time : 0.3603050708770752 s.
outchannel_time : 0.46735191345214844 s.
outchannel_time : 0.5430130958557129 s.
outchannel_time : 0.25498175621032715 s.
outchannel_time : 0.23081707954406738 s.
outchannel_time : 0.26042914390563965 s.
outchannel_time : 0.23119020462036133 s.
outchannel_time : 0.24800515174865723 s.
outchannel_time : 0.2476210594177246 s.
outchannel_time : 0.22965049743652344 s.
outchannel_time : 0.2948012351989746 s.
outchannel_time : 0.47194957733154297 s.
outchannel_time : 0.2622413635253906 s.
outchannel_time : 0.25153255462646484 s.
outchannel_time : 0.33835482597351074 s.
outchannel_time : 0.29275012016296387 s.
outchannel_time : 0.24633264541625977 s.
outchannel_time : 0.2480001449584961 s.
outchannel_time : 0.2643718719482422 s.
outchannel_time : 0.5719211101531982 s.
outchannel_time : 0.2683892250061035 s.
outchannel_time : 0.24035334587097168 s.
outchannel_time : 0.2419590950012207 s.
outchannel_time : 0.2417726

In [52]:
kernels_bias_lst

[-1.4366654131069224e-09,
 -1.4453679520835237e-09,
 8.001215162423781e-10,
 4.790737817210622e-10,
 5.78585138861708e-11,
 4.57039849591754e-10,
 -1.546363094728719e-09,
 -3.2403504980085097e-10,
 -1.5702420248359994e-09,
 -1.0062675404272515e-09,
 -8.743405464941082e-10,
 9.20670609412309e-11,
 2.0203197297000624e-09,
 -8.847063092991933e-10,
 -7.855381732286776e-10,
 -1.2305094142258035e-09,
 -8.964458511728927e-10,
 -1.6072046534053506e-09,
 1.8118623689518468e-09,
 -1.277767563140272e-09,
 -1.2537813421643813e-09,
 -9.760470010510025e-10,
 -1.2655439200332037e-09,
 1.5697751758912927e-09,
 -4.2709869868261737e-10,
 5.419703618100757e-10,
 8.189877058446407e-10,
 1.3498242409934949e-09,
 3.492745694591408e-10,
 -7.511573792886852e-10,
 -8.835925432694843e-10,
 -1.5847774981283574e-10,
 7.199032759577285e-10,
 3.9460137253529257e-10,
 -7.957107986481237e-10,
 -1.2197756740524451e-09,
 -9.507695523598405e-10,
 1.1913909685878407e-11,
 -1.5988589693890098e-09,
 -8.958470225410098e-10,

In [68]:
kernels_lst[-1][-2:]

array([[[ 2.84818149e-04, -3.98885197e-04,  4.10040759e-04],
        [-4.73547028e-04,  7.48047260e-03, -1.22115988e-03],
        [-1.16140110e-04,  5.12788559e-05,  1.69421973e-05]],

       [[-4.26416147e-05,  1.35178464e-04, -6.36524558e-05],
        [ 8.38984363e-06,  9.77448000e-03,  1.12372887e-05],
        [ 5.43550825e-05,  2.78848096e-05,  1.85565632e-05]]])

In [54]:
kernels_lst[0].shape

(200, 3, 3)

In [55]:
A_mat_solver

array([[3.92419584e+00, 1.55843294e-17, 1.95224253e-69, ...,
        1.89266724e+00, 1.48407465e+00, 1.98095579e+00],
       [1.55843294e-17, 3.92419584e+00, 1.55843294e-17, ...,
        2.05248711e+00, 1.89266724e+00, 1.98095579e+00],
       [1.95224253e-69, 1.55843294e-17, 3.92419584e+00, ...,
        9.89520701e-01, 2.05248711e+00, 1.98095579e+00],
       ...,
       [1.89266724e+00, 2.05248711e+00, 9.89520701e-01, ...,
        1.51837124e+01, 1.40999151e+01, 1.61747187e+01],
       [1.48407465e+00, 1.89266724e+00, 2.05248711e+00, ...,
        1.40999151e+01, 1.51837124e+01, 1.61747187e+01],
       [1.98095579e+00, 1.98095579e+00, 1.98095579e+00, ...,
        1.61747187e+01, 1.61747187e+01, 1.60000010e+02]])

In [56]:
B_vec

array([ 2.41375276,  2.61757393,  1.26195364, ..., 12.94515374,
       12.02115031, 16.17471871])

In [57]:
SU_x_lst

[array([[1.19610338e-313, 0.00000000e+000, 0.00000000e+000,
         0.00000000e+000, 0.00000000e+000, 0.00000000e+000,
         0.00000000e+000, 0.00000000e+000, 0.00000000e+000,
         0.00000000e+000],
        [4.86767698e-157, 0.00000000e+000, 0.00000000e+000,
         0.00000000e+000, 0.00000000e+000, 0.00000000e+000,
         0.00000000e+000, 0.00000000e+000, 0.00000000e+000,
         0.00000000e+000],
        [1.93312663e-174, 0.00000000e+000, 0.00000000e+000,
         0.00000000e+000, 0.00000000e+000, 0.00000000e+000,
         0.00000000e+000, 0.00000000e+000, 0.00000000e+000,
         0.00000000e+000],
        [1.21080989e-226, 0.00000000e+000, 0.00000000e+000,
         0.00000000e+000, 0.00000000e+000, 0.00000000e+000,
         0.00000000e+000, 0.00000000e+000, 0.00000000e+000,
         0.00000000e+000],
        [4.86767698e-157, 1.19610338e-313, 0.00000000e+000,
         0.00000000e+000, 0.00000000e+000, 0.00000000e+000,
         0.00000000e+000, 0.00000000e+000, 0.0000000

In [62]:
SU_mat_lst[10].indices

array([37, 38, 39, 40, 43, 44, 45, 46, 49, 50, 51, 52, 55, 56, 57, 58],
      dtype=int32)