In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import models

In [4]:
class SegUNetConvBlock(nn.Module):

    def __init__(self, in_channels, out_channels, num_layers):
        super(SegUNetConvBlock, self).__init__()

        layers = [
            nn.Conv2d(in_channels, in_channels // 2, 3, padding=1),
            nn.BatchNorm2d(in_channels // 2,  momentum=0.1),
            nn.LeakyReLU(inplace=True),
        ]
        layers += [
            nn.Conv2d(in_channels // 2, in_channels // 2, 3, padding=1),
            nn.BatchNorm2d(in_channels // 2,  momentum=0.1),
            nn.LeakyReLU(inplace=True),
        ] * num_layers
        layers += [
            nn.Conv2d(in_channels // 2, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels,  momentum=0.1),
            nn.LeakyReLU(inplace=True),
        ]
        self.decode = nn.Sequential(*layers)

    def forward(self, x):
        return self.decode(x)

In [None]:
class SegUNet(nn.Module):
    def __init__(self, input_nbr=3, label_nbr=2, pretrained=True):
        super(SegUNet, self).__init__()
        self.input_nbr = input_nbr
        self.label_nbr = label_nbr
        batchNorm_momentum = 0.1
        
        if pretrianed:
            vgg16 = list(models.vgg16_bn(pretrained=True).features.children())
            self.down1 = nn.Sequential(*vgg16[:6])     # 512*512*3 --> 256*256*64(after maxpool)
            self.down2 = nn.Sequential(*vgg16[7:13])   # 256*256*64  --> 128*128*128(after maxpool)
            self.down3 = nn.Sequential(*vgg16[14:23])  # 128*128*128  --> 64*64*256(after maxpool)
            self.down4 = nn.Sequential(*vgg16[24:33])  # 64*64*256 --> 32*32*512(after maxpool)
            self.down5 = nn.Sequential(*vgg16[34:43])    # 32*32*512 --> 16*16*512(after maxpool)
        else:
            self.down1 = nn.Sequential(nn.Conv2d(input_nbr, 64, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(64, momentum= batchNorm_momentum),
                                       nn.ReLU(),
                                       nn.Conv2d(64, 64, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(64, momentum= batchNorm_momentum),
                                       nn.ReLU())
            
            self.down2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(128, momentum= batchNorm_momentum),
                                       nn.ReLU(),
                                       nn.Conv2d(128, 128, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(128, momentum= batchNorm_momentum),
                                       nn.ReLU())
            
            self.down3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(256, momentum= batchNorm_momentum),
                                       nn.ReLU(),
                                       nn.Conv2d(256, 256, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(256, momentum= batchNorm_momentum),
                                       nn.ReLU(),
                                       nn.Conv2d(256, 256, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(256, momentum= batchNorm_momentum),
                                       nn.ReLU())
            
            self.down4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(512, momentum= batchNorm_momentum),
                                       nn.ReLU(),
                                       nn.Conv2d(512, 512, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(512, momentum= batchNorm_momentum),
                                       nn.ReLU(),
                                       nn.Conv2d(512, 512, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(512, momentum= batchNorm_momentum),
                                       nn.ReLU())
            
            self.down5 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(512, momentum= batchNorm_momentum),
                                       nn.ReLU(),
                                       nn.Conv2d(512, 512, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(512, momentum= batchNorm_momentum),
                                       nn.ReLU(),
                                       nn.Conv2d(512, 512, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(512, momentum= batchNorm_momentum),
                                       nn.ReLU())
            
        self.up5 = SegUNetConvBlock(512, 512, 1)           # 16*16*512(after up) --> 32*32*512
        self.up4 = SegUNetConvBlock(1024, 256, 1)          # 32*32*512+512(after up) --> 64*64*256
        self.up3 = SegUNetConvBlock(512, 128, 1)           # 64*64*256+256(after up) --> 128*128*128
        self.up2 = SegUNetConvBlock(256, 64, 1)            # 128*128*128+128(after up) --> 256*256*64
        self.up1 = SegUNetConvBlock(128, 64, 1)            # 256*256*64+64(after up) --> 512*512*64

        self.last = nn.Conv2d(64, label_nbr, 3, padding=1)
    
    def forward(self, x):
        #--------------------------------Encoder-------------------------------------------
        # Stage 1
        # 512*512*3 --> 256*256*64
        down1 = self.down1(x) # 512*512*64
        down1_pool, idx1 = F.max_pool2d(down1, kernel_size=2, stride=2,return_indices=True)     

        # Stage 2
        # 256*256*64  --> 128*128*128
        down2 = self.down2(down1_pool) # 256*256*128
        down2_pool, idx2 = F.max_pool2d(down2, kernel_size=2, stride=2,return_indices=True) 
        
        # Stage 3
        # 128*128*128  --> 64*64*256
        down3 = self.down3(down2_pool) # 128*128*256
        down3_pool, idx3 = F.max_pool2d(down3, kernel_size=2, stride=2,return_indices=True)
        
        # Stage 4
        # 64*64*256 --> 32*32*512
        down4 = self.down4(down3_pool) # 64*64*512
        down4_pool, idx4 = F.max_pool2d(down4,kernel_size=2, stride=2,return_indices=True)

        # Stage 5
        # 32*32*512 --> 16*16*512
        down5 = self.down5(down4_pool) # 32*32*512
        down5_pool, idx5 = F.max_pool2d(down5,kernel_size=2, stride=2,return_indices=True)
        
        #-------------------------------Decoder------------------------------------------------
        # Stage 5d
        # 16*16*512 --> 32*32*512
        up5 = self.up5(F.max_unpool2d(down5_pool, idx5, kernel_size=2, stride=2)) 

        # Stage 4d
        # 32*32*512+512 --> 64*64*256
        up5_up = F.max_unpool2d(up5, idx4, kernel_size=2, stride=2) # 32*32*512 --> 64*64*512
        up4 = self.up4(torch.cat([down4, up5_up], 1)) # 64*64*(512+512) --> 64*64*256

        # Stage 3d
        dec4_up = F.max_unpool2d(dec4, id3, kernel_size=2, stride=2) # 64*64*256 --> 128*128*256
        dec3 = self.dec3(torch.cat([enc3, dec4_up], 1)) # 128*128*(256+256) --> 128*128*128

        # Stage 2d
        dec3_up = F.max_unpool2d(dec3, id2, kernel_size=2, stride=2) # 128*128*128 --> 256*256*128
        dec2 = self.dec2(torch.cat([enc2, dec3_up], 1)) # 256*256*(128+128) --> 256*256*64

        # Stage 1d
        dec2_up = F.max_unpool2d(dec2, id1, kernel_size=2, stride=2) # 256*256*64 --> 512*512*64
        dec1 = self.dec1(torch.cat([enc1, dec2_up], 1)) # 512*512*(64+64) --> 512*512*64


        return F.upsample(input = self.final(dec1), size = x.size()[2:], mode = 'bilinear')

            
        
        

In [None]:
class SegUNet(nn.Module):
    def __init__(self, input_nbr=3, label_nbr=2, pretrained=True):
        super(SegUNet, self).__init__()

        self.input_nbr = input_nbr
        self.label_nbr = label_nbr
        batchNorm_momentum = 0.1
        if not pretrained:
            self.conv11 = nn.Conv2d(input_nbr, 64, kernel_size=3, padding=1)
            self.bn11 = nn.BatchNorm2d(64, momentum= batchNorm_momentum)
            self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
            self.bn12 = nn.BatchNorm2d(64, momentum= batchNorm_momentum)

            self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
            self.bn21 = nn.BatchNorm2d(128, momentum= batchNorm_momentum)
            self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
            self.bn22 = nn.BatchNorm2d(128, momentum= batchNorm_momentum)

            self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
            self.bn31 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
            self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
            self.bn32 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
            self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
            self.bn33 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)

            self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
            self.bn41 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
            self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
            self.bn42 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
            self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
            self.bn43 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)

            self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
            self.bn51 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
            self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
            self.bn52 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
            self.conv53 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
            self.bn53 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        else:
            # stage 1
            vgg16 = list(models.vgg16_bn(pretrained=True).features.children())
            if self.input_nbr == 3:
                self.conv11 = vgg16[0]
            else:
                self.conv11 = nn.Conv2d(input_nbr, 64, kernel_size=3, padding=1)
            self.bn11 = nn.BatchNorm2d(64, momentum= batchNorm_momentum)
            self.conv12 = vgg16[3]
            self.bn12 = nn.BatchNorm2d(64, momentum= batchNorm_momentum)

            # stage 2
            self.conv21 = vgg16[7]
            self.bn21 = nn.BatchNorm2d(128, momentum= batchNorm_momentum)
            self.conv22 = vgg16[10]
            self.bn22 = nn.BatchNorm2d(128, momentum= batchNorm_momentum)

            # stage 3
            self.conv31 = vgg16[14]
            self.bn31 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
            self.conv32 = vgg16[17]
            self.bn32 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
            self.conv33 = vgg16[20]
            self.bn33 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)

            # stage 4
            self.conv41 = vgg16[24]
            self.bn41 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
            self.conv42 = vgg16[27]
            self.bn42 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
            self.conv43 = vgg16[30]
            self.bn43 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)

            # stage 5
            self.conv51 = vgg16[34]
            self.bn51 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
            self.conv52 = vgg16[37]
            self.bn52 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
            self.conv53 = vgg16[40]
            self.bn53 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)


        self.down1 = nn.Sequential(self.conv11, self.bn11, nn.ReLU(),
                                   self.conv12, self.bn12, nn.ReLU())
            
        self.down2 = nn.Sequential(self.conv21, self.bn21, nn.ReLU(),
                                   self.conv22, self.bn22, nn.ReLU())
        
        self.down3 = nn.Sequential(self.conv31, self.bn31, nn.ReLU(),
                                   self.conv32, self.bn32, nn.ReLU(),
                                   self.conv33, self.bn33, nn.ReLU())
        
        self.down4 = nn.Sequential(self.conv41, self.bn41, nn.ReLU(),
                                   self.conv42, self.bn42, nn.ReLU(),
                                   self.conv43, self.bn43, nn.ReLU())
        
        self.down5 = nn.Sequential(self.conv51, self.bn51, nn.ReLU(),
                                   self.conv52, self.bn52, nn.ReLU(),
                                   self.conv53, self.bn53, nn.ReLU())

        self.up4 = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=3, padding=1),
                                 nn.BatchNorm2d(512),
                                 nn.ReLU(),
                                 nn.Conv2d(512, 512, kernel_size=3, padding=1),
                                 nn.BatchNorm2d(512),
                                 nn.ReLU(),
                                 nn.Conv2d(512, 512, kernel_size=3, padding=1),
                                 nn.BatchNorm2d(512),
                                 nn.ReLU())
            
        self.up3 = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=3, padding=1),
                                 nn.BatchNorm2d(512),
                                 nn.ReLU(),
                                 nn.Conv2d(512, 512, kernel_size=3, padding=1),
                                 nn.BatchNorm2d(512),
                                 nn.ReLU(),
                                 nn.Conv2d(512, 512, kernel_size=3, padding=1),
                                 nn.BatchNorm2d(512),
                                 nn.ReLU())
            
            
            
            
            
        self.conv53d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn53d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn52d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn51d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)

        self.conv43d = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
        self.bn43d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn42d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv41d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn41d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)

        self.conv33d = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.bn33d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
        self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn32d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
        self.conv31d = nn.Conv2d(256,  256, kernel_size=3, padding=1)
        self.bn31d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)

        self.conv22d = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.bn22d = nn.BatchNorm2d(128, momentum= batchNorm_momentum)
        self.conv21d = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn21d = nn.BatchNorm2d(128, momentum= batchNorm_momentum)

        self.conv12d = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.bn12d = nn.BatchNorm2d(64, momentum= batchNorm_momentum)
        self.conv11d = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        
        self.last = nn.Conv2d(64, n_classes, kernel_size=1)


    def forward(self, x):

        # Stage 1
        x11 = F.relu(self.bn11(self.conv11(x)))
        x12 = F.relu(self.bn12(self.conv12(x11)))
        x1p, id1 = F.max_pool2d(x12,kernel_size=2, stride=2,return_indices=True)

        # Stage 2
        x21 = F.relu(self.bn21(self.conv21(x1p)))
        x22 = F.relu(self.bn22(self.conv22(x21)))
        x2p, id2 = F.max_pool2d(x22,kernel_size=2, stride=2,return_indices=True)

        # Stage 3
        x31 = F.relu(self.bn31(self.conv31(x2p)))
        x32 = F.relu(self.bn32(self.conv32(x31)))
        x33 = F.relu(self.bn33(self.conv33(x32)))
        x3p, id3 = F.max_pool2d(x33,kernel_size=2, stride=2,return_indices=True)

        # Stage 4
        x41 = F.relu(self.bn41(self.conv41(x3p)))
        x42 = F.relu(self.bn42(self.conv42(x41)))
        x43 = F.relu(self.bn43(self.conv43(x42)))
        x4p, id4 = F.max_pool2d(x43,kernel_size=2, stride=2,return_indices=True)

        # Stage 5
        x51 = F.relu(self.bn51(self.conv51(x4p)))
        x52 = F.relu(self.bn52(self.conv52(x51)))
        x53 = F.relu(self.bn53(self.conv53(x52)))
        x5p, id5 = F.max_pool2d(x53,kernel_size=2, stride=2,return_indices=True)


        # Stage 5d
        x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2)
        
        
        x53d = F.relu(self.bn53d(self.conv53d(x5d)))
        x52d = F.relu(self.bn52d(self.conv52d(x53d)))
        x51d = F.relu(self.bn51d(self.conv51d(x52d)))

        # Stage 4d
        x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2)
        x43d = F.relu(self.bn43d(self.conv43d(x4d)))
        x42d = F.relu(self.bn42d(self.conv42d(x43d)))
        x41d = F.relu(self.bn41d(self.conv41d(x42d)))

        # Stage 3d
        x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2)
        x33d = F.relu(self.bn33d(self.conv33d(x3d)))
        x32d = F.relu(self.bn32d(self.conv32d(x33d)))
        x31d = F.relu(self.bn31d(self.conv31d(x32d)))

        # Stage 2d
        x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2)
        x22d = F.relu(self.bn22d(self.conv22d(x2d)))
        x21d = F.relu(self.bn21d(self.conv21d(x22d)))

        # Stage 1d
        x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2)
        x12d = F.relu(self.bn12d(self.conv12d(x1d)))
        x11d = self.conv11d(x12d)

        return x11d



