In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from glob import glob
import os
from skimage.io import imread
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
dsb_data_dir = os.path.join('.', 'data')
stage_label = 'data/stage1'

In [2]:
train_labels = pd.read_csv('{}_train_labels.csv'.format(stage_label))
train_labels['EncodedPixels'] = train_labels['EncodedPixels'].map(lambda ep: [int(x) for x in ep.split(' ')])
train_labels.sample(3)

Unnamed: 0,ImageId,EncodedPixels
27487,ed4b8e0d756836be7acb2e2b7799c473b52424e3092a71...,"[8870, 4, 9124, 7, 9379, 8, 9635, 8, 9891, 8, ..."
6128,317832f90f02c5e916b2ac0f3bcb8da9928d8e400b747b...,"[36255, 4, 36510, 7, 36766, 8, 37022, 9, 37278..."
8287,4590d7d47f521df62f3bcb0bf74d1bca861d94ade614d8...,"[7433, 5, 7688, 12, 7943, 15, 8198, 17, 8454, ..."


In [3]:
all_images = glob(os.path.join(dsb_data_dir, 'stage1_*', '*', '*', '*'))
img_df = pd.DataFrame({'path': all_images})
img_id = lambda in_path: in_path.split('/')[-3]
img_type = lambda in_path: in_path.split('/')[-2]
img_group = lambda in_path: in_path.split('/')[-4].split('_')[1]
img_stage = lambda in_path: in_path.split('/')[-4].split('_')[0]
img_df['ImageId'] = img_df['path'].map(img_id)
img_df['ImageType'] = img_df['path'].map(img_type)
img_df['TrainingSplit'] = img_df['path'].map(img_group)
img_df['Stage'] = img_df['path'].map(img_stage)
img_df.sample(2)

Unnamed: 0,path,ImageId,ImageType,TrainingSplit,Stage
14175,./data/stage1_train/a891bbc89143bca7a717386144...,a891bbc89143bca7a717386144eb061ec2d599cba24681...,masks,train,stage1
3116,./data/stage1_train/2a2032c4ed78f3fc64de7e5efd...,2a2032c4ed78f3fc64de7e5efd0bec26a81680b07404ea...,masks,train,stage1


In [4]:
train_df = img_df[img_df.TrainingSplit == 'train']
train_rows = []
group_cols = ['Stage', 'ImageId']
for n_group, n_rows in train_df.groupby(group_cols):
    c_row = {col_name: col_value for col_name, col_value in zip(group_cols, n_group)}
    c_row['masks'] = n_rows[n_rows.ImageType == 'masks']['path'].values.tolist()
    c_row['images'] = n_rows[n_rows.ImageType == 'images']['path'].values.tolist()
    train_rows += [c_row]
train_img_df = pd.DataFrame(train_rows)    

# np.stack to squash multidim-list to ndarray
IMG_CHANNELS = 3
def read_and_stack(in_img_list):
    img = [imread(c_img) for c_img in in_img_list]
    img_stack = np.stack(img, 0)
    img_sum = np.sum(img_stack, 0)
#     return img_sum/255.0
    return img_sum

train_img_df['images'] = train_img_df['images'].map(read_and_stack).map(lambda x: x[:,:,:IMG_CHANNELS])
train_img_df['masks'] = train_img_df['masks'].map(read_and_stack).map(lambda x: x.astype(int))
train_img_df.sample(3)

Unnamed: 0,ImageId,Stage,images,masks
664,fe80a2cf3c93dafad8c364fdd1646b0ba4db056cdb7bdb...,stage1,"[[[151, 151, 151], [144, 144, 144], [136, 136,...","[[255, 255, 255, 255, 255, 255, 255, 255, 255,..."
596,e4fc936ba57a936aaa5941ccc70946ab18fcebcb6e8d85...,stage1,"[[[13, 13, 13], [13, 13, 13], [14, 14, 14], [1...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
458,ad9d305cbf193d4250743ead466bdaefe910835d7e352c...,stage1,"[[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], ...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."


In [5]:
n_img = 6
fig, m_axs = plt.subplots(2, n_img, figsize = (12, 4))
for (_, c_row), (c_im, c_lab) in zip(train_img_df.sample(n_img).iterrows(), m_axs.T):
    c_im.imshow(c_row['images'])
    c_im.axis('off')
    c_im.set_title('Microscope')
    
    c_lab.imshow(c_row['masks'])
    c_lab.axis('off')
    c_lab.set_title('Labeled')

