# Practice2 - From Fully-Connected to Fully-Convolutional Networks

- Reference code
  - https://github.com/bodokaiser/piwise

In [None]:
# libraries for plot
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torch.nn.functional as F

### Examine VGG-16 architecture

In [None]:
# load VGG16 provided by torchvision
vgg16 = models.vgg16(pretrained=True)
vgg16.features

In [None]:
vgg16.classifier

#### check that VGG16's output from feature layer

In [None]:
feat = vgg16.features.forward(torch.zeros(1,3,224,224))
print(feat.size(), 256*6*6)

### Exercise 2. implement FCN with VGG-16 network

- nn.conv2d
  - https://pytorch.org/docs/stable/nn.html#conv2d
- view function, change tensor into different shape keeping the same number of elements
  - https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view

In [None]:
class FCN32(nn.Module):
    def __init__(self, num_classes=21):
        super().__init__()
        vgg16 = models.vgg16(pretrained=True)
        
        self.feats = vgg16.features
        
        self.fconn = nn.Sequential(
            nn.Conv2d(),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Conv2d(),
            nn.ReLU(inplace=True),
            nn.Dropout(),
        )
        
        # weight copy
        self.fconn[0].weight.data = 
        self.fconn[3].weight.data = 
        
        # bias copy
        self.fconn[0].bias.data = 
        self.fconn[3].bias.data = 
        
        # new score layer
        self.score = nn.Conv2d(4096, num_classes, 1)
        
    def forward(self, x):
        feats = self.feats(x)
        #print(feats.size())
        fconn = self.fconn(feats)
        score = self.score(fconn)
        upsample_t = F.upsample(score, scale_factor=32, mode='bilinear', align_corners=True)
        return upsample_t

In [None]:
fcn = FCN32()
fcn.forward(torch.zeros(1,3,512,512)).size()

### Dataset definition

In [None]:
import numpy as np
import os
from PIL import Image
from torch.utils.data import Dataset

EXTENSIONS = ['.jpg', '.png']

def load_image(file):
    return Image.open(file)

def image_path(root, basename, extension):
    return os.path.join(root, f'{basename}{extension}')


class VOC12(Dataset):
    def __init__(self, root, split='train', input_transform=None, target_transform=None):
        self.images_root = os.path.join(root, 'JPEGImages')
        self.labels_root = os.path.join(root, 'SegmentationClass')

        self.filenames = []
        with open(os.path.join(root, 'ImageSets', 'Segmentation', '%s.txt' % split)) as f:
            for r in f.readlines():
                self.filenames.append(r[0:-1])
        
        self.input_transform = input_transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        filename = self.filenames[index]

        with open(image_path(self.images_root, filename, '.jpg'), 'rb') as f:
            image = load_image(f).convert('RGB')
        with open(image_path(self.labels_root, filename, '.png'), 'rb') as f:
            label = load_image(f).convert('P')

        if self.input_transform is not None:
            image = self.input_transform(image)
        if self.target_transform is not None:
            label = self.target_transform(label)

        return image, label

    def __len__(self):
        return len(self.filenames)

### Transforms

In [None]:
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, CenterCrop, Normalize
from torchvision.transforms import ToTensor, ToPILImage

class Relabel:
    def __init__(self, olabel, nlabel):
        self.olabel = olabel
        self.nlabel = nlabel

    def __call__(self, tensor):
        assert isinstance(tensor, torch.LongTensor), 'tensor needs to be LongTensor'
        tensor[tensor == self.olabel] = self.nlabel
        return tensor

class ToLabel:
    def __call__(self, image):
        return torch.from_numpy(np.array(image)).long().unsqueeze(0)

input_transform = Compose([ 
    CenterCrop(512), 
    ToTensor(), 
    Normalize([.485, .456, .406], [.229, .224, .225]), 
])
target_transform = Compose([ 
    CenterCrop(512), 
    ToLabel(), 
    Relabel(255, 21), # ignore label 255 >> 21
])

### Exercise 3. train a single epoch

In [None]:
DATAROOT = './VOC2012/'

net = FCN32()
optimizer = optim.Adam(net.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=21)

net.cuda()

In [None]:
loader = DataLoader(VOC12(DATAROOT, 'train', input_transform, target_transform),
        num_workers=1, batch_size=10, shuffle=True)

net.train()
all_loss = 0
for i, data in enumerate(loader):
    image = data[0].cuda()
    label = data[1].squeeze().cuda()
    
    # write code in below
    
    all_loss = all_loss + loss.data
    
    if((i % 50 == 0) or (i == len(loader)-1)):
        print("[{:4d}/{:4d}] loss:{:.3f}".format(i, len(loader),all_loss/(i+1)))

### test a single epoch

In [None]:
loader = DataLoader(VOC12(DATAROOT, 'val', input_transform, target_transform),
        num_workers=1, batch_size=10, shuffle=False)

net.eval()
all_loss_val = 0
for i, data in enumerate(loader):
    image = data[0].cuda()
    label = data[1].squeeze().cuda()
    
    pred = net.forward(image)
    loss = loss_fn(pred, label)
    
    all_loss_val = all_loss_val + loss.data
    
    if((i % 50 == 0) or (i == len(loader)-1)):
        print("[{:4d}/{:4d}] loss:{:.3f}".format(i, len(loader),all_loss_val/(i+1)))

### Visualize inference results

In [None]:
def get_voc_colormap():
    N = 256 # number of colormap
    VOCcolormap = np.zeros([N, 3], dtype=np.uint8)
    for i in range(0, N):
        (r,b,g,idx)=(0,0,0,i)
        for j in range(0, 8):
            r = r | ((idx & 1) << (7 - j))
            g = g | ((idx & 2) << (7 - j))
            b = b | ((idx & 4) << (7 - j))
            idx = idx >> 3
        VOCcolormap[i, :] = [r, g >> 1, b >> 2]
    return VOCcolormap

def return_pascal_segmentation(input_im):
    VOCcolormap = get_voc_colormap()
    im = Image.fromarray(input_im, mode='P')
    im.putpalette(np.reshape(VOCcolormap, 768, 'C'))
    return im


im_idx = 0 # image index for visualization

gt = data[1][im_idx,0,:,:].numpy().astype(np.uint8)

val, pred_seg = pred.cpu().max(1)
pred_seg = pred_seg[im_idx].numpy().astype(np.uint8)

# plot the result
plt.figure(figsize=(15,5))

plt.subplot(1,3,1)
plt.imshow(data[0][im_idx,:,:,:].permute(1,2,0)) # C x H x W --> H x W x C

plt.subplot(1,3,2)
plt.imshow(return_pascal_segmentation(gt))

plt.subplot(1,3,3)
plt.imshow(return_pascal_segmentation(pred_seg))

### Exercise 4. compute mean IU via confusion matrix
- PASCAL VOC evaluation code https://github.com/npinto/VOCdevkit/blob/master/VOCcode/VOCevalseg.m

In [None]:
conf_mtx = torch.zeros(21,21)