Try to implement a simple convnet using the same datasets I used for UNet.

In [1]:
from __future__ import print_function
import torch.nn as nn
import os
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from skimage import io, transform
import utils_xy
from torchvision import transforms, utils
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image

In [2]:
gpu_id = 1
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

device = torch.device('cuda')
print (device)

cuda


### I. Data observation
* we have 160 training images and 160 training mask (groundtruth label for every pixel in the training images).
* for training images, the data type is uint8
* for training mask, the data type is uint8
* the image size is near 375x1242x3, some maybe 370x1242x3
* the mask size is near 375z1242x3, some maybe 370x1242x3

since the image size are different in the training set, we need to do the data preprocessing 

In [3]:
# load the training images and the semantic segmentation image (mask image)
img_dir = '/home/xiaoyu/data_semantics/training/train/image/'
mask_dir = '/home/xiaoyu/data_semantics/training/train/mask_rgb/'

img_list = os.listdir(img_dir)
mask_list = os.listdir(mask_dir)

print("Training images numbers for different folder of training images: "+str(len(img_list)))
print("Training mask Images numbers:"+str(len(mask_list)))

Training images numbers for different folder of training images: 160
Training mask Images numbers:160


#### Get the class mapping of the mask. i.e. Get the classes of the images. There are 29 classes in total. For each pixel in the image, there should be a label assigned to this pixel

In [9]:
colors_all = torch.tensor([])
for i in range(len(mask_list)):
    mask_str = mask_list[i]
    mask_arr = io.imread(os.path.join(mask_dir, mask_str))
    mask_tensor = torch.from_numpy(mask_arr)
    mask_tensor = mask_tensor.permute(2,0,1)
    
    mask_tensor = mask_tensor.to(device)
    colors = torch.unique(mask_tensor.view(mask_tensor.size(0), -1),dim=1)
    colors = colors.permute(1,0).type(torch.FloatTensor) 
    colors_all = torch.cat((colors_all, colors))
    colors_unique = torch.unique(colors_all, dim = 0)
print(colors_unique.shape)
print(colors_unique)

torch.Size([29, 3])
tensor([[  0.,   0.,   0.],
        [  0.,   0.,  70.],
        [  0.,   0.,  90.],
        [  0.,   0., 110.],
        [  0.,   0., 142.],
        [  0.,   0., 230.],
        [  0.,  60., 100.],
        [  0.,  80., 100.],
        [ 70.,  70.,  70.],
        [ 70., 130., 180.],
        [ 81.,   0.,  81.],
        [102., 102., 156.],
        [107., 142.,  35.],
        [111.,  74.,   0.],
        [119.,  11.,  32.],
        [128.,  64., 128.],
        [150., 100., 100.],
        [150., 120.,  90.],
        [152., 251., 152.],
        [153., 153., 153.],
        [180., 165., 180.],
        [190., 153., 153.],
        [220.,  20.,  60.],
        [220., 220.,   0.],
        [230., 150., 140.],
        [244.,  35., 232.],
        [250., 170.,  30.],
        [250., 170., 160.],
        [255.,   0.,   0.]])


### II. Preprocess of the data
* construction of the dataset class 
* normalisation of the training image 

### II.I Define a Train Class

