In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data
import torchvision

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image

import pandas as pd
import numpy as np
import scipy.io
import skimage.io

from PIL import Image, ImageFilter

In [2]:
class Encoder(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channel)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)
        x, idx = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
        return x, idx

In [3]:
class Decoder(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channel)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)
        return x

In [4]:
class SegNetBasic(nn.Module):
    """ 
        SegNet Basic is a smaller version of SegNet
        Please refer to this repository:
        https://github.com/0bserver07/Keras-SegNet-Basic
    """
    
    def __init__(self, in_channel, out_channel):
        super().__init__()
        
        self.encoder1 = Encoder(in_channel, 64)
        self.encoder2 = Encoder(64, 80)
        self.encoder3 = Encoder(80, 96)
        self.encoder4 = Encoder(96, 128)
        
        self.decoder1 = Decoder(128, 96)
        self.decoder2 = Decoder(96, 80)
        self.decoder3 = Decoder(80, 64)
        self.decoder4 = Decoder(64, out_channel)
        
    def forward(self, x):
        size1 = x.size()
        x, idx1 = self.encoder1(x)

        size2 = x.size()
        x, idx2 = self.encoder2(x)

        size3 = x.size()
        x, idx3 = self.encoder3(x)
        
        size4 = x.size()
        x, idx4 = self.encoder4(x)

        x = F.max_unpool2d(x, idx4, kernel_size=2, stride=2, output_size=size4)
        x = self.decoder1(x)
        
        x = F.max_unpool2d(x, idx3, kernel_size=2, stride=2, output_size=size3)
        x = self.decoder2(x)

        x = F.max_unpool2d(x, idx2, kernel_size=2, stride=2, output_size=size2)
        x = self.decoder3(x)

        x = F.max_unpool2d(x, idx1, kernel_size=2, stride=2, output_size=size1)
        x = self.decoder4(x)

        return x

In [5]:
class PartAffordanceDataset(Dataset):
    """Part Affordance Dataset"""
    
    def __init__(self, csv_file, transform=None):
        super().__init__()
        
        self.image_class_path = pd.read_csv(csv_file)
        self.transform = transform
        
    def __len__(self):
        return len(self.image_class_path)
    
    def __getitem__(self, idx):
        image_path = self.image_class_path.iloc[idx, 0]
        class_path = self.image_class_path.iloc[idx, 1]
        image = skimage.io.imread(image_path) # read as numpy array
        cls = scipy.io.loadmat(class_path)["gt_label"]
        
        sample = {'image': image, 'class': cls}
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample

In [6]:
def crop_center_numpy(array, crop_height, crop_weight):
    h, w = array.shape
    return array[h//2 - crop_height//2: h//2 + crop_height//2,
                 w//2 - crop_weight//2: w//2 + crop_weight//2
                ]

In [7]:
def crop_center_pil_image(pil_img, crop_width, crop_height):
    img_width, img_height = pil_img.size
    return pil_img.crop(((img_width - crop_width) // 2,
                         (img_height - crop_height) // 2,
                         (img_width + crop_width) // 2,
                         (img_height + crop_height) // 2))

In [8]:
class CenterCrop(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        image = Image.fromarray(np.uint8(image))
        
        image = crop_center_pil_image(image, 320, 240)
        cls = crop_center_numpy(cls, 240, 320)
        
        image = np.asarray(image)
        
        return {'image': image, 'class': cls}

In [9]:
class ToTensor(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image).float(), 
                'class': torch.from_numpy(cls).long()}

In [10]:
mean=[55.8630, 59.9099, 91.7419]
std=[31.6852, 29.8496, 19.0835]

In [11]:
class Normalize(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        image = transforms.functional.normalize(image, mean, std)
        
        return {'image': image, 'class': cls}

In [12]:
train_data = PartAffordanceDataset('train.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    ToTensor(),
                                    Normalize()
                                ]))

In [13]:
test_data = PartAffordanceDataset('test.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    ToTensor(),
                                    Normalize()
                                ]))

In [14]:
train_loader = DataLoader(train_data, batch_size=10, shuffle=True)
test_loader = DataLoader(test_data, batch_size=10, shuffle=False)

In [15]:
class_num = torch.tensor([2078085712, 34078992, 15921090, 12433420, 
                          38473752, 6773528, 9273826, 20102080])

class_weight = class_num[0].float() / (100.0 * class_num.float())

In [16]:
from tensorboardX import SummaryWriter
import tqdm

In [17]:
# def eval_model(model, test_loader, device='cpu'):
#     model.eval()
    
#     intersection = torch.zeros(8)   # the dataset has 8 classes including background
#     union = torch.zeros(8)
    
#     for sample in test_loader:
#         x, y = sample['image'], sample['class']
        
#         x = x.to(device)
#         y = y.to(device)
        
#         with torch.no_grad():
#             _, ypred = model(x).max(1)    # y_pred.shape => (N, 240, 320)
        
#         for i in range(8):
#             y_i = (y == i)           
#             ypred_i = (ypred == i)   
            
