In [2]:
import torch
import torch.nn as nn
from torchvision.models import vgg11



In [3]:
class VggConvLayer(nn.Module):
    def __init__(self,in_channel,out_channel,layer_num=1) -> None:
        super().__init__()
        assert layer_num >= 1
        modules = []
        for i in range(layer_num):
            modules.append(nn.ReflectionPad2d(1))
            modules.append(nn.Conv2d(in_channel,out_channel,3))
            modules.append(nn.ReLU())
        self.modules = nn.Sequential(*modules)

            
    def forward(self,x):
        return self.modules(x)
        

class VggEncoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.modules = nn.Sequential(
            nn.Conv2d(3,3,1),
            VggConvLayer(3,64),
            VggConvLayer(64,64),
            nn.MaxPool2d(2,2,ceil_mode=True),
            VggConvLayer(64,128),
            VggConvLayer(128,128),
            nn.MaxPool2d(2,2,ceil_mode=True),
            VggConvLayer(128,256),
            VggConvLayer(256,256,3),
            nn.MaxPool2d(2,2,ceil_mode=True),
            VggConvLayer(256,512),
            # VggConvLayer(512,512,3),
            # nn.MaxPool2d(2,2,ceil_mode=True),
            # VggConvLayer(512,512,4)
        )
        
    def forward(self,x):
        return self.modules(x)

net = VggEncoder()
net

VggEncoder(
  (modules): Sequential(
    (0): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
    (1): VggConvLayer(
      (modules): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
        (2): ReLU()
      )
    )
    (2): VggConvLayer(
      (modules): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
        (2): ReLU()
      )
    )
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (4): VggConvLayer(
      (modules): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
        (2): ReLU()
      )
    )
    (5): VggConvLayer(
      (modules): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (2): ReLU()
      )
    )
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0

In [13]:
a = torch.rand((1,3,2,2))
b = a.view(1,3,-1)
print('b '+str(b))
print('b.var '+str(b.var(dim=2)))

c = torch.ones(1,3,1)
d = (b+c)*(c+2)

b tensor([[[0.6610, 0.4581, 0.0725, 0.1556],
         [0.7572, 0.5708, 0.1523, 0.2944],
         [0.3483, 0.2334, 0.4204, 0.5740]]])
b.var tensor([[0.0742, 0.0739, 0.0203]])
