In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchsummary import summary
import torch.optim as optim
from torch.autograd import Function
def get_parameter_number(net):
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

dtype = torch.float

class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, *size ,dtype = dtype)
        self.label = torch.rand(1, dtype = dtype)

    def __getitem__(self, index):
        return self.data[index], self.label

    def __len__(self):
        return self.len

        
def Random_DataLoader():
    input_size = [3]

    train_img_loader = DataLoader(dataset=RandomDataset(input_size, length = 100),
                         batch_size=2, shuffle=True)
    val_img_loader =  DataLoader(dataset=RandomDataset(input_size, length = 10),
                         batch_size=2, shuffle=False)

    return train_img_loader, val_img_loader


trainloader, valloader =  Random_DataLoader()

In [125]:
class Linear(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        super(Linear, self).__init__()
        self.input_features = input_features
        self.output_features = output_features
        self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(output_features))
        else:
            self.register_parameter('bias', None)
        self.weight.data.uniform_(-0.1, 0.1)
        if bias is not None:
            self.bias.data.uniform_(-0.1, 0.1)
        
    def forward(self, input):
        # See the autograd section for explanation of what happens here.
        return LinearFunction.apply(input, self.weight, self.bias)

    def extra_repr(self):
        # (Optional)Set the extra information about this module. You can test
        # it by printing an object of this class.
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None)

In [190]:
demo_net = Linear(3,1)

criterion = nn.MSELoss()
optimizer = optim.SGD(demo_net.parameters(), lr=0.001, momentum=0.9)


for epoch in range(1):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        
        inputs, labels = data
        
        optimizer.zero_grad()

        outputs = demo_net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 20 == 19:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0
        break
print('Finished Training')

Finished Training


In [201]:
class LinearFunction(torch.autograd.Function):
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())

        
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
            
        return output

    @staticmethod
    def backward(ctx, grad_output):
        
        print(grad_output.size())
        input, weight, bias = ctx.saved_tensors
        
        grad_input = grad_weight = grad_bias = None
        
#         print(ctx.needs_input_grad)
        if ctx.needs_input_grad[0]:
            print('weight:',weight.size())
            grad_input = grad_output.mm(weight)

            
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)

                 
        if bias is not None and ctx.needs_input_grad[2]:
            
            grad_bias = grad_output.sum(0).squeeze(0)
            grad_bias = grad_bias.view_as(bias)

            
            
#         print('grad_output',grad_output)
#         print('grad_input',grad_input)
            
        return grad_input, grad_weight, grad_bias
        # there backward dl/dw, dl/db
        # and the grad_output is dl/dy

In [202]:
# test two demo backward

criterion = nn.MSELoss()
in_random = torch.rand([1,3])
label = torch.Tensor([1])

demo_net1 = Linear(3,2)
demo_net2 = Linear(2,1)
out_put = demo_net1(in_random)
out_put = demo_net2(out_put)

loss = criterion(out_put, label)

loss.backward()

torch.Size([1, 1])
weight: torch.Size([1, 2])
torch.Size([1, 2])


In [None]:
### test 2d transform matrix
input_matrix = torch.rand(2, 6)

kernel_size = 3
input_channels = 2
output_channels = 2
ks = kernel_size // 2
per_col = len(ls) - ks -1
transform_maxtrix = torch.zeros([per_col * 2 , kernel_size])
fil  = torch.rand([3,output_channels])


In [None]:
for index in range(len(input_matrix)):
    ls = input_matrix[index,:]
    for i in range(ks, len(ls) - ks):
        transform_maxtrix[i-1 + index * per_col,:] = ls[i-ks: i+ks+1]


output_matrix = transform_maxtrix.mm(fil)
output = output.T
output.size()

In [None]:
### test reshape order
a = np.array([[[111,112],[121,122]],[[211,212],[221,222]]])
a.reshape(2,-1, order = 'F')

