[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qbeer/patho-segmentation-unet/blob/master/segmentation.ipynb)

In [0]:
!git clone https://github.com/qbeer/patho-segmentation-unet
!cp -a patho-segmentation-unet/* .

In [0]:
from patho import Model, UNET
import os
import cv2
from torchvision import transforms, datasets
import torch
import numpy as np
from PIL import Image

In [0]:
net = UNET()
model = Model(net)

In [0]:
class DigestPathDataset(object):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        self.imgs = list(sorted(os.listdir(os.path.join(root, "resized_images"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "resized_masks"))))

    def __getitem__(self, idx):
        # load images ad masks
        img_path = os.path.join(self.root, "resized_images", self.imgs[idx])
        mask_path = os.path.join(self.root, "resized_masks", self.masks[idx])
        img = Image.open(img_path).convert("RGB")
        # note that we haven't converted the mask to RGB,
        # because each color corresponds to a different instance
        # with 0 being background
        mask = Image.open(mask_path).convert("L")
        # convert the PIL Image into a numpy array
        mask = np.array(mask).reshape(1, 388, 388)
        mask = torch.as_tensor(mask, dtype=torch.float32)
        
        img = np.array(img).reshape(3, 572, 572) / 255.
        img = torch.as_tensor(img, dtype=torch.float32)

        if self.transforms is not None:
            img, mask = self.transforms(img, mask)

        return img, mask

    def __len__(self):
        return len(self.imgs)

In [0]:
dataset = DigestPathDataset('data', None)

indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices)

data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=3, shuffle=True, num_workers=1)

In [0]:
model.train(data_loader, EPOCH=15)