Error in callback <function install_repl_displayhook.<locals>.post_execute at 0x7f67276c8400> (for post_execute):


ValueError: 3-dimensional arrays must be of dtype unsigned byte, unsigned short, float32 or float64

ValueError: 3-dimensional arrays must be of dtype unsigned byte, unsigned short, float32 or float64

<matplotlib.figure.Figure at 0x7f671fd71390>

In [6]:
#data loader
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
from PIL import Image
import random

class NucleusDataset(data.Dataset):
    def __init__(self, df, transform=None, target_transform=None):
        self.df = df
        self.transform = transform
        self.target_transform = target_transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        image = self.df.iloc[index]['images']
        mask = self.df.iloc[index]['masks']
        
        image = Image.fromarray(np.uint8(image))
        mask = Image.fromarray(np.uint8(mask))
        
        seed = np.random.randint(123456) 
        if self.transform:
            random.seed(seed)
            image = self.transform(image)
        if self.target_transform:
            random.seed(seed)
            mask = self.target_transform(mask)
            
        return image, mask
    

In [19]:
CROP_SIZE = 224
NUM_WORKERS = 2
BATCH_SIZE = 16
SHUFFLE = True
LEARNING_RATE = 0.001
EPOCH = 30

transform = transforms.Compose([ 
        transforms.RandomCrop(CROP_SIZE),
        transforms.RandomHorizontalFlip(), 
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(), 
        transforms.Normalize((0.5, 0.5, 0.5), 
                             (0.5, 0.5, 0.5))])

target_transform = transforms.Compose([ 
        transforms.RandomCrop(CROP_SIZE),
        transforms.RandomHorizontalFlip(), 
        transforms.RandomVerticalFlip(),
        transforms.ToTensor()])


nucleus_dataset = NucleusDataset(train_img_df, transform, target_transform)

nucleus_dataloader = torch.utils.data.DataLoader(dataset=nucleus_dataset, batch_size=BATCH_SIZE,
                                                shuffle=SHUFFLE, num_workers=NUM_WORKERS)


In [8]:
i,m =nucleus_dataset[0]
pilTrans = transforms.ToPILImage()

fig, axes = plt.subplots(1, 2)
axes[0].imshow(pilTrans(i))
axes[1].imshow(pilTrans(m))


<matplotlib.image.AxesImage at 0x7f66663f69e8>

Error in callback <function install_repl_displayhook.<locals>.post_execute at 0x7f67276c8400> (for post_execute):


AttributeError: 'numpy.ndarray' object has no attribute 'mask'

AttributeError: 'numpy.ndarray' object has no attribute 'mask'

<matplotlib.figure.Figure at 0x7f6718d9a748>

In [9]:
image, mask = iter(nucleus_dataloader).next()
print(image.shape, mask.shape)


torch.Size([16, 3, 224, 224]) torch.Size([16, 1, 224, 224])


In [10]:
#UNET

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from collections import OrderedDict
from torch.nn import init
import numpy as np

def conv3x3(in_channels, out_channels, stride=1, 
            padding=1, bias=True, groups=1):    
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=3,
        stride=stride,
        padding=padding,
        bias=bias,
        groups=groups)

def upconv2x2(in_channels, out_channels, mode='transpose'):
    if mode == 'transpose':
        return nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=2,
            stride=2)
    else:
        # out_channels is always going to be the same
        # as in_channels
        return nn.Sequential(
            nn.Upsample(mode='bilinear', scale_factor=2),
            conv1x1(in_channels, out_channels))

def conv1x1(in_channels, out_channels, groups=1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=1,
        groups=groups,
        stride=1)