In [11]:
class TrainDataset(Dataset):
    
    def __init__(self, img_dir, mask_dir, transform=None):
      
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        
        self.mapping = {
        torch.tensor([  0,   0,   0], dtype=torch.uint8):0,
        torch.tensor([  0,   0,  70], dtype=torch.uint8):1,
        torch.tensor([  0,   0,  90], dtype=torch.uint8):2,
        torch.tensor([  0,   0, 110], dtype=torch.uint8):3,
        torch.tensor([  0,   0, 142], dtype=torch.uint8):4,
        torch.tensor([  0,   0, 230], dtype=torch.uint8):5,
        torch.tensor([  0,  60, 100], dtype=torch.uint8):6,
        torch.tensor([  0,  80, 100.], dtype=torch.uint8):7,
        torch.tensor([ 70,  70,  70], dtype=torch.uint8):8,
        torch.tensor([ 70, 130, 180], dtype=torch.uint8):9,
        torch.tensor([ 81,   0,  81], dtype=torch.uint8):10,
        torch.tensor([102, 102, 156], dtype=torch.uint8):11,
        torch.tensor([107, 142,  35], dtype=torch.uint8):12,
        torch.tensor([111,  74,   0], dtype=torch.uint8):13,
        torch.tensor([119,  11,  32], dtype=torch.uint8):14,
        torch.tensor([128,  64, 128], dtype=torch.uint8):15,
        torch.tensor([150, 100, 100], dtype=torch.uint8):16,
        torch.tensor([150, 120,  90], dtype=torch.uint8):17,
        torch.tensor([152, 251, 152], dtype=torch.uint8):18,
        torch.tensor([153, 153, 153], dtype=torch.uint8):19,
        torch.tensor([180, 165, 180], dtype=torch.uint8):20,
        torch.tensor([190, 153, 153], dtype=torch.uint8):21,
        torch.tensor([220,  20,  60], dtype=torch.uint8):22,
        torch.tensor([220, 220,   0], dtype=torch.uint8):23,
        torch.tensor([230, 150, 140], dtype=torch.uint8):24,
        torch.tensor([244,  35, 232], dtype=torch.uint8):25,
        torch.tensor([250, 170,  30], dtype=torch.uint8):26,
        torch.tensor([250, 170, 160], dtype=torch.uint8):27,
        torch.tensor([255,   0,   0], dtype=torch.uint8):28
        }
        
    def __len__(self):
        return len(os.listdir(self.img_dir))
    
    def mask_to_class(self, mask):
        for k in self.mapping:
            print(k.dtype)
            print(mask.dtype)
#             mask[mask==k] = self.mapping[k]
        return mask
    
    def __getitem__(self, idx):
        img_list = os.listdir(img_dir)
        mask_list = os.listdir(mask_dir)
        
        img_str = img_list[idx]
        img_arr = io.imread(os.path.join(img_dir, img_str))
        img_tensor = torch.from_numpy(img_arr)
        img_tensor = img_tensor.permute(2,0,1)
        
        mask_str = mask_list[idx]
        mask_arr = io.imread(os.path.join(mask_dir, mask_str))
        mask_tensor = torch.from_numpy(mask_arr)
        mask_tensor = mask_tensor.permute(2,0,1)
#         mask_tensor = mask_tensor.view(mask_tensor.size(0), -1).permute(1,0)
        print(mask_tensor)
        print(mask_tensor.shape)
  
        mask_tensor = self.mask_to_class(mask_tensor)
    
        print(mask_tensor.shape)
        sample = {'image':img_tensor, 'mask':mask_tensor}
        
        if self.transform:
            sample = self.transform(sample)
        return sample


In [12]:
traindata = TrainDataset(img_dir = img_dir, mask_dir = mask_dir)
print(traindata[0]['mask'])

tensor([[[ 70,  70,  70,  ...,  70,  70,  70],
         [ 70,  70,  70,  ...,  70,  70,  70],
         [ 70,  70,  70,  ...,  70,  70,  70],
         ...,
         [128, 128, 128,  ..., 244, 244, 244],
         [128, 128, 128,  ..., 244, 244, 244],
         [128, 128, 128,  ..., 244, 244, 244]],

        [[130, 130, 130,  ..., 130, 130, 130],
         [130, 130, 130,  ..., 130, 130, 130],
         [130, 130, 130,  ..., 130, 130, 130],
         ...,
         [ 64,  64,  64,  ...,  35,  35,  35],
         [ 64,  64,  64,  ...,  35,  35,  35],
         [ 64,  64,  64,  ...,  35,  35,  35]],

        [[180, 180, 180,  ..., 180, 180, 180],
         [180, 180, 180,  ..., 180, 180, 180],
         [180, 180, 180,  ..., 180, 180, 180],
         ...,
         [128, 128, 128,  ..., 232, 232, 232],
         [128, 128, 128,  ..., 232, 232, 232],
         [128, 128, 128,  ..., 232, 232, 232]]], dtype=torch.uint8)
torch.Size([3, 375, 1242])
torch.uint8
torch.uint8
torch.uint8
torch.uint8
torch.uint8