In [None]:

from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

from torch.utils import model_zoo
from torchvision import models

# In[7]:


class SegNetDec(nn.Module):

    def __init__(self, in_channels, out_channels, num_layers):
        super(SegNetDec, self).__init__()


        layers = [
#             nn.UpsamplingBilinear2d(scale_factor=2),
#             nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(in_channels, in_channels // 2, 3, padding=1),
            nn.BatchNorm2d(in_channels // 2,  momentum=0.1),
            nn.LeakyReLU(inplace=True),
        ]
        layers += [
            nn.Conv2d(in_channels // 2, in_channels // 2, 3, padding=1),
            nn.BatchNorm2d(in_channels // 2,  momentum=0.1),
            nn.LeakyReLU(inplace=True),
        ] * num_layers
        layers += [
            nn.Conv2d(in_channels // 2, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels,  momentum=0.1),
            nn.LeakyReLU(inplace=True),
        ]
        self.decode = nn.Sequential(*layers)

    def forward(self, x):
        return self.decode(x)


# In[9]:


class SegNet_U(nn.Module):
    def __init__(self,num_classes):
        super(SegNet_U, self).__init__()

        batchNorm_momentum = 0.1
        
        encoders = list(models.vgg16_bn(pretrained=True).features.children())
        
        self.enc1 = nn.Sequential(*encoders[:6])     # 512*512*3 --> 256*256*64(after maxpool)
        self.enc2 = nn.Sequential(*encoders[7:13])   # 256*256*64  --> 128*128*128(after maxpool)
        self.enc3 = nn.Sequential(*encoders[14:23])  # 128*128*128  --> 64*64*256(after maxpool)
        self.enc4 = nn.Sequential(*encoders[24:33])  # 64*64*256 --> 32*32*512(after maxpool)
        self.enc5 = nn.Sequential(*encoders[34:43])    # 32*32*512 --> 16*16*512(after maxpool)
        
        self.dec5 = SegNetDec(512, 512, 1)           # 16*16*512(after up) --> 32*32*512
        self.dec4 = SegNetDec(1024, 256, 1)          # 32*32*512+512(after up) --> 64*64*256
        self.dec3 = SegNetDec(512, 128, 1)           # 64*64*256+256(after up) --> 128*128*128
        self.dec2 = SegNetDec(256, 64, 1)            # 128*128*128+128(after up) --> 256*256*64
        self.dec1 = SegNetDec(128, 64, 1)            # 256*256*64+64(after up) --> 512*512*64
        self.final = nn.Conv2d(64, num_classes, 3, padding=1)


    def forward(self, x):
        #--------------------------------Encoder-------------------------------------------
        # Stage 1
        # 512*512*3 --> 256*256*64
        enc1 = self.enc1(x) # 512*512*64
        enc1_pool, id1 = F.max_pool2d(enc1, kernel_size=2, stride=2,return_indices=True)     

        # Stage 2
        # 256*256*64  --> 128*128*128
        enc2 = self.enc2(enc1_pool) # 256*256*128
        enc2_pool, id2 = F.max_pool2d(enc2, kernel_size=2, stride=2,return_indices=True) 
        
        # Stage 3
        # 128*128*128  --> 64*64*256
        enc3 = self.enc3(enc2_pool) # 128*128*256
        enc3_pool, id3 = F.max_pool2d(enc3,kernel_size=2, stride=2,return_indices=True)
        
        # Stage 4
        # 64*64*256 --> 32*32*512
        enc4 = self.enc4(enc3_pool) # 64*64*512
        enc4_pool, id4 = F.max_pool2d(enc4,kernel_size=2, stride=2,return_indices=True)

        # Stage 5
        # 32*32*512 --> 16*16*512
        enc5 = self.enc5(enc4_pool) # 32*32*512
        enc5_pool, id5 = F.max_pool2d(enc5,kernel_size=2, stride=2,return_indices=True)
        
        #-------------------------------Decoder------------------------------------------------
        # Stage 5d
        # 16*16*512 --> 32*32*512
        dec5 = self.dec5(F.max_unpool2d(enc5_pool, id5, kernel_size=2, stride=2)) 

        # Stage 4d
        # 32*32*512+512 --> 64*64*256
        dec5_up = F.max_unpool2d(dec5, id4, kernel_size=2, stride=2) # 32*32*512 --> 64*64*512
        dec4 = self.dec4(torch.cat([enc4, dec5_up], 1)) # 64*64*(512+512) --> 64*64*256

        # Stage 3d
        dec4_up = F.max_unpool2d(dec4, id3, kernel_size=2, stride=2) # 64*64*256 --> 128*128*256
        dec3 = self.dec3(torch.cat([enc3, dec4_up], 1)) # 128*128*(256+256) --> 128*128*128

        # Stage 2d
        dec3_up = F.max_unpool2d(dec3, id2, kernel_size=2, stride=2) # 128*128*128 --> 256*256*128
        dec2 = self.dec2(torch.cat([enc2, dec3_up], 1)) # 256*256*(128+128) --> 256*256*64

        # Stage 1d
        dec2_up = F.max_unpool2d(dec2, id1, kernel_size=2, stride=2) # 256*256*64 --> 512*512*64
        dec1 = self.dec1(torch.cat([enc1, dec2_up], 1)) # 512*512*(64+64) --> 512*512*64


        return F.upsample(input = self.final(dec1), size = x.size()[2:], mode = 'bilinear')

In [2]:
vgg16 = list(models.vgg16_bn(pretrained=True).features.children())

In [3]:
vgg16

[Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 ReLU(inplace),
 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 ReLU(inplace),
 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
 Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 ReLU(inplace),
 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 ReLU(inplace),
 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 ReLU(inplace),
 Conv2d(2