In [None]:
### test conv1d kernel and input shape
# input size = [conv_num, kernel_size * in_channels]
# kernel matirx shape = [kernel_size * in_channels, out_channels]

# test no batch_size
a = torch.rand([3,4,5])
# a shape = [in_channels, row, col]
a = a.reshape([3, -1])
# a shape = [in_channels, row * col]

# 


In [None]:
### test 4d transform matrix

a = torch.rand([2,3,4,5])
b = a.reshape(-1,a.size(3))
len_row = a.size(3)
# b.size() = [batch_size * x * y, z]
# print(a)
# print(b)

In [None]:
### test 4d transform matrix
half_kernel_size = kernel_size // 2
kernel_size = 3
index_start = half_kernel_size
index_end = len_row - half_kernel_size

res = torch.empty([0])
for i in b:
    for index in range(index_start, index_end):
        temp  = i[index-half_kernel_size : index+half_kernel_size+1]
        res = torch.cat([res,temp],dim = 0)  
    # there is copy, don't warry about editing other torch

res = res.reshape(-1,kernel_size)
res.size()

In [None]:
### test 4d transform matrix
kernel = torch.rand([3,1])
out = res.mm(kernel)
print(out.size())
c = out.reshape(2,3,4,-1)
print(c.size())

In [None]:
#test in_channels conv1
a = torch.rand([2,3,4,5])
# a.size() = batch_size, in_channels, row, col
# a.size() = batch_size, row, col, in_channels
a =a.permute(0, 2, 3, 1)
a.size()
len_row = a.size(3)
b = a.reshape(-1,len_row)


In [None]:
#test in_channels conv1

kernel_size = 3
half_kernel_size = kernel_size // 2
index_start = half_kernel_size
index_end = len_row - half_kernel_size

res = torch.empty([0])
for i in b:
    for index in range(index_start, index_end):
        temp  = i[index-half_kernel_size : index+half_kernel_size+1]
        res = torch.cat([res,temp],dim = 0)  
    # there is copy, don't warry about editing other torch

res = res.reshape(-1,kernel_size)
res.size()

In [None]:
#test in_channels conv1

kernel = torch.rand([3,1])
out = res.mm(kernel)


In [None]:
#test in_channels conv1

c = out.reshape(2,4,5,-1)
c =c.permute(0, 3, 1, 2)

# c.size() = batch, col, row, out_channels
print(c.size())
# print(out)
# print(c)
# out.size = [batch_size, out_channels, row, col]

In [None]:
#test row conv1

c =c.permute(0, 1, 3, 2)
# a.size() = batch_size, col, , out_channels

print(c.size())
len_row = c.size(3)
b = c.reshape(-1,len_row)


In [None]:
### test 4d transform matrix
def matrix3d_matrix1d(input_matrix, kernel_size ):
    
    input_row = input_matrix.size(0)
    input_col = input_matrix.size(1)
    
    ks = kernel_size // 2
    trans_subrow = len(input_matrix[0]) - ks -1
    trans_row = trans_subrow * input_row
    trans_col = kernel_size
    transform_maxtrix = torch.zeros([trans_row , trans_col])
    
    for index in range(input_row):
        
        ls = input_matrix[index,:]
        
        for i in range(ks, len(ls) - ks):
            transform_maxtrix[i-1 + index * trans_subrow,:] = ls[i-ks: i+ks+1]
    return transform_maxtrix

def myconv1d(input_m, kernel,kernel_size):
    # batch_size, in_channels, in_row, in_col
    
    temp_m = input_m.reshape(-1,input_m.size(1))
    res = matrix3d_matrix1d(temp_m, kernel_size )
    
    res = res.mm(kernel)
    res = res.reshape(input_m.size(0),-1,input_m.size(2),kernel.size(1))
    
    return res 


In [None]:
#test row conv1


