In [None]:
import numpy as np
import pandas as pd
import torch
import torchvision
import matplotlib.pyplot as plt
import PIL.Image as Image

import os

In [None]:
data = pd.read_csv('../input/uw-madison-gi-tract-image-segmentation/train.csv')

In [None]:
train_path = '../input/uw-madison-gi-tract-image-segmentation/train/'

In [None]:
train_pics_pathes = []

In [None]:
for cases in os.listdir(train_path):
    for days in os.listdir(os.path.join(train_path, cases)):
        for slices in os.listdir(os.path.join(train_path, cases, days, 'scans')):
            train_pics_pathes.append(os.path.join(cases, days, 'scans', slices))

In [None]:
id_pics_matching = dict()
for pics_path in train_pics_pathes:
    pics_path_splitted = pics_path.split('/')
    name_splitted = pics_path_splitted[3].split('_')
    pics_id = pics_path_splitted[0] + '_' + pics_path_splitted[1].split('_')[1] + '_' + name_splitted[0] + '_' + name_splitted[1]
    id_pics_matching[pics_id] = os.path.join(train_path, pics_path)

In [None]:
data = data.dropna().reset_index()

In [None]:
img = Image.open('../input/uw-madison-gi-tract-image-segmentation/train/case30/case30_day0/scans/slice_0135_266_266_1.50_1.50.png')
img = torchvision.transforms.ToTensor()(img)

In [None]:
def look(img):
    plt.imshow(img.detach().cpu().permute(1, 2, 0))
    plt.title(str(img.shape))

In [None]:
def decode_rle(img, seq):
    img = img.clone()
    seq = seq.split()
    for start in range(0, len(seq), 2):
        start_x = int(seq[start]) % img.shape[2]
        start_y = int(seq[start]) // img.shape[1]
        for pix in range(start_x, start_x+int(seq[start+1])):
            img[0][start_y][pix] = 65536
    return img

<h4>Testing encoder

In [None]:
look(img)

In [None]:
decoded_img = decode_rle(img, data['segmentation'][33909])

In [None]:
look(decoded_img)

In [None]:
def encode_line(img, i):
    pix = 0
    rle = []
    while pix < img.shape[2]:
        if img[0][i][pix] == 65536:
            start = pix
            while pix < img.shape[2] and img[0][i][pix] == 65536:
                pix += 1
            rle.append(str(i*img.shape[1] + start))
            rle.append(str(pix-start))
        pix += 1
    return rle

def encode_rle(img):
    rle = []
    for i in range(img.shape[1]):
        rle += encode_line(img, i)
    return ' '.join(rle)

<h4>Testing decoder

In [None]:
seq_rle = encode_rle(decoded_img)

In [None]:
look(decode_rle(img, seq_rle))

<h4>Plot segmentation

In [None]:
def look_seg(img, segmentations):
    """
    segmentation is 3d array with represents 3 types of organs segmentation
    """
    img_seg = torch.zeros([3, img.shape[1], img.shape[2]])

    img_seg[0], img_seg[1], img_seg[2] = img[0], img[0], img[0]
    
    if img_seg.max() > 1:
        img_seg /= (2**16)
    
    colors = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
    
    for seg_channel in range(3):
        seg = segmentations[seg_channel].split()
        for start in range(0, len(seg), 2):
            start_x = int(seg[start]) % img.shape[2]
            start_y = int(seg[start]) // img.shape[1]
            for pix in range(start_x, start_x+int(seg[start+1])):
                for channel in range(3):
                    img_seg[channel][start_y][pix] = colors[seg_channel][channel]
    plt.imshow(img_seg.permute(1, 2, 0))

In [None]:
img_test = torchvision.transforms.ToTensor()(Image.open('../input/uw-madison-gi-tract-image-segmentation/train/case30/case30_day0/scans/slice_0137_266_266_1.50_1.50.png'))

In [None]:
img_seg = torchvision.transforms.ToTensor()(Image.open('../input/uw-madison-gi-tract-image-segmentation/train/case30/case30_day0/scans/slice_0137_266_266_1.50_1.50.png')) 

look_seg(img_test, [data['segmentation'][33911], '', data['segmentation'][33912]])

In [None]:
look_seg(img, [data['segmentation'][33909], '', ''])

<h4>Preparing data

In [None]:
data_full = dict()

for ids in data['id']:
    data_full[ids] = ['', '', '']

for i in range(len(data)):
    class_seg = data['class'][i]
    if class_seg == 'stomach':
        data_full[data['id'][i]][0] += data['segmentation'][i]
    if class_seg == 'small_bowel':
        data_full[data['id'][i]][1] += data['segmentation'][i]
    if class_seg == 'large_bowel':
        data_full[data['id'][i]][2] += data['segmentation'][i]

