In [45]:
import torch
import torch.nn as nn
kernel_List = [12, 4, 4, 4, 4, 4]
channel_List = [128, 256, 512, 512, 512, 400]

class ConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_sz, padding, stride = 2) -> None:
        super().__init__()
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_sz, stride, padding)
        self.bn = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
    def forward(self, x):
        x = self.relu(self.bn(self.conv(x)))
        x = self.dropout(x)
        return x

def get_convBlocks(in_channel):
    layerNum = len(kernel_List)
    blocks = []
    blocks.append(ConvBlock(in_channel, channel_List[0], kernel_List[0], int(kernel_List[0] / 2 - 1)))
    for i in range(1, layerNum):
        blocks.append(ConvBlock(channel_List[i-1], channel_List[i], kernel_List[i], int(kernel_List[i] / 2 - 1)))
    return blocks

class DeepFold(nn.Module):
    def __init__(self, in_channel) -> None:
        super().__init__()
        self.convLayer = nn.Sequential(*get_convBlocks(in_channel))
    
    # [batch_size, 3, 256, 256]
    def forward(self, x):
        # [batch_size, 400, 4, 4]
        x = self.convLayer(x)
        # [batch_size, 400, 4]
        x = torch.diagonal(x, dim1=2, dim2=3)
        # [batch_size, 400]
        x = torch.mean(x, dim= 2)

        normValue = torch.norm(x, dim = 1) # norm_value [batch_size]
        # print(normValue.shape)
        # [400, batch_size]  最后一维要和norm_value维度匹配
        x = x.reshape(x.shape[-1], -1)
        # [400, batch_size] 已经正则化
        x = torch.div(x, normValue)

        # [batch_size, 400]
        x = x.view(x.shape[-1], -1)
        return x

    # def hook(self, layer: nn.Module, input: torch.tensor, output)

# outputList = []
# def hook(self, layer: nn.Module,  output: torch.tensor):
#     outputList.append(output)

x = torch.rand(2, 3, 256, 256)

model = DeepFold(3)

# for layer in model.convLayer:
#     layer.register_forward_hook(hook)

y = model(x)

# for ele in outputList:
#     print(ele.shape)

# print(model)
# print(help(model))
# print(len(list(model.named_modules())))
# for name,_ in model.convLayer.named_modules():
#     print(name)
#     print('-'*60)

# √
# print(x.shape)
# for layer in model.convLayer:
#     x = layer(x)
#     print(x.shape)
#     print('-'*60)


torch.Size([400, 2])
torch.Size([2, 400])
norm:
tensor([1.0223, 0.9785], grad_fn=<CopyBackwards>)


In [28]:
x = torch.randn(4,4)
print(x)
# tensor([[ 0.9148,  0.1396, -0.8974,  2.0014],
#        [ 0.1129, -0.3656,  0.4371,  0.2618],
#        [ 1.1049, -0.0774, -0.4160, -0.4922],
#        [ 1.3197, -0.2022, -0.0031, -1.3811]])

x = torch.diagonal(x)
print(x)
# tensor([ 0.9148, -0.3656, -0.4160, -1.3811])


tensor([[ 1.0252, -0.6609,  0.1759,  0.4572],
        [ 2.6635,  0.6376, -0.6880,  0.1564],
        [ 0.8290,  3.0119, -0.3659, -0.1646],
        [ 0.0706,  0.8429, -0.0488, -1.3583]])
tensor([ 1.0252,  0.6376, -0.3659, -1.3583])