kernel_size = 3
half_kernel_size = kernel_size // 2
index_start = half_kernel_size
index_end = len_row - half_kernel_size

res = torch.empty([0])
for i in b:
    for index in range(index_start, index_end):
        temp  = i[index-half_kernel_size : index+half_kernel_size+1]
        res = torch.cat([res,temp],dim = 0)  
    # there is copy, don't warry about editing other torch

res = res.reshape(-1,kernel_size)
res.size()

In [None]:
#test row conv1

kernel = torch.rand([3,1])
out = res.mm(kernel)


In [None]:
#test in_channels conv1

d = out.reshape(2, 1, 5,-1)
d =d.permute(0, 1, 3, 2)

# c.size() = batch, col, row, out_channels
# print(c.size())
# print(out)
# print(c)
# out.size = [batch_size, out_channels, row, col]
d.size()

In [None]:
#test col conv1

e =d
# a.size() = batch_size, col, , out_channels

print(e.size())
len_row = e.size(3)
b = e.reshape(-1,len_row)

kernel_size = 3
half_kernel_size = kernel_size // 2
index_start = half_kernel_size
index_end = len_row - half_kernel_size

res = torch.empty([0])
for i in b:
    for index in range(index_start, index_end):
        temp  = i[index-half_kernel_size : index+half_kernel_size+1]
        res = torch.cat([res,temp],dim = 0)  
    # there is copy, don't warry about editing other torch

res = res.reshape(-1,kernel_size)
res.size()

In [None]:
#test col conv1

kernel = torch.rand([3,1])
out = res.mm(kernel)

f = out.reshape(2, 1, 2,-1)

# c.size() = batch, col, row, out_channels
# print(c.size())
# print(out)
# print(c)
# out.size = [batch_size, out_channels, row, col]
f.size()

In [None]:
#test mul_out_channels, basis configure
batch_size = 2
in_channels = 3
out_channels = 2
row = 4
col = 5
my_input = torch.rand([batch_size,in_channels,row,col])
L_wight_size = 3
# a.size() = batch_size, row, col, in_channels

In [None]:
#test mul_out_channels, in_channels part

in_L =my_input.permute(0, 2, 3, 1)
print('in_L:',my_input.size())

len_row = in_L.size(3)
b = in_L.reshape(-1,len_row)

L_wight_size = 3
half_L_ws = L_wight_size // 2
index_start = half_L_ws
index_end = len_row - half_L_ws

res = []

for i in b:
    for index in range(index_start, index_end):
        temp  = i[index-half_L_ws : index+half_L_ws+1]
        res.append(temp)
        
res = torch.cat(res)  
res = res.reshape(-1,L_wight_size)
print('res:', res.size())



L_weight = torch.rand([L_wight_size, out_channels])
out_L = res.mm(L_weight)


out_L = out_L.reshape(2,4,5,-1)
out_L =out_L.permute(0, 3, 1, 2)

print('out_L:',out_L.size())


In [None]:
in_H

In [None]:
#test mul_out_channels, row part

in_H =out_L.permute(1, 0, 3, 2)
# in_H shape = out_channels, batch_size, col, row

print('in_H:', out_L.size())
print('permute in_H:', in_H.size())

len_row = in_H.size(3)

H_wight_size = 3
half_H_ws = H_wight_size // 2
index_start = half_H_ws
index_end = len_row - half_H_ws


res = []
for per_channel in in_H:
    per_channel = per_channel.reshape(-1,len_row)
    # per_channel size = batch_size, col, row
    # per_channel size = batch_size * col, row
    temp_res = []
    
    for i in per_channel:
        for index in range(index_start, index_end):
            temp  = i[index-half_H_ws : index+half_H_ws+1]
            temp_res.append(temp)
            
    temp_res = torch.cat(temp_res,dim = 0)
    temp_res = temp_res.reshape(-1, H_wight_size)
    res.append(temp_res)

    
