In [1]:
import sys
sys.path.append('../')
import os
import glob
import pandas as pd
import argparse
import configparser
import numpy as np
import torch
import matplotlib
matplotlib.use( 'tkagg' )
%matplotlib inline
import matplotlib.pyplot as plt
from collections import OrderedDict

from model import *
from lib.preprocessing import *
from lib.dataloading import *
from lib.loss_functions import *
from lib.evaluation import *
from torchvision import transforms
import torchvision.models as models
from torch import optim, nn
import torch.nn.functional as F

In [8]:
transform = get_transformer_norm()
dataset = SSIDataset(img_file = '../data/ssi.csv', transform= transform['val'], inpaint = True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size = 1)
dataiter = iter(dataloader)

In [9]:
generator = CENet()

In [10]:
generator

CENet(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256, kernel_size=(3, 3

In [46]:
# For VGG11 UNet    
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels, padding = 1, output_padding = 1):
        super().__init__()

        self.block = nn.Sequential(
            ConvBatchRelu(in_channels, middle_channels),
            nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=3, stride=2, padding= padding, output_padding= output_padding),            
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)

class ConvBatchRelu(nn.Module):
    def __init__(self, in_, out):
        super().__init__()
        self.conv = nn.Conv2d(in_, out, 3, padding=1)
        self.batchnorm = nn.BatchNorm2d(out)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        x = self.activation(x)
        return x



In [12]:
pretrained_dict = torch.load('../results/baseline2/epoch92.pth')
pretrained_dict.keys()

