In [20]:
import numpy as np
from scipy import linalg
import torch 
import torch.nn.functional as F

In [107]:
def toeplitz_1_ch(kernel, input_size):
    # shapes
    k_h, k_w = kernel.shape
    i_h, i_w = input_size
    o_h, o_w = i_h-k_h+1, i_w-k_w+1

    # construct 1d conv toeplitz matrices for each row of the kernel
    toeplitz = []
    for r in range(k_h):
        toeplitz.append(linalg.toeplitz(c=(kernel[r,0], *np.zeros(i_w-k_w)), r=(*kernel[r], *np.zeros(i_w-k_w))) ) 

    # construct toeplitz matrix of toeplitz matrices (just for padding=0)
    h_blocks, w_blocks = o_h, i_h
    h_block, w_block = toeplitz[0].shape

    W_conv = np.zeros((h_blocks, h_block, w_blocks, w_block))

    for i, B in enumerate(toeplitz):
        for j in range(o_h):
            W_conv[j, :, i+j, :] = B

    W_conv.shape = (h_blocks*h_block, w_blocks*w_block)

    return W_conv

def toeplitz_mult_ch(kernel, input_size):
    """Compute toeplitz matrix for 2d conv with multiple in and out channels.
    Args:
        kernel: shape=(n_out, n_in, H_k, W_k)
        input_size: (n_in, H_i, W_i)"""

    kernel_size = kernel.shape
    output_size = (kernel_size[0], input_size[1] - (kernel_size[2] - 1), input_size[2] - (kernel_size[3] - 1))
    print (output_size)
    T = np.zeros((output_size[0], int(np.prod(output_size[1:])), input_size[0], int(np.prod(input_size[1:]))))

    for i, ks in enumerate(kernel):  # loop over output channel
        for j, k in enumerate(ks):  # loop over input channel
            T_k = toeplitz_1_ch(k, input_size[1:])
            T[i, :, j, :] = T_k

    T.shape = (np.prod(output_size), np.prod(input_size))

    return T

In [47]:
# k = np.random.randn(8, 3, 3, 3)
# i = np.random.randn(3, 10, 12)

# T = toeplitz_mult_ch(k, i.shape)
# out = T.dot(i.flatten()).reshape((1, 8, 8, 10))

# check correctness of convolution via toeplitz matrix
print(np.sum((out - F.conv2d(torch.tensor(i).view(1,3,10,12), torch.tensor(k)).numpy())**2))

7.271357924035143e-28


In [163]:
# Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)
inp = torch.randn(32, 3, 32, 32)
w = torch.randn(2, 3, 4, 4)
inp_unf = torch.nn.functional.unfold(inp, (4, 4), stride=1, padding=0)
print (inp_unf.shape, w.view(w.size(0), -1).t().shape)
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
print (out_unf.shape)
out = out_unf.view(32, 2, 29, 29)
torch.sum((F.conv2d(inp, w, stride=1, padding=0) - out)**2)
# tensor(1.9073e-06)

torch.Size([32, 48, 841]) torch.Size([48, 2])
torch.Size([32, 2, 841])


tensor(5.1425e-08)

In [156]:
T = toeplitz_mult_ch(w, inp.squeeze().shape)
out_toe = T.dot(inp.flatten())
print (out_toe.shape)

(2, 29, 29)
(1682,)


In [164]:
z = out_unf 
fc = torch.sum(z**2, axis=0)
fc = torch.mean(fc)
aW_unf = torch.zeros(w.view(w.size(0), -1).shape)

if fc > 0:
    tW = aW_unf
    for k in range(1):
        tW += 2 * (z[k, :] @ inp_unf[k, :].T)
    aW_unf = -fc * tW / torch.linalg.norm(tW)**2
    
print (aW_unf.shape, out_unf.shape, inp_unf.shape)

torch.Size([2, 48]) torch.Size([32, 2, 841]) torch.Size([32, 48, 841])


In [152]:
z = torch.from_numpy(out_toe).view(1, -1).float()
fc = torch.sum(z**2, axis=0)
fc = torch.mean(fc)
aW_toe = torch.zeros(T.shape)

if fc > 0:
    tW = aW_toe
    for k in range(1):
        tW += 2 * (z[k, :] @ inp[k, :].flatten().T)
    aW_toe = -fc * tW / torch.linalg.norm(tW)**2

RuntimeError: inconsistent tensor size, expected tensor [1682] and src [3072] to have the same number of elements, but got 1682 and 3072 elements respectively

In [161]:
out_toe.reshape(1, 2, 29, 29).shape, inp.shape

((1, 2, 29, 29), torch.Size([1, 3, 32, 32]))