res = torch.cat(res,dim = 0)
res = res.reshape(out_channels, -1, H_wight_size)

print('res:', res.size())

out_H = []
H_weight = torch.rand([out_channels, H_wight_size]) #out_channels,1])

for i in range(out_channels):
    sub_weight = H_weight[i,:]
    sub_weight.unsqueeze_(1)
    sub_feature = res[i,:,:]
    temp = sub_feature.mm(sub_weight)
    out_H.append(temp)

out_H = torch.cat(out_H,dim=0)
# print('out:', out.size())
out_H = out_H.reshape(out_channels, batch_size, col,-1)
out_H =out_H.permute(1, 0, 3, 2)
print('out_H:', out_H.size())



In [None]:
#test mul_out_channels, row part

in_H =out_L.permute(1, 0, 3, 2)
# in_H shape = out_channels, batch_size, col, row

print('in_H:', out_L.size())
print('permute in_H:', in_H.size())

len_row = in_H.size(3)

H_wight_size = 3
half_H_ws = H_wight_size // 2
index_start = half_H_ws
index_end = len_row - half_H_ws


res = []
for per_channel in in_H:
    per_channel = per_channel.reshape(-1,len_row)
    # per_channel size = batch_size, col, row
    # per_channel size = batch_size * col, row
    temp_res = []
    
    for i in per_channel:
        for index in range(index_start, index_end):
            temp  = i[index-half_H_ws : index+half_H_ws+1]
            temp_res.append(temp)
            
    temp_res = torch.cat(temp_res,dim = 0)
    temp_res = temp_res.reshape(-1, H_wight_size)
    res.append(temp_res)

    
res = torch.cat(res,dim = 0)
res = res.reshape(out_channels, -1, H_wight_size)

print('res:', res.size())

out_H = []
H_weight = torch.rand([out_channels, H_wight_size]) #out_channels,1])

for i in range(out_channels):
    sub_weight = H_weight[i,:]
    sub_weight.unsqueeze_(1)
    sub_feature = res[i,:,:]
    temp = sub_feature.mm(sub_weight)
    out_H.append(temp)

out_H = torch.cat(out_H,dim=0)
# print('out_H:', out_H.size())
out_H = out_H.reshape(out_channels, batch_size, col,-1)
out_H =out_H.permute(1, 0, 3, 2)
print('out_H:', out_H.size())



In [None]:
#test mul_out_channels, col part
new_row = out_H.size(2)
in_V =out_H.permute(1, 0, 2, 3)
# in_V shape = out_channels, batch_size, row, col 

print('in_V:', out_H.size())
print('permute in_V:', in_V.size())

len_row = in_V.size(3)

V_wight_size = 3
half_V_ws = V_wight_size // 2
index_start = half_V_ws
index_end = len_row - half_V_ws


res = []
for per_channel in in_V:
    per_channel = per_channel.reshape(-1,len_row)
    # per_channel size = batch_size, row, col
    # per_channel size = batch_size * row, col
    temp_res = []
    
    for i in per_channel:
        for index in range(index_start, index_end):
            temp  = i[index-half_V_ws : index+half_V_ws+1]
            temp_res.append(temp)
            
    temp_res = torch.cat(temp_res,dim = 0)
    temp_res = temp_res.reshape(-1, V_wight_size)
    res.append(temp_res)

    
res = torch.cat(res,dim = 0)
res = res.reshape(out_channels, -1, V_wight_size)

print('res:', res.size())

out_V = []
V_weight = torch.rand([out_channels, V_wight_size]) #out_channels,1])

for i in range(out_channels):
    sub_weight = V_weight[i,:]
    sub_weight.unsqueeze_(1)
    sub_feature = res[i,:,:]
    temp = sub_feature.mm(sub_weight)
    out_V.append(temp)

