A high level EDA --> dataloaders --> model --> loss function --> training loop

In [None]:
from fastai.vision.all import *
from torch.utils.data import Dataset, DataLoader
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
path = Path('../input/sartorius-cell-instance-segmentation')
files = get_image_files(path/'train')
files

Let's quickly take a look at what's been provided for our training data

So we have 606 images to use for training and validation of our models

We can easily plot one of these image files as seen below

In [None]:
Image.open(files[0]).resize((256,256))

Since this is a segmentation problem, we'll need to find the reference annotations -- these are in our train.csv file

In [None]:
df = pd.read_csv(path/'train.csv')
df.shape

In [None]:
df.head(1)

Looks like the annotation column has our segmentation masks -- "run length encoded pixels". What's that mean?

In [None]:
df.width.unique(), df.height.unique()

Looks like all of our images are standardized to one size -- 704 pixels wide and 520 pixels in height

In addition, we have information on the following:<br>
1) cell_type: the cell line <br>
2) plate_time: the time the plate was created <br>
3) sample_date: the date the sample was created <br>
4) Sample_id: sample identifier -- not sure what the utility of having this is <br>
5) elapsed_timedelta: time since first image taken of smaple<br>

***None of the metadata is provided for test set***

In [None]:
#df.cell_type.unique()
#df.plate_time.unique()
#train_df.sample_date.unique()
#len(train_df.sample_id.unique())
#train_df.elapsed_timedelta.unique()

In [None]:
def get_image(path):
    image = np.array(Image.open(path))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image

def get_annot(img_id):
    return df[df.id == img_id].annotation.values

def get_mask(img_annotations, colors):
    mask = np.zeros((502, 704, 3))
    for annot in img_annotations:
        mask += rle_decode(annot, shape=(502, 704, 3), color=colors)
    mask = mask.clip(0, 1)
    return mask

In [None]:
#https://www.kaggle.com/inversion/run-length-decoding-quick-start
def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros((shape[0] * shape[1], shape[2]), dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape)

In [None]:
colors = np.random.rand(3)
colors

In [None]:
plt.imshow(get_mask(get_annot(files[0].name[:-4]), colors));

Lets build our dataset object with the functions shown above -- we're not going to want to do that mask calculation everytime we want a sample in our real training loop, but it will serve it's purpose for now. It will be better to extract the masks once and same them as their own files. This will reduce the training time per epoch... allowing more time for trying out different ideas

In [None]:
class masked_ds(Dataset):
    def __init__(self, files, df, colors, 
                 img_transforms=None, mask_transforms=None):
        self.files = files
        self.df = df
        self.colors = colors
        self.img_transforms = img_transforms
        self.mask_transforms = mask_transforms
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        img_path = self.files[idx]
        name = img_path.name[:-4]
        annotations = get_annot(name)
        mask = get_mask(annotations, self.colors)
        img = get_image(img_path)
        if(self.img_transforms):
            img = self.img_transforms(image=img)['image'].float()
        if(self.mask_transforms):
            mask = self.mask_transforms(image=mask)['image'].float()
        return img, mask

In [None]:
img_transforms = A.Compose([
                A.RandomResizedCrop(448, 448),
                A.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]),
                ToTensorV2()
              ])

mask_transforms = A.Compose([
                A.RandomResizedCrop(448, 448),
                ToTensorV2()
              ])

In [None]:
split = int(len(files) * 0.8)
train_files = files[:split]
valid_files = files[split:]
train_ds = masked_ds(train_files, df, colors, img_transforms, mask_transforms)
valid_ds = masked_ds(valid_files, df, colors, img_transforms, mask_transforms)
train_dl = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)
dls = DataLoaders(train_dl, valid_dl)

In [None]:
batch = next(iter(dls.train))
batch[0].shape, batch[1].shape

Model, loss function etc brought to you by @lafoss: https://www.kaggle.com/iafoss/hubmap-pytorch-fast-ai-starter?scriptVersionId=56448549

In [None]:
class FPN(nn.Module):
    def __init__(self, input_channels:list, output_channels:list):
        super().__init__()
        self.convs = nn.ModuleList(
            [nn.Sequential(nn.Conv2d(in_ch, out_ch*2, kernel_size=3, padding=1),
             nn.ReLU(inplace=True), nn.BatchNorm2d(out_ch*2),
             nn.Conv2d(out_ch*2, out_ch, kernel_size=3, padding=1))
            for in_ch, out_ch in zip(input_channels, output_channels)])
        
    def forward(self, xs:list, last_layer):
        hcs = [F.interpolate(c(x),scale_factor=2**(len(self.convs)-i),mode='bilinear') 
               for i,(c,x) in enumerate(zip(self.convs, xs))]
        hcs.append(last_layer)
        return torch.cat(hcs, dim=1)

class UnetBlock(Module):
    def __init__(self, up_in_c:int, x_in_c:int, nf:int=None, blur:bool=False,
                 self_attention:bool=False, **kwargs):
        super().__init__()
        self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, **kwargs)
        self.bn = nn.BatchNorm2d(x_in_c)
        ni = up_in_c//2 + x_in_c
        nf = nf if nf is not None else max(up_in_c//2,32)
        self.conv1 = ConvLayer(ni, nf, norm_type=None, **kwargs)
        self.conv2 = ConvLayer(nf, nf, norm_type=None,
            xtra=SelfAttention(nf) if self_attention else None, **kwargs)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, up_in:Tensor, left_in:Tensor) -> Tensor:
        s = left_in
        up_out = self.shuf(up_in)
        cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
        return self.conv2(self.conv1(cat_x))
        
