In [None]:
import os
import random

import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from medpy.io import load
from scipy.ndimage import zoom
from torchvision import transforms
from torch.utils.data import random_split
import matplotlib.pyplot as plt
import tifffile as tiff
import torchgeometry as tgm
import torch.nn as nn
from scipy import ndimage
from PIL import Image
from skimage import transform
from sklearn.model_selection import train_test_split

In [None]:
class config:
    seed = 1717
    base_path = "./hubmap-organ-segmentation"
    batch_size = 4
    
    val_size = 0.25
    
    epoch = 30
    lr = 3e-3
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    rescaling_factor = 3
    
    crop_size = 256
    crop_amount = 8
    
    #base_path = "../input/hubmap-organ-segmentation" #UNCOMMIT FOR KAGGLE
    train_metadata = os.path.join(base_path, "train.csv")
    test_metadata = os.path.join(base_path, "test.csv")
    train_images = os.path.join(base_path, "train_images")
    test_images = os.path.join(base_path, "test_images")
    
    model_path = "model"
    
    random.seed(seed)
    torch.manual_seed(seed)
    

In [None]:
# https://www.kaggle.com/paulorzp/rle-functions-run-length-encode-decode
def mask2rle(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels= img.T.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)
 
def rle2mask(mask_rle, shape=(1600,256)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (width,height) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T

In [None]:
class RandomCrop(object):
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, data):
        image, mask = data["image"], data["mask"]
        
        h, w = image.shape[1:]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[:, top: top + new_h,
                      left: left + new_w]
        mask = mask[:, top: top + new_h,
                      left: left + new_w]

        return {"image": image, "mask": mask}

In [None]:
class BaselineDataset(Dataset):
    def __init__(
        self, root_dir, metadata,
        image_transforms=None, mask_transforms=None,
        rescaling_factor = config.rescaling_factor,
        crop_size = 1, crop_amount = 1, val = False
    ):
        self.root_dir = root_dir
        self.image_transforms = image_transforms
        self.mask_transforms = mask_transforms
        self.crop = RandomCrop(crop_size)
        self.crop_amount = crop_amount
        self.val = val
        self.new_voxel_spacing = np.array(
            [
                3,
                0.171 * rescaling_factor,
                0.171 * rescaling_factor
            ]
        )
        cur_ids = set(metadata["id"])
        files = [
            x for x in os.listdir(self.root_dir) if int(
                x.split(".")[0]
            ) in cur_ids
        ]
        self.idx2name = {
            x: os.path.join(self.root_dir, y)
            for x, y in enumerate(files)
        }
        self.name2idx = {
            int(y.split(".")[1].split("/")[-1]): x
            for x, y in self.idx2name.items()
        }
        self.idx2meta = {
            self.name2idx[x[1]["id"]]: {
                "mask": None,
                "pixel_size": x[1]["pixel_size"],
                "tissue_thickness": x[1]["tissue_thickness"],
            }
            for x in metadata.iterrows()
        }
        for x in metadata.iterrows():
            if x[1].get("rle") is not None:
                self.idx2meta[self.name2idx[x[1]["id"]]]["mask"] = rle2mask(
                    x[1]["rle"],
                    (x[1]["img_width"], x[1]["img_height"])
                )
    
    def get_voxel_size(self, idx):
        image = load(self.idx2name[idx])
        return np.array(image[1].get_voxel_spacing())
    
    def zoom(self, data, zoom_scale):
        image, mask = data['image'], data['mask']

        h, w = image.shape[:2]
        h *= zoom_scale
        w *= zoom_scale
        img = transform.resize(image, (h, w))
        mask = transform.resize(mask, (h, w))
        return {'image': img, 'mask': mask}
    
    def to_image(self, data):
        return {
            'image': Image.fromarray((data['image'] * 255).astype(np.uint8)),
            'mask': Image.fromarray((data['mask'] * 255).astype(np.uint8))
        }
    
    def __len__(self):
        return len(self.idx2name) * self.crop_amount if not self.val else len(self.idx2name)
    
    def __getitem__(self, idx_):
        idx = idx_ // self.crop_amount
        image = tiff.imread(self.idx2name[idx])
        mask = self.idx2meta[idx]["mask"]
        #TODO Use pixel_size and tissue_thickness here
        ans = {"image": image, "mask": mask}
        ans = self.zoom(ans, self.idx2meta[idx]["pixel_size"])
        ans = self.to_image(ans)
        
        if self.image_transforms:
            ans["image"] = self.image_transforms(ans["image"])
        if self.mask_transforms:
            ans["mask"] = self.mask_transforms(ans["mask"])
        
        return self.crop(ans) if not self.val else ans