out_V = torch.cat(out_V,dim=0)
# print('out_V:', out_V.size())
out_V = out_V.reshape(out_channels, batch_size, new_row,-1)
out_V =out_V.permute(1, 0, 2, 3)
print('out_V:', out_V.size())


In [16]:
#test mul_out_channels class

class test_mul_out_channels():
    def __init__(self, out_channels):
        self.row = 0
        self.col = 0
        self.out_channels = out_channels
#         self.in_channels = in_channels
        self.batch_size =0
        
    def full_conv(self,input):
        self.batch_size = input.size(0)
        self.in_channels = input.size(1)
        self.row = input.size(2)
        self.col = input.size(3)
        
        in_L =input.permute(0, 2, 3, 1)
        out_L = self.conv1d_L(in_L, wight_size = self.in_channels)
        out_L =out_L.permute(0, 3, 1, 2)
        print('out_L:', out_L.size())
        
        
        
        in_H =out_L.permute(1, 0, 3, 2)
        out_H = self.conv1d_VH(in_H, wight_size =3 )
        out_H =out_H.permute(1, 0, 3, 2)
        print('out_H:', out_H.size())

        in_V =out_H.permute(1, 0, 2, 3)
        out_V = self.conv1d_VH(in_V, wight_size =3 )
        out_V =out_V.permute(1, 0, 2, 3)
        print('out_V:', out_V.size())
        return out_V
        
        

    
    def conv1d_L(self, in_matrix, wight_size):
        

#         print('in_L:',my_input.size())

        len_row = in_matrix.size(3)
        b = in_matrix.reshape(-1,len_row)

        L_wight_size = wight_size
        half_L_ws = L_wight_size // 2
        index_start = half_L_ws
        index_end = len_row - half_L_ws

        res = []

        for i in b:
            for index in range(index_start, index_end):
                temp  = i[index-half_L_ws : index+half_L_ws+1]
                res.append(temp)

        res = torch.cat(res)  
        res = res.reshape(-1,L_wight_size)

        print(res.size())

        L_weight = torch.rand([L_wight_size, self.out_channels])
        out_L = res.mm(L_weight)
        out_L = out_L.reshape(self.batch_size , self.row , self.col, -1)

        return out_L
        
        
        
    def conv1d_VH(self, in_matrix, wight_size):

        out_shape = list(in_matrix.size())
        out_shape = out_shape[:-1] + [-1]
        in_V =in_matrix

        len_row = in_V.size(3)
        
        V_wight_size = wight_size
        half_V_ws = V_wight_size // 2
        index_start = half_V_ws
        index_end = len_row - half_V_ws


        res = []
        for per_channel in in_V:
            per_channel = per_channel.reshape(-1,len_row)
            # per_channel size = batch_size, row, col
            # per_channel size = batch_size * row, col
            temp_res = []

            for i in per_channel:
                for index in range(index_start, index_end):
                    temp  = i[index-half_V_ws : index+half_V_ws+1]
                    temp_res.append(temp)

            temp_res = torch.cat(temp_res,dim = 0)
            temp_res = temp_res.reshape(-1, V_wight_size)
            res.append(temp_res)


        res = torch.cat(res,dim = 0)
        res = res.reshape(self.out_channels, -1, V_wight_size)

#         print('res:', res.size())

        out_V = []
        V_weight = torch.rand([self.out_channels, V_wight_size]) #out_channels,1])

        for i in range(out_channels):
            sub_weight = V_weight[i,:]
            sub_weight.unsqueeze_(1)
            sub_feature = res[i,:,:]
            temp = sub_feature.mm(sub_weight)
            out_V.append(temp)

        out_V = torch.cat(out_V,dim=0)
        # print('out_V:', out_V.size())
        out_V = out_V.reshape(*out_shape)
        return out_V


In [153]:

#test mul_out_channels class

batch_size = 2
in_channels = 3
out_channels = 2
row = 4
col = 5
my_input = torch.rand([batch_size,in_channels,row,col])
print(my_input.size())


