In [10]:
from model import trainfunc
import torch
import torch.utils

In [11]:
device = torch.device('cuda') if torch.cuda.is_available()\
                              else torch.device('cpu')

In [12]:
x = torch.randn((320, 1, 3, 256, 256))
y = torch.randn((320, 1, 1, 256, 256))

In [13]:
dataset = torch.utils.data.TensorDataset(x, y)
dataloader = torch.utils.data.DataLoader(dataset, batch_size = 32)

*torch.nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')*

In [14]:
class toynet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.act = torch.nn.LeakyReLU(0.1, inplace = True)
        self.down1 = torch.nn.Conv3d(1, 32, kernel_size = (3,2,2), padding = (1,0,0), stride=(1,2,2)) # 3， 512， 512
        self.down2 = torch.nn.Conv3d(32, 64, kernel_size = (3,2,2), padding = (1,0,0), stride=(1,2,2)) # 3， 256, 256
        self.down3 = torch.nn.Conv3d(64, 128, kernel_size = (3,2,2), padding = (1,0,0), stride=(1,2,2)) # 3， 128, 128
        self.down4 = torch.nn.Conv3d(128, 256, kernel_size = (3,2,2), padding = (1,0,0), stride=(1,2,2)) # 3， 64, 64
        self.down5 = torch.nn.Conv3d(256, 256, kernel_size = (3,2,2), padding = (1,0,0), stride=(1,2,2)) # 3， 32, 32
        
        self.down6 = torch.nn.Conv3d(256, 512, kernel_size = (3,2,2), padding = (1,0,0), stride=(1,2,2)) # 3, 16, 16
        
        self.up5 = torch.nn.ConvTranspose3d(512, 256, kernel_size = (3,2,2), padding = (1,0,0), stride=(1,2,2)) # 3, 32, 32
        self.up4 = torch.nn.ConvTranspose3d(256+256, 128, kernel_size = (3,2,2), padding = (1,0,0), stride=(1,2,2)) # 3, 64, 64
        self.up3 = torch.nn.ConvTranspose3d(128+256, 64, kernel_size = (3,2,2), padding = (1,0,0), stride=(1,2,2)) # 3, 128, 128
        self.up2 = torch.nn.ConvTranspose3d(64+128, 32, kernel_size = (3,2,2), padding = (1,0,0), stride=(1,2,2)) # 3, 256, 256
        self.up1 = torch.nn.ConvTranspose3d(32+64, 16, kernel_size = (3,2,2), padding = (1,0,0), stride=(1,2,2)) # 3, 512, 512
        self.up0 = torch.nn.ConvTranspose3d(16+32, 8, kernel_size = (3,2,2), padding = (1,0,0), stride=(1,2,2)) # 3, 1024, 1024
        
        self.final = torch.nn.Conv3d(8+1, 1, kernel_size = 3, padding = (0, 1, 1)) # 1, 1024, 1024
    def forward(self, x):
#         assert(x.shape[-2:]==(1024,1024))
        x0 = x.detach().clone() # 1, 3, 1024, 1024
        
        xdown1 = self.down1(x0) # 32, 3, 512, 512
        xdown1 = self.act(xdown1)
        
        xdown2 = self.down2(xdown1) # 64, 3, 256, 256
        xdown2 = self.act(xdown2)
        
        xdown3 = self.down3(xdown2) # 128, 3, 128, 128
        xdown3 = self.act(xdown3)
        
        xdown4 = self.down4(xdown3) # 256, 3, 64, 64
        xdown4 = self.act(xdown4)
        
        xdown5 = self.down5(xdown4) # 256, 3, 32, 32
        xdown5 = self.act(xdown5)
        
        xdown6 = self.down6(xdown5) # 512, 3, 16, 16
        xdown6 = self.act(xdown6)
        
        xup5 = self.up5(xdown6)
        xup5 = self.act(xup5)
        xup5 = torch.cat((xup5,xdown5),1) # pos 1 is channel
        
        
        xup4 = self.up4(xup5)
        xup4 = self.act(xup4)
        xup4 = torch.cat((xup4,xdown4),1)
        
        
        xup3 = self.up3(xup4)
        xup3 = self.act(xup3)
        xup3 = torch.cat((xup3,xdown3),1)
        
        
        xup2 = self.up2(xup3)
        xup2 = self.act(xup2)
        xup2 = torch.cat((xup2,xdown2),1)
        
        xup1 = self.up1(xup2)
        xup1 = self.act(xup1)
        xup1 = torch.cat((xup1,xdown1),1)
        
        xup0 = self.up0(xup1)
        xup0 = self.act(xup0)
        xup0 = torch.cat((xup0,x0),1)
        
        x = self.final(xup0)
        del x0,xdown1,xdown2,xdown3,xdown5,xdown6
        del xup0,xup1,xup2,xup3,xup4,xup5
#         print(torch.cuda.memory_allocated()//(1024 * 1024))
        return x

In [15]:
model = toynet().to(device)

In [16]:
import torchsummary
torchsummary.summary(model,(1,3,256,256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1      [-1, 32, 3, 128, 128]             416
         LeakyReLU-2      [-1, 32, 3, 128, 128]               0
            Conv3d-3        [-1, 64, 3, 64, 64]          24,640
         LeakyReLU-4        [-1, 64, 3, 64, 64]               0
            Conv3d-5       [-1, 128, 3, 32, 32]          98,432
         LeakyReLU-6       [-1, 128, 3, 32, 32]               0
            Conv3d-7       [-1, 256, 3, 16, 16]         393,472
         LeakyReLU-8       [-1, 256, 3, 16, 16]               0
            Conv3d-9         [-1, 256, 3, 8, 8]         786,688
        LeakyReLU-10         [-1, 256, 3, 8, 8]               0
           Conv3d-11         [-1, 512, 3, 4, 4]       1,573,376
        LeakyReLU-12         [-1, 512, 3, 4, 4]               0
  ConvTranspose3d-13         [-1, 256, 3, 8, 8]       1,573,120
        LeakyReLU-14         [-1, 256, 

In [17]:
for x, _ in dataloader:
    print(x.shape)
    x = x.to(device)
    print(model(x).shape)

torch.Size([32, 1, 3, 256, 256])
torch.Size([32, 1, 1, 256, 256])
torch.Size([32, 1, 3, 256, 256])
torch.Size([32, 1, 1, 256, 256])
torch.Size([32, 1, 3, 256, 256])
torch.Size([32, 1, 1, 256, 256])
torch.Size([32, 1, 3, 256, 256])
torch.Size([32, 1, 1, 256, 256])
torch.Size([32, 1, 3, 256, 256])
torch.Size([32, 1, 1, 256, 256])
torch.Size([32, 1, 3, 256, 256])
torch.Size([32, 1, 1, 256, 256])
torch.Size([32, 1, 3, 256, 256])
torch.Size([32, 1, 1, 256, 256])
torch.Size([32, 1, 3, 256, 256])
torch.Size([32, 1, 1, 256, 256])
torch.Size([32, 1, 3, 256, 256])
torch.Size([32, 1, 1, 256, 256])
torch.Size([32, 1, 3, 256, 256])
torch.Size([32, 1, 1, 256, 256])
