In [1]:
import math 
import numpy as np
from scipy import linalg
import torch 
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [107]:
def toeplitz_conv2d(inp, w, stride, padding='same'):
    """
    %[Y,X,W]  = conv_lay_mat(x,h,s) : FCN form of a MIMO convolutional layer
    % h: DxCxL kernel matrix 
    % with D: number of output channels
    %      C: number of input channels
    %      L: kernel 2D dimensions
    % x: CxT(1)xT(2)xK input
    % with T: image size
    %      K: number of samples
    % s: stride
    % pad: 'full' for generating outputs of length T' = T-L+1
    %      otherwise default is same length T' = T
    % Y: D x (K T'(1) T'(2)/s^2) output matrix
    % X: (C L(1) L(2)) x (K T'(1) T'(2)/s^2) input matrix
    % W: D x (C L(1) L(2)) matrix 
    """

    batch_size, channels, height, width = inp.shape
    out_size, _, kernel_height, kernel_width = w.shape
    
    W = []
    X = []

    if padding == 'same':
        L1 = math.floor((L-1)/2)
        L2 = math.ceil((L-1)/2)
    elif padding == 'full':
        L1 = [0 0]
        L2 = [0 0]
    else:
        print ('padding undefined')

    for c in range(channels):
        W = [W reshape(h(:,c,:,:),[D L(1)*L(2)])];
        XX = [];
        for k = 1:K
             for t2 = L(2)-L1(2):s:T(2)+L2(2)
    %            for t1 = L(1)-L1(1):T(1)+L2(1)
                    %xxt = [];
                    XXT = [];
                    for tt2 = t2:-1:t2-L(2)+1
                        if tt2 > 0 & tt2 <= T(2)
                            XXT = [XXT;toeplitz([x(c,L(1)-L1(1):-1:1,tt2,k) zeros(1,L1(1))],[x(c,L(1)-L1(1):end,tt2,k) zeros(1,L2(1))])];
                        else
                            XXT = [XXT;zeros(L(1),T(1))];
                        end
                        %xxt = [xxt x(c,t1:-1:t1-L(1)+1,tt2,k)];
                    end
                    %XX = [XX xxt(:)];
                    XX = [XX XXT(:,1:s:end)];
    %            end
             end

            %XXT = toeplitz([x(c,L-L1:-1:1,k) zeros(1,L1)],[x(c,L-L1:end,k) zeros(1,L2)]);
            %XX = [XX XXT(:,1:s:end)];
        end
        X = [X;XX];
    end

    %size(X)
    %size(W)
    Y = W*X;

    end

In [47]:
# k = np.random.randn(8, 3, 3, 3)
# i = np.random.randn(3, 10, 12)

# T = toeplitz_mult_ch(k, i.shape)
# out = T.dot(i.flatten()).reshape((1, 8, 8, 10))

# check correctness of convolution via toeplitz matrix
print(np.sum((out - F.conv2d(torch.tensor(i).view(1,3,10,12), torch.tensor(k)).numpy())**2))

7.271357924035143e-28


In [163]:
# Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)
inp = torch.randn(32, 3, 32, 32)
w = torch.randn(2, 3, 4, 4)
inp_unf = torch.nn.functional.unfold(inp, (4, 4), stride=1, padding=0)
print (inp_unf.shape, w.view(w.size(0), -1).t().shape)
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
print (out_unf.shape)
out = out_unf.view(32, 2, 29, 29)
torch.sum((F.conv2d(inp, w, stride=1, padding=0) - out)**2)
# tensor(1.9073e-06)

torch.Size([32, 48, 841]) torch.Size([48, 2])
torch.Size([32, 2, 841])


tensor(5.1425e-08)

In [156]:
T = toeplitz_mult_ch(w, inp.squeeze().shape)
out_toe = T.dot(inp.flatten())
print (out_toe.shape)

(2, 29, 29)
(1682,)


In [164]:
z = out_unf 
fc = torch.sum(z**2, axis=0)
fc = torch.mean(fc)
aW_unf = torch.zeros(w.view(w.size(0), -1).shape)

if fc > 0:
    tW = aW_unf
    for k in range(1):
        tW += 2 * (z[k, :] @ inp_unf[k, :].T)
    aW_unf = -fc * tW / torch.linalg.norm(tW)**2
    
print (aW_unf.shape, out_unf.shape, inp_unf.shape)

torch.Size([2, 48]) torch.Size([32, 2, 841]) torch.Size([32, 48, 841])


In [152]:
z = torch.from_numpy(out_toe).view(1, -1).float()
fc = torch.sum(z**2, axis=0)
fc = torch.mean(fc)
aW_toe = torch.zeros(T.shape)

if fc > 0:
    tW = aW_toe
    for k in range(1):
        tW += 2 * (z[k, :] @ inp[k, :].flatten().T)
    aW_toe = -fc * tW / torch.linalg.norm(tW)**2

RuntimeError: inconsistent tensor size, expected tensor [1682] and src [3072] to have the same number of elements, but got 1682 and 3072 elements respectively

In [161]:
out_toe.reshape(1, 2, 29, 29).shape, inp.shape

((1, 2, 29, 29), torch.Size([1, 3, 32, 32]))