S = test_mul_out_channels(out_channels =2)
c = S.full_conv(my_input )
c.size()

torch.Size([2, 3, 4, 5])
torch.Size([40, 3])
out_L: torch.Size([2, 2, 4, 5])
out_H: torch.Size([2, 2, 2, 5])
out_V: torch.Size([2, 2, 2, 3])


torch.Size([2, 2, 2, 3])

In [154]:
# test Module class
class flatten_conv2d(nn.Module):
    def __init__(self, in_channels, out_channels, in_row, in_col,
                kernel_size, stride=1, padding=0, 
                dilation=1, groups=1, bias=True, padding_mode='zeros'):

        super(flatten_conv2d, self).__init__()
        
#         self.in_row = in_row
#         self.in_col = in_col
        self.in_channels = in_channels
        self.out_channels = out_channels
#         self.kernel_size = kernel_size
        wight_size = self.in_channels
        self.L1_weight = nn.Parameter(torch.Tensor(in_channels, out_channels))
        self.L1_bias = nn.Parameter(torch.Tensor(out_channels))
        
        self.V1_weight = nn.Parameter(torch.Tensor(kernel_size, out_channels))
        self.V1_bias = nn.Parameter(torch.Tensor(out_channels))
        
        self.H1_weight = nn.Parameter(torch.Tensor(kernel_size, out_channels))
        self.H1_bias = nn.Parameter(torch.Tensor(out_channels))

        
        self.L1_weight.data.uniform_(-0.1, 0.1)
        self.L1_bias.data.uniform_(-0.1, 0.1)
        
        self.V1_weight.data.uniform_(-0.1, 0.1)
        self.V1_bias.data.uniform_(-0.1, 0.1)
               
        self.H1_weight.data.uniform_(-0.1, 0.1)
        self.H1_bias.data.uniform_(-0.1, 0.1)
         
    def forward(self, input):
#         x = self.matrix3d_matrix1d(input, self.kernel_size )
        # See the autograd section for explanation of what happens here.
#         res = test_fun.apply(input, self.L1_weight, self.L1_bias)
        
        out_L = conv1d_L.apply(input, self.L1_weight, self.L1_bias)
        print('out_L:', out_L.size())
        
        out_V = conv1d_V.apply(out_L, self.V1_weight, self.V1_bias)
        print('out_V:', out_V.size())
        
        out_H = conv1d_H.apply(out_V, self.H1_weight, self.H1_bias)
        print('out_H:', out_H.size())
        
        return out_H

#     def extra_repr(self):
#         # (Optional)Set the extra information about this module. You can test
#         # it by printing an object of this class.
#         return 'in_features={}, out_features={}, bias={}'.format(
#             self.in_features, self.out_features, self.bias is not None)

In [167]:

#test mul_out_channels class

batch_size = 2
in_channels = 3
out_channels = 3
row = 4
col = 5
kernel_size = 3
my_input = torch.rand([batch_size,in_channels,row,col])

S = flatten_conv2d(in_channels= in_channels, 
                   out_channels= out_channels , 
                   in_row=row, in_col=col, kernel_size=kernel_size)

my_output = S(my_input )
my_output.size()




out_L: torch.Size([2, 3, 4, 5])
out_V: torch.Size([2, 3, 2, 5])
out_H: torch.Size([2, 3, 2, 3])


torch.Size([2, 3, 2, 3])

