In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models


In [4]:
device = "cuda" if torch.cuda.is_available() else 'cpu'

In [6]:
class Vgg11Fcn(nn.Module):
    def __init__(self, pre_trained_model, num_class):
        super(Vgg11Fcn,self).__init__()
        self.pre_trained_model = pre_trained_model  ## outpur_size = 512,8,8
        self.num_class = num_class
        self.conv1= nn.Sequential(nn.Conv2d(512,4096, 1, groups=4),   #############################groups=2!!!!!!
                                  nn.ReLU(),
                                  nn.Dropout2d(p=0.2))### pretrained model output받기
        self.conv2= nn.Sequential(nn.Conv2d(4096, 4096, 1,groups=4),   ##########################groups=2!!!
                                  nn.ReLU(),
                                  nn.Dropout2d(p=0.2))
        self.upconvX2 = nn.Sequential(nn.ConvTranspose2d(self.num_class, self.num_class, kernel_size =2, stride = 2), ####padding???????
                                    nn.ReLU(),
                                    nn.BatchNorm2d(self.num_class))
        self.upconvX8 = nn.Sequential(nn.ConvTranspose2d(self.num_class, self.num_class, kernel_size =8, stride = 8), ####padding???????
                                    nn.ReLU(),
                                    nn.BatchNorm2d(self.num_class))
        self.feature = {'2':'pool1', 
                '5':'pool2',
                '10':'pool3',
                '15':'pool4',
                '20':'output'}
    
    def get_features(self, x, model, layers):
        features = {}
        for name, layer in enumerate(model.children()):
            x = layer(x)
            if str(name) in layers:
                features[layers[str(name)]] = x
        return features
        
    def conv_num_class(self, x) :
        input_channel = x.shape[1]    ###shape = (Bn, C, H, W)
        model = nn.Conv2d(input_channel, self.num_class, 1).to(device)
        return model(x)
    
    def Padding(self, x, n):
        w,h = x.shape[3] ,x.shape[2]
        wf,hf = 0 ,0
        while w%n == 0. :
            wf += 1
            w = w+wf
        while h%n == 0. :
            hf += 1
            h = h+hf
        return F.pad(x, (hf,0,wf,0))  #####(wf,0, hf, 0) 아닌감
         
    def sum_(self, score, pool3, pool4):
        if score.shape[2:] == pool3.shape[2:] == pool4.shape[2:] :
            return score + pool3 + pool4
        else :
            h = score.shape[2] 
            w = score.shape[3] 
            pool3 = F.interpolate(pool3, size=(h, w), mode='bicubic', align_corners=False)
            pool4 = F.interpolate(pool4, size=(h, w), mode='bicubic', align_corners=False)
            return score + pool3 + pool4
        
    def forward(self, x):
        x_h = x.shape[-2]
        x_w = x.shape[-1]
        x = x.to(device)
        vgg11_output = self.get_features(x, self.pre_trained_model, self.feature)
        score = self.conv1(vgg11_output['output'])
        score = self.conv2(score)
        score = self.conv_num_class(score)
        ####### 4096, num_class, 1
        score = self.upconvX2(self.upconvX2(score))  
        pool3 = self.conv_num_class(vgg11_output['pool3'])  
        ## *********** First layer Conv
        pool4 = self.conv_num_class(vgg11_output['pool4'])
        pool4 = self.upconvX2(pool4)
        out = self.sum_(score, pool3, pool4)
        out = self.upconvX8(out)
        out = nn.Sequential(
             nn.Upsample((x_h, x_w))
            ,nn.Sigmoid())(out)
        
        return out

In [112]:
# pre_trained_model = models.vgg11(pretrained=True).features.to(device) ####VGG의 convolution 층(feature class) 까지만 가져옴
# for param in pre_trained_model.parameters():
#      param.requires_grad_(False)

In [113]:
#print(pre_trained_model)

In [114]:
# Vgg11_FCN = Vgg11Fcn(pre_trained_model, 150).to(device)

In [115]:
#print(Vgg11_FCN)

In [116]:
# import torchsummary
# torchsummary.summary(Vgg11_FCN, (3,149,199))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 149, 199]           1,792
              ReLU-2         [-1, 64, 149, 199]               0
         MaxPool2d-3           [-1, 64, 74, 99]               0
            Conv2d-4          [-1, 128, 74, 99]          73,856
              ReLU-5          [-1, 128, 74, 99]               0
         MaxPool2d-6          [-1, 128, 37, 49]               0
            Conv2d-7          [-1, 256, 37, 49]         295,168
              ReLU-8          [-1, 256, 37, 49]               0
            Conv2d-9          [-1, 256, 37, 49]         590,080
             ReLU-10          [-1, 256, 37, 49]               0
        MaxPool2d-11          [-1, 256, 18, 24]               0
           Conv2d-12          [-1, 512, 18, 24]       1,180,160
             ReLU-13          [-1, 512, 18, 24]               0
           Conv2d-14          [-1, 512,

In [117]:
# s1 = torch.rand(1,3,127,124).to(device)

In [118]:
# Vgg11_FCN(s1).shape

In [None]:
11,157,176