#             inter = (y_i.byte() & ypred_i.byte()).float().sum().to('cpu')
#             intersection[i] += inter
#             union[i] += (y_i.float().sum() + ypred_i.float().sum()).to('cpu') - inter
    
#     """ iou[i] is the IoU of class i """
#     iou = intersection / union
    
#     return iou

In [18]:
# def train_model(model, train_loader, test_loader, optimizer_cls=optim.Adam, 
#                 criterion=nn.CrossEntropyLoss(), max_epoch=200, device='cpu', writer=None):
    
#     model.to(device)
    
#     train_losses = []
#     val_iou = []
#     mean_iou = []
#     best_iou = 0.0
    
#     optimizer = optimizer_cls(model.parameters(), lr=0.01)
    
#     for epoch in range(max_epoch):
#         model.train()
#         running_loss = 0.0
        
#         for i, sample in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
#             optimizer.zero_grad()
            
#             x, y = sample['image'], sample['class']
            
#             x = x.to(device)
#             y = y.to(device)

#             h = model(x)
#             loss = criterion(h, y)
#             loss.backward()
#             optimizer.step()
            
#             running_loss += loss.item()

#         train_losses.append(running_loss / i)
        
#         val_iou.append(eval_model(model, test_loader, device))
#         mean_iou.append(val_iou[-1].mean().item())
        
#         if best_iou < mean_iou[-1]:
#             best_iou = mean_iou[-1]
#             torch.save(model.state_dict(), "./SegNet_with_class_weight(median)_results/best_iou_model.prm")
        
#         if writer is not None:
#             writer.add_scalar("train_loss", train_losses[-1], epoch)
#             writer.add_scalar("mean_IoU", mean_iou[-1], epoch)
#             writer.add_scalars("class_IoU", {'iou of class 0': val_iou[-1][0],
#                                            'iou of class 1': val_iou[-1][1],
#                                            'iou of class 2': val_iou[-1][2],
#                                            'iou of class 3': val_iou[-1][3],
#                                            'iou of class 4': val_iou[-1][4],
#                                            'iou of class 5': val_iou[-1][5],
#                                            'iou of class 6': val_iou[-1][6],
#                                            'iou of class 7': val_iou[-1][7]}, epoch)
            
#         print(epoch, train_losses[-1], mean_iou[-1])
        
#     torch.save(model.state_dict(), "./SegNet_with_class_weight(median)_results/final_model.prm")

In [19]:
colors = torch.tensor([[0, 0, 0],         # class 0 'background'  black
                       [255, 0, 0],       # class 1 'grasp'       red
                       [255, 255, 0],     # class 2 'cut'         yellow
                       [0, 255, 0],       # class 3 'scoop'       green
                       [0, 255, 255],     # class 4 'contain'     sky blue
                       [0, 0, 255],       # class 5 'pound'       blue
                       [255, 0, 255],     # class 6 'support'     purple
                       [255, 255, 255]    # class 7 'wrap grasp'  white
                      ])

In [20]:
def class_to_mask(cls):
    
    mask = colors[cls].transpose(1, 2).transpose(1, 3)
    
    return mask

In [21]:
def predict(model, sample, device='cpu'):
    model.eval()
    model.to(device)
    
    x, y = sample['image'], sample['class']
    
    x = x.to(device)
    y = y.to(device)

    with torch.no_grad():
        _, y_pred = model(x).max(1)    # y_pred.shape => (N, 240, 320)
    
    true_mask = class_to_mask(y).to('cpu')
    pred_mask = class_to_mask(y_pred).to('cpu')
    
    save_image(true_mask, "./SegNet_with_class_weight_results/true_mask_with_SegNet_with_class_weight.jpg")
    save_image(pred_mask, "./SegNet_with_class_weight_results/pred_mask_with_SegNet_with_class_weight.jpg")

In [25]:
trained_model = SegNetBasic(3, 8)
trained_model.load_state_dict(torch.load("./SegNet_with_class_weight_results/best_iou_model.prm",
                                        map_location=lambda storage, loc: storage))

In [26]:
eval_data = PartAffordanceDataset('eval.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    ToTensor(),
                                    Normalize()
                                ]))

In [27]:
def reverse_normalize(x, mean, std):
    x[:, 0, :, :] = x[:, 0, :, :] * std[0] + mean[0]
    x[:, 1, :, :] = x[:, 1, :, :] * std[1] + mean[1]
    x[:, 2, :, :] = x[:, 2, :, :] * std[2] + mean[2]
    return x

In [28]:
eval_loader = DataLoader(eval_data, batch_size=8, shuffle=False)

In [29]:
mean=[55.8630, 59.9099, 91.7419]
std=[31.6852, 29.8496, 19.0835]

for sample in eval_loader:
    trained_model.eval()
    
    predict(trained_model, sample)
    
    x = sample["image"]
    x = reverse_normalize(x, mean, std)
    save_image(x/255, "./SegNet_with_class_weight_results/original_img_with_SegNet_with_class_weight.jpg")
    
    break