In [168]:
class conv1d_L(torch.autograd.Function):
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias):
        ctx.save_for_backward(input, weight, bias)
        
        in_L =input.permute(0, 2, 3, 1)
        
        out_shape = list(in_L.size())
        out_shape = out_shape[:-1] + [-1]
        
        
        len_row = in_L.size(3)
        L_wight_size = weight.size(0)
        half_L_ws = L_wight_size // 2
        index_start = half_L_ws
        index_end = len_row - half_L_ws
        
        b = in_L.reshape(-1,len_row)
        res = []

        for i in b:
            for index in range(index_start, index_end):
                temp  = i[index-half_L_ws : index+half_L_ws+1]
                res.append(temp)

        res = torch.cat(res)  
        res = res.reshape(-1,L_wight_size)

    
        out_L = res.mm(weight)
        out_L = out_L + bias
        out_L = out_L.reshape(*out_shape)
        out_L = out_L.permute(0, 3, 1, 2)

        return out_L



    @staticmethod
    def backward(ctx, grad_output):
        print('conv1d_L:',ctx.needs_input_grad)

        input, weight, bias = ctx.saved_tensors
        
        grad_input = grad_weight = grad_bias = None
        
        return grad_input, grad_weight, grad_bias





In [183]:
class conv1d_V(torch.autograd.Function):
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias):
        ctx.save_for_backward(input, weight, bias)
        in_V =input.permute(1, 0, 3, 2)
        
        
        out_channels = in_V.size(0)
        
        out_shape = list(in_V.size())
        out_shape = out_shape[:-1] + [-1]
        
        
        len_row = in_V.size(3)
        wight_size = weight.size(0)
        half_ws = wight_size // 2
        index_start = half_ws
        index_end = len_row - half_ws
        


        res = []
        for per_channel in in_V:
            per_channel = per_channel.reshape(-1,len_row)
            
            temp_res = []

            for i in per_channel:
                for index in range(index_start, index_end):
                    temp  = i[index-half_ws : index+half_ws+1]
                    temp_res.append(temp)

            temp_res = torch.cat(temp_res,dim = 0)
            temp_res = temp_res.reshape(-1, wight_size)
            res.append(temp_res)


        res = torch.cat(res,dim = 0)
        res = res.reshape(out_channels, -1, wight_size)

        out_V = []

        for i in range(out_channels):
            sub_weight = weight[:,i]
            sub_weight.unsqueeze_(1)
            sub_feature = res[i,:,:]
            sub_bias = bias[i]
            
            temp = sub_feature.mm(sub_weight)
            temp +=  sub_bias

            
            out_V.append(temp)

        out_V = torch.cat(out_V,dim=0)
        # print('out_V:', out_V.size())
        out_V = out_V.reshape(*out_shape)
        
        out_V =out_V.permute(1, 0, 3, 2)
        
        return out_V



    @staticmethod
    def backward(ctx, grad_output):
        print('conv1d_V:',ctx.needs_input_grad)
        
#         print(grad_output)
        input, weight, bias = ctx.saved_tensors
        
        grad_input = grad_weight = grad_bias = None
        
        return grad_input, grad_weight, grad_bias





In [211]:
class conv1d_H(torch.autograd.Function):
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias):
        ctx.save_for_backward(input, weight, bias)
        
        in_matrix =input.permute(1, 0, 2, 3)
        
        
        out_channels = in_matrix.size(0)
        
        out_shape = list(in_matrix.size())
        out_shape = out_shape[:-1] + [-1]
        
        
        len_row = in_matrix.size(3)
        wight_size = weight.size(0)
        half_ws = wight_size // 2
        index_start = half_ws
        index_end = len_row - half_ws
        


        res = []
        for per_channel in in_matrix:
            per_channel = per_channel.reshape(-1,len_row)
            
            temp_res = []

            for i in per_channel:
                for index in range(index_start, index_end):
                    temp  = i[index-half_ws : index+half_ws+1]
                    temp_res.append(temp)

            temp_res = torch.cat(temp_res,dim = 0)
            temp_res = temp_res.reshape(-1, wight_size)
            res.append(temp_res)


        res = torch.cat(res,dim = 0)
        res = res.reshape(out_channels, -1, wight_size)

        out_matrix = []
        

        for i in range(out_channels):
            sub_weight = weight[:,i]
            sub_weight.unsqueeze_(1)
            sub_feature = res[i,:,:]
            sub_bias = bias[i]
            print('sub_feature,',sub_feature.size())
            
            temp = sub_feature.mm(sub_weight)
            temp +=  sub_bias

            
            out_matrix.append(temp)

        out_matrix = torch.cat(out_matrix,dim=0)
        # print('out_V:', out_V.size())
        out_matrix = out_matrix.reshape(*out_shape)
        
        out_matrix =out_matrix.permute(1, 0, 2, 3)
        
        return out_matrix



    @staticmethod
    def backward(ctx, grad_output):
        print('conv1d_H:',ctx.needs_input_grad)
        print('H.grad_output:,',grad_output.size())
        
        
