[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qbeer/patho-segmentation-unet/blob/master/segmentation.ipynb)

In [1]:
!git clone https://github.com/qbeer/patho-segmentation-unet
!cp -a patho-segmentation-unet/* .

Cloning into 'patho-segmentation-unet'...
remote: Enumerating objects: 545, done.[K
remote: Counting objects: 100% (545/545), done.[K
remote: Compressing objects: 100% (525/525), done.[K
remote: Total 845 (delta 27), reused 535 (delta 20), pack-reused 300
Receiving objects: 100% (845/845), 48.25 MiB | 15.83 MiB/s, done.
Resolving deltas: 100% (185/185), done.


In [0]:
from patho import Model, UNET
import os
import cv2
from torchvision import transforms, datasets
import torch
import numpy as np
from PIL import Image

In [0]:
net = UNET()
model = Model(net)

In [0]:
class DigestPathDataset(object):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        self.imgs = list(sorted(os.listdir(os.path.join(root, "resized_images"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "resized_masks"))))

    def __getitem__(self, idx):
        # load images ad masks
        img_path = os.path.join(self.root, "resized_images", self.imgs[idx])
        mask_path = os.path.join(self.root, "resized_masks", self.masks[idx])
        img = Image.open(img_path).convert("RGB")
        # note that we haven't converted the mask to RGB,
        # because each color corresponds to a different instance
        # with 0 being background
        mask = Image.open(mask_path).convert("L")
        # convert the PIL Image into a numpy array
        mask = np.array(mask).reshape(1, 388, 388)
        mask = torch.as_tensor(mask, dtype=torch.float32)
        
        img = np.array(img).reshape(3, 572, 572) / 255.
        img = torch.as_tensor(img, dtype=torch.float32)

        if self.transforms is not None:
            img, mask = self.transforms(img, mask)

        return img, mask

    def __len__(self):
        return len(self.imgs)

In [0]:
dataset = DigestPathDataset('data', None)

indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices)

data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=3, shuffle=True, num_workers=1)

In [0]:
model.train(data_loader, EPOCH=15)

[1,    10] loss: -1084.482
[1,    20] loss: -965.835
[1,    30] loss: -659.558
[1,    40] loss: -1109.700
[1,    50] loss: -954.255
[1,    60] loss: -935.493
[1,    70] loss: -933.014
[1,    80] loss: -917.600
[2,    10] loss: -879.052
[2,    20] loss: -1092.183
[2,    30] loss: -858.566
[2,    40] loss: -839.545
[2,    50] loss: -1047.354
[2,    60] loss: -1125.781
[2,    70] loss: -990.331
[2,    80] loss: -844.833
[3,    10] loss: -941.994
[3,    20] loss: -1150.375
[3,    30] loss: -773.399
[3,    40] loss: -824.213
[3,    50] loss: -877.181
[3,    60] loss: -1012.367
[3,    70] loss: -974.258
[3,    80] loss: -1032.250
[4,    10] loss: -1144.589
[4,    20] loss: -812.298
[4,    30] loss: -996.895
[4,    40] loss: -973.395
[4,    50] loss: -1008.895
[4,    60] loss: -963.186
[4,    70] loss: -952.456
[4,    80] loss: -703.682
[5,    10] loss: -909.557
[5,    20] loss: -1012.984
[5,    30] loss: -999.090
[5,    40] loss: -1039.811
[5,    50] loss: -699.950
[5,    60] loss: -773.254
