<a href="https://colab.research.google.com/github/willychangx/facade-segmentation/blob/main/predict.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture

!gdown https://drive.google.com/uc?id=1Knr0YRtWKf06a9tCHh1HTcgEA8AQy3Cl
!unzip /content/segmentation_predict.zip -d /content
%rm -rf /content/segmentation_predict.zip

%cd /content/segmentation_predict

In [None]:
%%capture

!pip install -r requirements.txt

In [None]:
import os
import time

import cv2
import matplotlib.pyplot as plt
import numpy as np
import png
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import random
from PIL import Image
from colormap.colors import Color, hex2rgb
from sklearn.metrics import average_precision_score as ap_score
from torch.utils.data import ConcatDataset, DataLoader
from torchvision import datasets, models, transforms
from tqdm import tqdm

from dataset import FacadeDataset

N_CLASS=5

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.n_class = N_CLASS
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1), 
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
        )
        self.layer6 = nn.Sequential(
            nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
        )
        self.layer7 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
        )
        self.layer8 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
        )
        self.layer9 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=N_CLASS, kernel_size=1)
        )

    def forward(self, x):
        down_pooling = nn.MaxPool2d(kernel_size=2, stride=2)

        x1 = self.layer1(x) # conv 3x3, ReLU
        x = down_pooling(x1) # max pool 2x2
        x2 = self.layer2(x) # conv 3x3, ReLU
        x = down_pooling(x2) # max pool 2x2
        x3 = self.layer3(x) # conv 3x3, ReLU
        x = down_pooling(x3) # max pool 2x2
        x4 = self.layer4(x) # conv 3x3, ReLU
        x = down_pooling(x4) # max pool 2x2
        x = self.layer5(x) # conv 3x3, ReLU, up-conv 2x2
        x = torch.cat([x4, x], dim=1)
        x = self.layer6(x) # conv 3x3, ReLU, up-conv 2x2
        x = torch.cat([x3, x], dim=1)
        x = self.layer7(x) # conv 3x3, ReLU, up-conv 2x2
        x = torch.cat([x2, x], dim=1)
        x = self.layer8(x) # conv 3x3, ReLU, up-conv 2x2
        x = torch.cat([x1, x], dim=1)
        x = self.layer9(x) # conv 3x3, ReLU, conv 1x1

        return x

In [None]:
def save_label(label, path):
    '''
    Function for plotting labels.
    '''
    colormap = [
        '#000000', # black, facade
        '#0080FF', # blue, others
        '#80FF80', # green, pillar
        '#FF8000', # orange, window
        '#FF0000', # red, balcony
    ]
    assert(np.max(label)<len(colormap))
    colors = [hex2rgb(color, normalise=False) for color in colormap]
    w = png.Writer(label.shape[1], label.shape[0], palette=colors, bitdepth=4)
    with open(path, 'wb') as f:
        w.write(f, label)

In [None]:
def get_result(testloader, net, device, folder='output_train'):
    result = []
    cnt = 1
    with torch.no_grad():
        net = net.eval()
        cnt = 0
        for images, labels in tqdm(testloader, disable=True):
            images = images.to(device)
            labels = labels.to(device)
            output = net(images)[0].cpu().numpy()
            c, h, w = output.shape
            assert(c == N_CLASS)
            y = np.zeros((h,w)).astype('uint8')
            for i in range(N_CLASS):
                mask = output[i]>0.5
                y[mask] = i
            gt = labels.cpu().data.numpy().astype('uint8')
            save_label(y, './{}/y{}.png'.format(folder, cnt))
            save_label(gt, './{}/gt{}.png'.format(folder, cnt))
            plt.imsave(
                './{}/x{}.jpg'.format(folder, cnt),
                ((images[0].cpu().data.numpy()+1)*128).astype(np.uint8).transpose(1,2,0))

            cnt += 1

In [None]:
def main():
    torch.cuda.empty_cache()
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load testing data, data must be in folder /uh_test/ and named 'uh_test_0.jpg'
    uh_data = FacadeDataset(flag='uh_test', dataDir='./uh_test/', data_range=(0,1), onehot=False)
    uh_loader = DataLoader(uh_data, batch_size=1)

    name = 'starter_net'

    net = Net().to(device)

    net.load_state_dict(torch.load('./models/model_starter_net.pth'))

    print('\nPredicting on UH building image')
    res = get_result(uh_loader, net, device, folder='output_train')

In [None]:
if __name__ == "__main__":
    main()