#         print(grad_output)
#         print(grad_output)
        input, weight, bias = ctx.saved_tensors
        
        grad_input = grad_weight = grad_bias = None
        
        
        
        ############
        if ctx.needs_input_grad[1]:
            
#             grad_weight = grad_output.t().mm(input)

#             print(grad_output)
            print(grad_output.size())
            print(weight.size())
#             grad_input = grad_output.mm(weight)


        ##########
        
        return grad_input, grad_weight, grad_bias

    def transform_matrix(input):
                
        in_matrix =input.permute(1, 0, 2, 3)    
        out_channels = in_matrix.size(0)
        
        out_shape = list(in_matrix.size())
        out_shape = out_shape[:-1] + [-1]
        
        
        len_row = in_matrix.size(3)
        wight_size = weight.size(0)
        half_ws = wight_size // 2
        index_start = half_ws
        index_end = len_row - half_ws
        


        res = []
        for per_channel in in_matrix:
            per_channel = per_channel.reshape(-1,len_row)
            
            temp_res = []

            for i in per_channel:
                for index in range(index_start, index_end):
                    temp  = i[index-half_ws : index+half_ws+1]
                    temp_res.append(temp)

            temp_res = torch.cat(temp_res,dim = 0)
            temp_res = temp_res.reshape(-1, wight_size)
            res.append(temp_res)


        res = torch.cat(res,dim = 0)
        res = res.reshape(out_channels, -1, wight_size)



In [213]:
#test mul_out_channels backward
# ctx.needs_input_grad

criterion = nn.MSELoss()
batch_size = 1
in_channels = 1
out_channels = 1
row = 4
col = 5
kernel_size = 3
my_input = torch.rand([batch_size,in_channels,row,col])

S = flatten_conv2d(in_channels= in_channels, 
                   out_channels= out_channels , 
                   in_row=row, in_col=col, kernel_size=kernel_size)

my_output = S(my_input )
label = torch.ones_like(my_output)


loss = criterion(my_output, label)

loss.backward()

out_L: torch.Size([1, 1, 4, 5])
out_V: torch.Size([1, 1, 2, 5])
sub_feature, torch.Size([6, 3])
out_H: torch.Size([1, 1, 2, 3])
conv1d_H: (True, True, True)
H.grad_output:, torch.Size([1, 1, 2, 3])
torch.Size([1, 1, 2, 3])
torch.Size([3, 1])
conv1d_V: (True, True, True)
conv1d_L: (False, True, True)


In [210]:
# test
conv1 = nn.Conv1d(2,3,kernel_size)
input_matrix = torch.rand(2, 2, 6)
conv1(input_matrix).size()

torch.Size([2, 3, 4])

In [None]:
        print(ctx.needs_input_grad)
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)

            
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)

                 
        if bias is not None and ctx.needs_input_grad[2]:
            
            grad_bias = grad_output.sum(0).squeeze(0)
            grad_bias = grad_bias.view_as(bias)


In [None]:
c = c.reshape(2,-1,3)
c

In [None]:
d = matrix3d_matrix1d(c,kernel_size)
myconv1d(d, kernel)

In [None]:
myconv1d(trans_m, kernel)