class _ASPPModule(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation, groups=1):
        super().__init__()
        self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                stride=1, padding=padding, dilation=dilation, bias=False, groups=groups)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()

        self._init_weight()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)

        return self.relu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class ASPP(nn.Module):
    def __init__(self, inplanes=512, mid_c=256, dilations=[6, 12, 18, 24], out_c=None):
        super().__init__()
        self.aspps = [_ASPPModule(inplanes, mid_c, 1, padding=0, dilation=1)] + \
            [_ASPPModule(inplanes, mid_c, 3, padding=d, dilation=d,groups=4) for d in dilations]
        self.aspps = nn.ModuleList(self.aspps)
        self.global_pool = nn.Sequential(nn.AdaptiveMaxPool2d((1, 1)),
                        nn.Conv2d(inplanes, mid_c, 1, stride=1, bias=False),
                        nn.BatchNorm2d(mid_c), nn.ReLU())
        out_c = out_c if out_c is not None else mid_c
        self.out_conv = nn.Sequential(nn.Conv2d(mid_c*(2+len(dilations)), out_c, 1, bias=False),
                                    nn.BatchNorm2d(out_c), nn.ReLU(inplace=True))
        self.conv1 = nn.Conv2d(mid_c*(2+len(dilations)), out_c, 1, bias=False)
        self._init_weight()

    def forward(self, x):
        x0 = self.global_pool(x)
        xs = [aspp(x) for aspp in self.aspps]
        x0 = F.interpolate(x0, size=xs[0].size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x0] + xs, dim=1)
        return self.out_conv(x)
    
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class UneXt50(nn.Module):
    def __init__(self, stride=1, **kwargs):
        super().__init__()
        #encoder
        m = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models',
                           'resnext50_32x4d_ssl')
        self.enc0 = nn.Sequential(m.conv1, m.bn1, nn.ReLU(inplace=True))
        self.enc1 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1),
                            m.layer1) #256
        self.enc2 = m.layer2 #512
        self.enc3 = m.layer3 #1024
        self.enc4 = m.layer4 #2048
        #aspp with customized dilatations
        self.aspp = ASPP(2048,256,out_c=512,dilations=[stride*1,stride*2,stride*3,stride*4])
        self.drop_aspp = nn.Dropout2d(0.5)
        #decoder
        self.dec4 = UnetBlock(512,1024,256)
        self.dec3 = UnetBlock(256,512,128)
        self.dec2 = UnetBlock(128,256,64)
        self.dec1 = UnetBlock(64,64,32)
        self.fpn = FPN([512,256,128,64],[16]*4)
        self.drop = nn.Dropout2d(0.1)
        self.final_conv = ConvLayer(32+16*4, 3, ks=1, norm_type=None, act_cls=None)
        
    def forward(self, x):
        enc0 = self.enc0(x)
        enc1 = self.enc1(enc0)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        enc5 = self.aspp(enc4)
        dec3 = self.dec4(self.drop_aspp(enc5),enc3)
        dec2 = self.dec3(dec3,enc2)
        dec1 = self.dec2(dec2,enc1)
        dec0 = self.dec1(dec1,enc0)
        x = self.fpn([enc5, dec3, dec2, dec1], dec0)
        x = self.final_conv(self.drop(x))
        x = F.interpolate(x,scale_factor=2,mode='bilinear')
        return x

#split the model to encoder and decoder for fast.ai
split_layers = lambda m: [list(m.enc0.parameters())+list(m.enc1.parameters())+
                list(m.enc2.parameters())+list(m.enc3.parameters())+
                list(m.enc4.parameters()),
                list(m.aspp.parameters())+list(m.dec4.parameters())+
                list(m.dec3.parameters())+list(m.dec2.parameters())+
                list(m.dec1.parameters())+list(m.fpn.parameters())+
                list(m.final_conv.parameters())]

In [None]:
model = UneXt50()

In [None]:
out = model(batch[0])
out.shape

In [None]:
plt.imshow(out.detach().cpu().numpy()[0].transpose(1,2,0));

At least it looks cool?

In [None]:
plt.imshow(batch[1][0].cpu().numpy().transpose(1,2,0));

Just a reminder of what our label mask should look like

In [None]:
def dice_loss(inp, target):
    inp = torch.sigmoid(inp)
    smooth = 1.0

    iflat = inp.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    
    return ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))

def IoU(pred, targs):
    pred = (pred>0).float()
    intersection = (pred*targs).sum()
    return intersection / ((pred+targs).sum() - intersection + 1.0)

Let's just double check that the dice_loss works from lafoss, no idea if this is the appropriate one to use at the moment, but it can serve the purpose of confirming our pipelines feasibility

In [None]:
dice_loss(out,batch[1])

In [None]:
learner = Learner(dls, model, loss_func=dice_loss, metrics=IoU)

In [None]:
learner.fit_one_cycle(3)

Ok, so looks like that isn't working too great. But that's not a problem, we have the starting point to go from data to training and this should allow us to try out different ideas moving forward.

This whole thing needs a ton of more work.... stay tuned