odict_keys(['encoder.0.weight', 'encoder.0.bias', 'encoder.1.weight', 'encoder.1.bias', 'encoder.1.running_mean', 'encoder.1.running_var', 'encoder.1.num_batches_tracked', 'encoder.3.weight', 'encoder.3.bias', 'encoder.4.weight', 'encoder.4.bias', 'encoder.4.running_mean', 'encoder.4.running_var', 'encoder.4.num_batches_tracked', 'encoder.7.weight', 'encoder.7.bias', 'encoder.8.weight', 'encoder.8.bias', 'encoder.8.running_mean', 'encoder.8.running_var', 'encoder.8.num_batches_tracked', 'encoder.10.weight', 'encoder.10.bias', 'encoder.11.weight', 'encoder.11.bias', 'encoder.11.running_mean', 'encoder.11.running_var', 'encoder.11.num_batches_tracked', 'encoder.14.weight', 'encoder.14.bias', 'encoder.15.weight', 'encoder.15.bias', 'encoder.15.running_mean', 'encoder.15.running_var', 'encoder.15.num_batches_tracked', 'encoder.17.weight', 'encoder.17.bias', 'encoder.18.weight', 'encoder.18.bias', 'encoder.18.running_mean', 'encoder.18.running_var', 'encoder.18.num_batches_tracked', 'encode

In [79]:
pretrained_dict = torch.load('../results/baseline2/epoch92.pth')
model_dict = vgg.encoder.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) 
# 3. load the new state dict
vgg.encoder.load_state_dict(model_dict)

<All keys matched successfully>

In [81]:
model_dict.keys()

odict_keys(['0.weight', '0.bias', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked', '3.weight', '3.bias', '4.weight', '4.bias', '4.running_mean', '4.running_var', '4.num_batches_tracked', '7.weight', '7.bias', '8.weight', '8.bias', '8.running_mean', '8.running_var', '8.num_batches_tracked', '10.weight', '10.bias', '11.weight', '11.bias', '11.running_mean', '11.running_var', '11.num_batches_tracked', '14.weight', '14.bias', '15.weight', '15.bias', '15.running_mean', '15.running_var', '15.num_batches_tracked', '17.weight', '17.bias', '18.weight', '18.bias', '18.running_mean', '18.running_var', '18.num_batches_tracked', '20.weight', '20.bias', '21.weight', '21.bias', '21.running_mean', '21.running_var', '21.num_batches_tracked', '24.weight', '24.bias', '25.weight', '25.bias', '25.running_mean', '25.running_var', '25.num_batches_tracked', '27.weight', '27.bias', '28.weight', '28.bias', '28.running_mean', '28.running_var', '28.num_batches_tracked', '30.weight

In [199]:
weight_path = '../results/baseline2/epoch92.pth'
pretrained_dict = torch.load(weight_path)

In [203]:
pretrained_dict['encoder.0.weight'][0]

tensor([[[-0.0350, -0.0227,  0.0022],
         [ 0.0654,  0.0476, -0.0699],
         [ 0.0331, -0.1387, -0.0050]],

        [[-0.1026,  0.0149,  0.0538],
         [ 0.0293, -0.0288,  0.0078],
         [-0.0041, -0.1956,  0.1362]],

        [[ 0.0335,  0.0263, -0.0362],
         [-0.0766, -0.0320, -0.0965],
         [ 0.1025, -0.1313, -0.0382]]], device='cuda:0')

In [214]:
class VGG16UNet(nn.Module):
    def __init__(self, num_classes=1, num_filters=32, pretrained=False, self_trained = False, freeze = False, is_deconv=False):
        """
        :param num_classes:
        :param num_filters:
        :param pretrained:
            False - no pre-trained network used
            True - encoder pre-trained with VGG16
        :is_deconv:
            False: bilinear interpolation is used in decoder
            True: deconvolution is used in decoder
        """
        super().__init__()
        self.num_classes = num_classes
        
        self.encoder = models.vgg16_bn(pretrained=pretrained).features
                
        if self_trained:
            weight_path = '../results/baseline2/epoch92.pth'
            print(f"Load weights from {weight_path}")
            pretrained_dict = torch.load(weight_path)            
            model_dict = self.state_dict()            
            # 1. filter out unnecessary keys
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
            # 2. overwrite entries in the existing state dict
            model_dict.update(pretrained_dict)             
            
            # 3. load the new state dict0            
            self.load_state_dict(model_dict)                    
        
        if freeze:
            for param in self.encoder.parameters():
                param.require_grad = False
                
        self.relu = nn.ReLU(inplace=True)                
        self.pool = nn.MaxPool2d(2, 2)
                
        self.conv1 = nn.Sequential(self.encoder[0],
                                   self.encoder[1],
                                   self.relu,
                                   self.encoder[3],
                                   self.encoder[4],
                                   self.relu)

        self.conv2 = nn.Sequential(self.encoder[7],
                                   self.encoder[8],
                                   self.relu,
                                   self.encoder[10],
                                   self.encoder[11],
                                   self.relu)

        self.conv3 = nn.Sequential(self.encoder[14],
                                   self.encoder[15],
                                   self.relu,
                                   self.encoder[17],
                                   self.encoder[18],
                                   self.relu,
                                   self.encoder[20],
                                   self.encoder[21],
                                   self.relu)

        self.conv4 = nn.Sequential(self.encoder[24],
                                   self.encoder[25],
                                   self.relu,
                                   self.encoder[27],
                                   self.encoder[28],
                                   self.relu,
                                   self.encoder[30],
                                   self.encoder[31],
                                   self.relu)

        self.conv5 = nn.Sequential(self.encoder[34],
                                   self.encoder[35],
                                   self.relu,
                                   self.encoder[37],
                                   self.encoder[38],
                                   self.relu,
                                   self.encoder[40],
                                   self.encoder[41],
                                   self.relu)
    
        self.center = DecoderBlock(512, num_filters * 8 * 2, num_filters * 8, padding = (1,0), output_padding = (1,0))

        self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8)
        self.dec4 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8)
        self.dec3 = DecoderBlock(256 + num_filters * 8, num_filters * 4 * 2, num_filters * 2)
        self.dec2 = DecoderBlock(128 + num_filters * 2, num_filters * 2 * 2, num_filters)
        self.dec1 = ConvBatchRelu(64 + num_filters, num_filters)
        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

    def forward(self, x):        
        conv1 = self.conv1(x)
        conv2 = self.conv2(self.pool(conv1))
        conv3 = self.conv3(self.pool(conv2))
        conv4 = self.conv4(self.pool(conv3))
        conv5 = self.conv5(self.pool(conv4))

        center = self.center(self.pool(conv5))

        dec5 = self.dec5(torch.cat([center, conv5], 1))

        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(torch.cat([dec2, conv1], 1))

        if self.num_classes > 1:
            x_out = F.log_softmax(self.final(dec1), dim=1)
        else:
            x_out = self.final(dec1)

        return F.sigmoid(x_out)


In [221]:
vgg = VGG16UNet(num_classes=2, pretrained=False, self_trained=True)

Load weights from ../results/baseline2/epoch92.pth


In [220]:
list(vgg.conv1[0].weight.data)[0]

tensor([[[ 0.0080, -0.0645,  0.0364],
         [ 0.0237,  0.0199,  0.0402],
         [-0.0632,  0.0091,  0.0989]],

        [[ 0.0625, -0.0922, -0.0725],
         [ 0.0359,  0.0855, -0.0374],
         [-0.0504,  0.0304,  0.0501]],

        [[ 0.0345,  0.0371, -0.0076],
         [ 0.0128,  0.0458, -0.0714],
         [ 0.0350, -0.0945, -0.0777]]])

In [209]:
vgg.encoder.state_dict()['0.weight'][0]

tensor([[[-0.0350, -0.0227,  0.0022],
         [ 0.0654,  0.0476, -0.0699],
         [ 0.0331, -0.1387, -0.0050]],

        [[-0.1026,  0.0149,  0.0538],
         [ 0.0293, -0.0288,  0.0078],
         [-0.0041, -0.1956,  0.1362]],

        [[ 0.0335,  0.0263, -0.0362],
         [-0.0766, -0.0320, -0.0965],
         [ 0.1025, -0.1313, -0.0382]]])

In [86]:
out = vgg(img)



In [69]:
out.size()

torch.Size([1, 2, 256, 400])