class DownConv(nn.Module):
    """
    A helper Module that performs 2 convolutions and 1 MaxPool.
    A ReLU activation follows each convolution.
    """
    def __init__(self, in_channels, out_channels, pooling=True):
        super(DownConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.pooling = pooling

        self.conv1 = conv3x3(self.in_channels, self.out_channels)
        self.conv2 = conv3x3(self.out_channels, self.out_channels)

        if self.pooling:
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        before_pool = x
        if self.pooling:
            x = self.pool(x)
        return x, before_pool


class UpConv(nn.Module):
    """
    A helper Module that performs 2 convolutions and 1 UpConvolution.
    A ReLU activation follows each convolution.
    """
    def __init__(self, in_channels, out_channels, 
                 merge_mode='concat', up_mode='transpose'):
        super(UpConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.merge_mode = merge_mode
        self.up_mode = up_mode

        self.upconv = upconv2x2(self.in_channels, self.out_channels, 
            mode=self.up_mode)

        if self.merge_mode == 'concat':
            self.conv1 = conv3x3(
                2*self.out_channels, self.out_channels)
        else:
            # num of input channels to conv2 is same
            self.conv1 = conv3x3(self.out_channels, self.out_channels)
        self.conv2 = conv3x3(self.out_channels, self.out_channels)


    def forward(self, from_down, from_up):
        """ Forward pass
        Arguments:
            from_down: tensor from the encoder pathway
            from_up: upconv'd tensor from the decoder pathway
        """
        from_up = self.upconv(from_up)
        if self.merge_mode == 'concat':
            x = torch.cat((from_up, from_down), 1)
        else:
            x = from_up + from_down
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x


class UNet(nn.Module):
    """ `UNet` class is based on https://arxiv.org/abs/1505.04597
    The U-Net is a convolutional encoder-decoder neural network.
    Contextual spatial information (from the decoding,
    expansive pathway) about an input tensor is merged with
    information representing the localization of details
    (from the encoding, compressive pathway).
    Modifications to the original paper:
    (1) padding is used in 3x3 convolutions to prevent loss
        of border pixels
    (2) merging outputs does not require cropping due to (1)
    (3) residual connections can be used by specifying
        UNet(merge_mode='add')
    (4) if non-parametric upsampling is used in the decoder
        pathway (specified by upmode='upsample'), then an
        additional 1x1 2d convolution occurs after upsampling
        to reduce channel dimensionality by a factor of 2.
        This channel halving happens with the convolution in
        the tranpose convolution (specified by upmode='transpose')
    """

    def __init__(self, num_classes, in_channels=3, depth=5, 
                 start_filts=64, up_mode='transpose', 
                 merge_mode='concat'):
        """
        Arguments:
            in_channels: int, number of channels in the input tensor.
                Default is 3 for RGB images.
            depth: int, number of MaxPools in the U-Net.
            start_filts: int, number of convolutional filters for the 
                first conv.
            up_mode: string, type of upconvolution. Choices: 'transpose'
                for transpose convolution or 'upsample' for nearest neighbour
                upsampling.
        """
        super(UNet, self).__init__()

        if up_mode in ('transpose', 'upsample'):
            self.up_mode = up_mode
        else:
            raise ValueError("\"{}\" is not a valid mode for "
                             "upsampling. Only \"transpose\" and "
                             "\"upsample\" are allowed.".format(up_mode))
    
        if merge_mode in ('concat', 'add'):
            self.merge_mode = merge_mode
        else:
            raise ValueError("\"{}\" is not a valid mode for"
                             "merging up and down paths. "
                             "Only \"concat\" and "
                             "\"add\" are allowed.".format(up_mode))

        # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
        if self.up_mode == 'upsample' and self.merge_mode == 'add':
            raise ValueError("up_mode \"upsample\" is incompatible "
                             "with merge_mode \"add\" at the moment "
                             "because it doesn't make sense to use "
                             "nearest neighbour to reduce "
                             "depth channels (by half).")

        self.num_classes = num_classes
        self.in_channels = in_channels
        self.start_filts = start_filts
        self.depth = depth

        self.down_convs = []
        self.up_convs = []

        # create the encoder pathway and add to a list
        for i in range(depth):
            ins = self.in_channels if i == 0 else outs
            outs = self.start_filts*(2**i)
            pooling = True if i < depth-1 else False

            down_conv = DownConv(ins, outs, pooling=pooling)
            self.down_convs.append(down_conv)

        # create the decoder pathway and add to a list
        # - careful! decoding only requires depth-1 blocks
        for i in range(depth-1):
            ins = outs
            outs = ins // 2
            up_conv = UpConv(ins, outs, up_mode=up_mode,
                merge_mode=merge_mode)
            self.up_convs.append(up_conv)

        self.conv_final = conv1x1(outs, self.num_classes)

        # add the list of modules to current module
        self.down_convs = nn.ModuleList(self.down_convs)
        self.up_convs = nn.ModuleList(self.up_convs)

        self.reset_params()

    @staticmethod
    def weight_init(m):
        if isinstance(m, nn.Conv2d):
            init.xavier_normal(m.weight)
            init.constant(m.bias, 0)


    def reset_params(self):
        for i, m in enumerate(self.modules()):
            self.weight_init(m)


    def forward(self, x):
        encoder_outs = []
         
        # encoder pathway, save outputs for merging
        for i, module in enumerate(self.down_convs):
            x, before_pool = module(x)
            encoder_outs.append(before_pool)

        for i, module in enumerate(self.up_convs):
            before_pool = encoder_outs[-(i+2)]
            x = module(before_pool, x)
        
        # No softmax is used. This means you need to use
        # nn.CrossEntropyLoss is your training script,
        # as this module includes a softmax already.
        x = self.conv_final(x)
        return x


In [11]:
#Loss
class CrossEntropyLoss2d(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(CrossEntropyLoss2d, self).__init__()
        self.nll_loss = nn.NLLLoss2d(weight, size_average)

    def forward(self, inputs, targets):
        return self.nll_loss(F.log_softmax(inputs), targets)

In [12]:
def to_var(x, volatile=False):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x, volatile=volatile)

In [13]:
WEIGHT = torch.ones(22)
WEIGHT[21] = 0

use_cuda = torch.cuda.is_available()

unet = UNet(2, depth=5, merge_mode='concat', up_mode='transpose')
optimizer = torch.optim.Adam(unet.parameters(), lr=LEARNING_RATE)
criterion = CrossEntropyLoss2d()

if use_cuda:
    unet.cuda()

In [21]:
MODEL_PATH = '.'
LOG_STEP=10
SAVE_STEP=42
total_step = len(nucleus_dataloader)
for epoch in range(EPOCH):
    for i, data in enumerate(nucleus_dataloader):
        image, mask = data

        image = to_var(image)
        mask = to_var(mask.type(torch.LongTensor))

        optimizer.zero_grad()
        out = unet(image)
        
        loss = criterion(out, mask.squeeze(1))
        loss.backward()
        optimizer.step()
        
        # Print log info
        if i % LOG_STEP == 0:
            print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f'%(epoch, EPOCH, i, total_step, loss.data[0])) 

        # Save the models
        if (i+1) % SAVE_STEP == 0:
            torch.save(unet.state_dict(), os.path.join(MODEL_PATH, 'unet-%d-%d.pkl' %(epoch+1, i+1)))



Epoch [0/30], Step [0/42], Loss: 0.2217
Epoch [0/30], Step [10/42], Loss: 0.1552
Epoch [0/30], Step [20/42], Loss: 0.0958
Epoch [0/30], Step [30/42], Loss: 0.1728
Epoch [0/30], Step [40/42], Loss: 0.0935
Epoch [1/30], Step [0/42], Loss: 0.1253
Epoch [1/30], Step [10/42], Loss: 0.1363
Epoch [1/30], Step [20/42], Loss: 0.1811
Epoch [1/30], Step [30/42], Loss: 0.1888
Epoch [1/30], Step [40/42], Loss: 0.1305
Epoch [2/30], Step [0/42], Loss: 0.2426
Epoch [2/30], Step [10/42], Loss: 0.1476
Epoch [2/30], Step [20/42], Loss: 0.0872
Epoch [2/30], Step [30/42], Loss: 0.0645
Epoch [2/30], Step [40/42], Loss: 0.1556
Epoch [3/30], Step [0/42], Loss: 0.0785
Epoch [3/30], Step [10/42], Loss: 0.1781
Epoch [3/30], Step [20/42], Loss: 0.1315
Epoch [3/30], Step [30/42], Loss: 0.1060
Epoch [3/30], Step [40/42], Loss: 0.1832
Epoch [4/30], Step [0/42], Loss: 0.2352
Epoch [4/30], Step [10/42], Loss: 0.2052
Epoch [4/30], Step [20/42], Loss: 0.3035
Epoch [4/30], Step [30/42], Loss: 0.1760
Epoch [4/30], Step [4

In [None]:
# save model weight


In [160]:
data = iter(nucleus_dataloader).next()

image, mask = data
mask = mask.type(torch.LongTensor)

image = Variable(image)
mask = Variable(mask)

out = unet(image)

loss = criterion(out, mask.squeeze(1))
print(loss)

Variable containing:
 0.7111
[torch.FloatTensor of size 1]





In [26]:
out.shape

torch.Size([2, 1, 224, 224])

In [61]:
mask.squeeze(1).shape

torch.Size([2, 224, 224])

In [159]:
criterion(out, mask.squeeze(1))



Variable containing:
 0.7111
[torch.FloatTensor of size 1]