In [None]:
class BaselineModel(nn.Module):
    def __init__(self):
        super().__init__()
        #self.backbone = torch.hub.load(
        #    'mateuszbuda/brain-segmentation-pytorch', 'unet',
        #    in_channels=3, out_channels=1, init_features=32, pretrained=False
        #)
        self.backbone = torch.hub.load(
            'milesial/Pytorch-UNet', 'unet_carvana', pretrained=True, scale=0.5
        )
    
    def forward(self, x):
        ans = self.backbone(x)
        return ans

In [None]:
def dice(logit, target):
    pred = torch.sigmoid(logit)
    
    sum_dims = list(range(1, target.dim()))
    dice = 2 * torch.sum(pred * target, dim = sum_dims) / torch.sum(pred ** 2 + target ** 2, dim = sum_dims)
    loss = 1 - dice
    return loss.mean()

In [None]:
df_train_source = pd.read_csv(config.train_metadata)
df_train = df_train_source.sample(frac=1, random_state=config.seed).reset_index(drop=True)[:300]
df_val = df_train_source.sample(frac=1, random_state=config.seed).reset_index(drop=True)[300:]

df_test = pd.read_csv(config.test_metadata)
df_train.head()

In [None]:
df_train_source

In [None]:
general_transforms = [
    transforms.ToTensor(),
]
image_transforms = transforms.Compose(
    general_transforms + [
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ]
)
mask_transforms = transforms.Compose(general_transforms)

In [None]:
train_dataset = BaselineDataset(
    config.train_images, df_train,
    image_transforms = image_transforms,
    mask_transforms = mask_transforms,
    crop_size = config.crop_size,
    crop_amount = config.crop_amount,
    val = False
)
val_dataset = BaselineDataset(
    config.train_images, df_val,
    image_transforms = image_transforms,
    mask_transforms = mask_transforms,
    val = True
)
test_dataset = BaselineDataset(
    config.test_images, df_test,
    image_transforms = image_transforms,
    mask_transforms = mask_transforms
)

In [None]:
train_dataloader = DataLoader(
    train_dataset, batch_size=config.batch_size,
    shuffle=True
)
val_dataloader = DataLoader(
    val_dataset, batch_size=1,
    shuffle=True
)
test_dataloader = DataLoader(
    test_dataset, batch_size=config.batch_size,
    shuffle=False
)

In [None]:
model = BaselineModel()
model = model.to(config.device)

optimizer = torch.optim.Adam(
    model.parameters(),
    config.lr
)

In [None]:
def val(model, loss_fn, val_dataloader, device):
    model.eval()
    val_loss = []
    with torch.no_grad():
        for x in val_dataloader:
            mask_pred = model(x["image"].to(device))
            loss = loss_fn(
                mask_pred, 
                x["mask"].to(device)
            )
            
            val_loss.append(loss.item())
    return sum(val_loss) / len(val_loss)

In [None]:
def train(model, optimizer, loss_fn, train_dataloader, val_dataloader, epochs, device):
    for epoch in range(epochs):
        train_loss = []
        model.train()
        for step, x in enumerate(train_dataloader):
            model.zero_grad()
            
            mask_pred = model(x["image"].to(device))
            loss = loss_fn(
                mask_pred,
                x["mask"].to(device)
            )

            loss.backward()
            optimizer.step()
            
            train_loss.append(loss.item())
            #print(f"Step {step} Train loss: {sum(train_loss) / len(train_loss)}")
        val_loss = val(model, loss_fn, val_dataloader, device)
        
        print(f"Train loss: {sum(train_loss) / len(train_loss)}")
        print(f"Val loss: {val_loss}")
        torch.save(model.state_dict(), os.path.join(config.model_path, f"model_{epoch}"))

In [None]:
train(model, optimizer, dice, train_dataloader, val_dataloader, config.epoch, config.device)

In [None]:
train_dataset.idx2name