'id' -> [stomach_seg:str, small_b:str, large_b:str]

<h4>Working with device

In [None]:
def move_to(data, device):
    """
    moving data to device
    :param data: data to move
    :param device: device
    :return: moved data
    """
    if isinstance(data, (list, tuple)):
        return [move_to(x, device) for x in data]
    return data.to(device, non_blocking=True)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

<h4>Define Model

In [None]:
class Downsampler(torch.nn.Module):
    def __init__(self, in_channels, out_channels, pooling = True):
        super().__init__()
        self.pooling = pooling
        
        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, 3)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, 3)
        self.conv3 = torch.nn.Conv2d(out_channels, out_channels, 3)
        self.act = torch.nn.ReLU()
        self.pool = torch.nn.MaxPool2d(2)
        
    def forward(self, X):
        X = self.act(self.conv1(X))
        X = self.act(self.conv2(X))
        X = self.act(self.conv3(X))
        if self.pooling:
            X = self.pool(X)
        return X

In [None]:
class Upsampler(torch.nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.deconv = torch.nn.ConvTranspose2d(in_channels, in_channels//2, 2, 2)
        self.conv1 = torch.nn.Conv2d(in_channels, in_channels//2, 3)
        self.conv2 = torch.nn.Conv2d(in_channels//2, in_channels//2, 3)
        self.act = torch.nn.ReLU()
        
    def forward(self, X, X_cat):
        X = self.act(self.deconv(X))
        X = torch.cat([X, X_cat], axis=1)
        X = self.act(self.conv1(X))
        X = self.act(self.conv2(X))
        return X

In [None]:
class UNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.downsampler1 = Downsampler(1, 64)
        self.downsampler2 = Downsampler(64, 128)
        self.downsampler3 = Downsampler(128, 256)
        self.downsampler4 = Downsampler(256, 512)
        self.downsampler5 = Downsampler(512, 1024, pooling = False)
        
        self.upsampler1 = Upsampler(1024)
        self.upsampler2 = Upsampler(512)
        self.upsampler3 = Upsampler(256)
        self.upsampler4 = Upsampler(128)
        
        self.final_conv = torch.nn.Conv2d(64, 3, 3)  # 3 channels for seg. maps of stomach, large and small bowel
        self.final_act = torch.nn.Sigmoid()
        
    def copy_crop(self, X, shape):
        top = (X.shape[2]-shape)//2  # as same as left
        return torchvision.transforms.functional.crop(X, top, top, shape, shape).clone()
        
    
    def forward(self, X):
        X = self.downsampler1(X)
        X_1 = X
        X = self.downsampler2(X)
        X_2 = X
        X = self.downsampler3(X)
        X_3 = X
        X = self.downsampler4(X)
        X_4 = X
        X = self.downsampler5(X)
        
        X = self.upsampler1(X, self.copy_crop(X_4, 48))
        X = self.upsampler2(X, self.copy_crop(X_3, 88))
        X = self.upsampler3(X, self.copy_crop(X_2, 168))
        X = self.upsampler4(X, self.copy_crop(X_1, 328))
        
        X = self.final_conv(X)
        return self.final_act(X)

In [None]:
def batch_loader(data, batch_size, id_pics_matching):
    ind = 0
    data_keys = list(data_full.keys())
    while ind + batch_size < len(data_keys):
        X = torch.zeros([batch_size, 1, 572, 572])
        y = torch.zeros([batch_size, 3, 572, 572])
        
        for i in range(batch_size):
            X_new = torchvision.transforms.ToTensor()(Image.open(id_pics_matching[data_keys[ind+i]])).type(torch.float32)/(2**16)
            X[i] = torchvision.transforms.Resize([572, 572])(X_new)
            
            for j in range(3):            
                y_new = decode_rle(X_new, data_full[data_keys[ind+i]][j]).type(torch.float32)/(2**16)
                y[i][j] = torchvision.transforms.Resize([572, 572])(y_new)
        ind += batch_size
        
        yield X, y

In [None]:
def batch_weights(y, n=1):
    res = torch.log(n*torch.sqrt(y)+1)+0.2
    return res

In [None]:
torch.cuda.empty_cache()

In [None]:
batch_size = 2
epochs = 0
lr = 2e-4

model = move_to(UNet(), device)
opt = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
opt.lr = 0.5e-4

In [None]:
backups = []
backup_losses = []
backup_rate = 500

In [None]:
av_loss = 0
n_iters_loss = 100
cnt = 0


for epoch in range(epochs*10):
    for X, y in batch_loader(data, batch_size, id_pics_matching):
        X, y = move_to(X, device), move_to(y, device)
        out = torchvision.transforms.Resize([572, 572])(model(X))
        
        loss = torch.nn.BCELoss(weight = move_to(batch_weights(y, 3), device))(out, y)
        opt.zero_grad()
        
        loss.backward()
    
        opt.step()
        av_loss += loss.detach().cpu().numpy()
        if cnt % n_iters_loss == 0 and cnt > 0:
            print(cnt, av_loss/n_iters_loss)
            av_loss = 0
            
        if cnt % backup_rate == 0 and cnt > 0:
            backups.append(model.state_dict())
            backup_losses.append(loss.detach().cpu().numpy())
            print(">Here d made backup", len(backups)-1)
        cnt += 1
        torch.cuda.empty_cache()

In [None]:
len(data_full.keys())

In [None]:
#look(model(torchvision.transforms.Resize([572, 572])(move_to(img, device)).view(1, 1, 572, 572).type(torch.float32)/(2**16))[0])

In [None]:
#look(out[0][0:1])

In [None]:
class Downsampler_XL(torch.nn.Module):
    def __init__(self, in_channels, out_channels, pooling = True):
        super().__init__()
        self.pooling = pooling
        
        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, 3)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, 3)
        self.conv3 = torch.nn.Conv2d(out_channels, out_channels, 3)
        self.conv4 = torch.nn.Conv2d(out_channels, out_channels, 3)
        self.act = torch.nn.ReLU()
        self.pool = torch.nn.MaxPool2d(2)
        
    def forward(self, X):
        X = self.act(self.conv1(X))
        X = self.act(self.conv2(X))
        X = self.act(self.conv3(X))
        if self.pooling:
            X = self.pool(X)
        return X

In [None]:
class Upsampler_XL(torch.nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.deconv = torch.nn.ConvTranspose2d(in_channels, in_channels//2, 2, 2)
        self.conv1 = torch.nn.Conv2d(in_channels, in_channels//2, 3)
        self.conv2 = torch.nn.Conv2d(in_channels//2, in_channels//2, 3)
        self.conv2 = torch.nn.Conv2d(in_channels//2, in_channels//2, 3)
        self.conv3 = torch.nn.Conv2d(in_channels//2, in_channels//2, 3)
        self.conv4 = torch.nn.Conv2d(in_channels//2, in_channels//2, 3)
        self.act = torch.nn.ReLU()
        
    def forward(self, X, X_cat):
        X = self.act(self.deconv(X))
        X = torch.cat([X, X_cat], axis=1)
        X = self.act(self.conv1(X))
        X = self.act(self.conv2(X))
        return X

UNET XL<br>
Больше каналов у сверток: основание 100 вместо 64<br>
На один больше Downsamler и Sampler<br>
Добавляем skip connection (для сверток через одну)

In [None]:
class UNet_XL(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.downsampler1 = Downsampler_XL(1, 64)
        self.downsampler2 = Downsampler_XL(64, 128)
        self.downsampler3 = Downsampler_XL(128, 256)
        self.downsampler4 = Downsampler_XL(256, 512)
        self.downsampler5 = Downsampler_XL(512, 1024)
        self.downsampler6 = Downsampler_XL(1024, 1024, pooling = False)
        
        self.upsampler1 = Upsampler_XL(1024)
        self.upsampler2 = Upsampler_XL(1024)
        self.upsampler3 = Upsampler_XL(512)
        self.upsampler4 = Upsampler_XL(256)
        self.upsampler5 = Upsampler_XL(128)
        
        self.final_conv = torch.nn.Conv2d(64, 3, 3)  # 3 channels for seg. maps of stomach, large and small bowel
        self.final_act = torch.nn.Sigmoid()
        
    def copy_crop(self, X, shape):
        top = (X.shape[2]-shape)//2  # as same as left
        return torchvision.transforms.functional.crop(X, top, top, shape, shape).clone()
        
    
    def forward(self, X):
        X = self.downsampler1(X)
        X_1 = X
        X = self.downsampler2(X)
        X_2 = X
        X = self.downsampler3(X)
        X_3 = X
        X = self.downsampler4(X)
        X_4 = X
        X = self.downsampler5(X)
        #X_5 = X
        #X = self.downsampler5(X)
        
        #X = self.upsampler1(X, self.copy_crop(X_5, 48))
        X = self.upsampler1(X, self.copy_crop(X_4, 48))
        X = self.upsampler2(X, self.copy_crop(X_3, 88))
        X = self.upsampler3(X, self.copy_crop(X_2, 168))
        X = self.upsampler4(X, self.copy_crop(X_1, 328))
        
        X = self.final_conv(X)
        return self.final_act(X)

In [None]:
def n_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
n_params(model)

In [None]:
n_params(UNet_XL())

# 