In [22]:
mapping = {
        torch.tensor([  0,   0,   0], dtype=torch.uint8):torch.tensor([0,0,0]),
        torch.tensor([  0,   0,  70], dtype=torch.uint8):torch.tensor([1,1,1]),
        torch.tensor([  0,   0,  90], dtype=torch.uint8):torch.tensor([2,2,2]),
        torch.tensor([  0,   0, 110], dtype=torch.uint8):torch.tensor([3,3,3]),
        torch.tensor([  0,   0, 142], dtype=torch.uint8):torch.tensor([4,4,4]),
        torch.tensor([  0,   0, 230], dtype=torch.uint8):torch.tensor([5,5,5]),
        torch.tensor([  0,  60, 100], dtype=torch.uint8):torch.tensor([6,6,6]),
        torch.tensor([  0,  80, 100.], dtype=torch.uint8):torch.tensor([7,7,7]),
        torch.tensor([ 70,  70,  70], dtype=torch.uint8):torch.tensor([8,8,8]),
        torch.tensor([ 70, 130, 180], dtype=torch.uint8):torch.tensor([9,9,9]),
        torch.tensor([ 81,   0,  81], dtype=torch.uint8):torch.tensor([10,10,10]),
        torch.tensor([102, 102, 156], dtype=torch.uint8):torch.tensor([11,11,11]),
        torch.tensor([107, 142,  35], dtype=torch.uint8):torch.tensor([12,12,12]),
        torch.tensor([111,  74,   0], dtype=torch.uint8):torch.tensor([13,13,13]),
        torch.tensor([119,  11,  32], dtype=torch.uint8):torch.tensor([14,14,14]),
        torch.tensor([128,  64, 128], dtype=torch.uint8):torch.tensor([15,15,15]),
        torch.tensor([150, 100, 100], dtype=torch.uint8):torch.tensor([16,16,16]),
        torch.tensor([150, 120,  90], dtype=torch.uint8):torch.tensor([17,17,17]),
        torch.tensor([152, 251, 152], dtype=torch.uint8):torch.tensor([18,18,18]),
        torch.tensor([153, 153, 153], dtype=torch.uint8):torch.tensor([19,19,19]),
        torch.tensor([180, 165, 180], dtype=torch.uint8):torch.tensor([20,20,20]),
        torch.tensor([190, 153, 153], dtype=torch.uint8):torch.tensor([21,21,21]),
        torch.tensor([220,  20,  60], dtype=torch.uint8):torch.tensor([22,22,22]),
        torch.tensor([220, 220,   0], dtype=torch.uint8):torch.tensor([23,23,23]),
        torch.tensor([230, 150, 140], dtype=torch.uint8):torch.tensor([24,24,24]),
        torch.tensor([244,  35, 232], dtype=torch.uint8):torch.tensor([25,25,25]),
        torch.tensor([250, 170,  30], dtype=torch.uint8):torch.tensor([26,26,26]),
        torch.tensor([250, 170, 160], dtype=torch.uint8):torch.tensor([27,27,27]),
        torch.tensor([255,   0,   0], dtype=torch.uint8):torch.tensor([28,28,28])
        }

In [50]:
mask_ex = mask_list[0]
mask_arr = io.imread(os.path.join(mask_dir, mask_ex))
mask_tensor = torch.from_numpy(mask_arr)
print(mask_tensor.size())
print(mask_tensor[0,20])

torch.Size([375, 1242, 3])
tensor([ 70, 130, 180], dtype=torch.uint8)


In [52]:
for k in mapping:
    print(k)
    if any([(k == mask_tensor_).all() for mask_tensor_ in mask_tensor]):
        print('a in c')
   


tensor([0, 0, 0], dtype=torch.uint8)
tensor([ 0,  0, 70], dtype=torch.uint8)
tensor([ 0,  0, 90], dtype=torch.uint8)
tensor([  0,   0, 110], dtype=torch.uint8)
tensor([  0,   0, 142], dtype=torch.uint8)
tensor([  0,   0, 230], dtype=torch.uint8)
tensor([  0,  60, 100], dtype=torch.uint8)
tensor([  0,  80, 100], dtype=torch.uint8)
tensor([70, 70, 70], dtype=torch.uint8)
tensor([ 70, 130, 180], dtype=torch.uint8)
tensor([81,  0, 81], dtype=torch.uint8)
tensor([102, 102, 156], dtype=torch.uint8)
tensor([107, 142,  35], dtype=torch.uint8)
tensor([111,  74,   0], dtype=torch.uint8)
tensor([119,  11,  32], dtype=torch.uint8)
tensor([128,  64, 128], dtype=torch.uint8)
tensor([150, 100, 100], dtype=torch.uint8)
tensor([150, 120,  90], dtype=torch.uint8)
tensor([152, 251, 152], dtype=torch.uint8)
tensor([153, 153, 153], dtype=torch.uint8)
tensor([180, 165, 180], dtype=torch.uint8)
tensor([190, 153, 153], dtype=torch.uint8)
tensor([220,  20,  60], dtype=torch.uint8)
tensor([220, 220,   0], dtype