In [None]:
 !pip install segmentation-models-pytorch


Collecting segmentation-models-pytorch
[?25l  Downloading https://files.pythonhosted.org/packages/65/54/8953f9f7ee9d451b0f3be8d635aa3a654579abf898d17502a090efe1155a/segmentation_models_pytorch-0.1.3-py3-none-any.whl (66kB)
[K     |████████████████████████████████| 71kB 2.9MB/s 
[?25hCollecting timm==0.3.2
[?25l  Downloading https://files.pythonhosted.org/packages/51/2d/39ecc56fbb202e1891c317e8e44667299bc3b0762ea2ed6aaaa2c2f6613c/timm-0.3.2-py3-none-any.whl (244kB)
[K     |████████████████████████████████| 245kB 4.3MB/s 
Collecting efficientnet-pytorch==0.6.3
  Downloading https://files.pythonhosted.org/packages/b8/cb/0309a6e3d404862ae4bc017f89645cf150ac94c14c88ef81d215c8e52925/efficientnet_pytorch-0.6.3.tar.gz
Collecting pretrainedmodels==0.7.4
[?25l  Downloading https://files.pythonhosted.org/packages/84/0e/be6a0e58447ac16c938799d49bfb5fb7a80ac35e137547fc6cee2c08c4cf/pretrainedmodels-0.7.4.tar.gz (58kB)
[K     |████████████████████████████████| 61kB 4.4MB/s 
Collecting munch
  

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch.nn as nn
import logging
import cv2
import os
from torch.utils.data import DataLoader, random_split
from torch import optim
from tqdm import tqdm

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import zipfile
zip_ref = zipfile.ZipFile("/content/drive/MyDrive/HT/histology_dataset.zip", 'r')
zip_ref.extractall("/content/drive/MyDrive/HT")
zip_ref.close()

In [None]:
class histologyDataset(Dataset):
    def __init__(self, imgs_dir, gt_dir, augs=None):
        self.imgs_dir = imgs_dir
        self.masks_dir = gt_dir
        self.num_classes = len(os.listdir(gt_dir))
        self.gt_classes = os.listdir(gt_dir)
        self.gt_classes.sort()
        self.im_names = os.listdir(self.imgs_dir)
        self.augs = augs

    def __len__(self):
        return len(self.im_names)

    def __getitem__(self, idx):
        im_name = self.im_names[idx]
        gt_names = [os.path.join(self.masks_dir, gt_class, im_name) for gt_class in self.gt_classes]
        
        img = np.expand_dims(plt.imread(os.path.join(self.imgs_dir, im_name), 0), axis=0)
        mask = np.array([plt.imread(i, 0) for i in gt_names])

        return {
            'image': torch.from_numpy(img).type(torch.FloatTensor),
            'mask': torch.from_numpy(mask).type(torch.FloatTensor)
        }

In [None]:
dataset = histologyDataset("./histology_dataset/train/images/", "./histology_dataset/train/GT/")
plt.imshow(dataset[0]['mask'].numpy())

FileNotFoundError: ignored

In [None]:
 import segmentation_models_pytorch as smp
 #import collections.abc as container_abcs

model = smp.FPN(
    encoder_name="resnet34",       
    encoder_weights="imagenet",     
    in_channels=1,                  
    classes=9,                      
)

In [None]:
from segmentation_models_pytorch.encoders import get_preprocessing_fn

preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')

In [None]:
from typing import Optional, Union
from segmentation_models_pytorch.fpn.decoder import FPNDecoder
from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead
from segmentation_models_pytorch.encoders import get_encoder


class FPN(SegmentationModel):

    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_pyramid_channels: int = 256,
        decoder_segmentation_channels: int = 128,
        decoder_merge_policy: str = "add",
        decoder_dropout: float = 0.2,
        in_channels: int = 1,
        classes: int = 9,
        activation: Optional[str] = "softmax",
        upsampling: int = 4,
        aux_params: Optional[dict] = None,
    ):
        super().__init__()

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
        )

        self.decoder = FPNDecoder(
            encoder_channels=self.encoder.out_channels,
            encoder_depth=encoder_depth,
            pyramid_channels=decoder_pyramid_channels,
            segmentation_channels=decoder_segmentation_channels,
            dropout=decoder_dropout,
            merge_policy=decoder_merge_policy,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=self.decoder.out_channels,
            out_channels=classes,
            activation=activation,
            kernel_size=1,
            upsampling=upsampling,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(
                in_channels=self.encoder.out_channels[-1], **aux_params
            )
        else:
            self.classification_head = None

        self.name = "fpn-{}".format(encoder_name)
        self.initialize()



In [None]:
VAL_PERCENT = 0.1
EPOCHS = 100
BATCH_SIZE = 16
LR = 0.1

dataset = histologyDataset("./histology_dataset/train/images/", "./histology_dataset/train/GT/")
n_val = int(len(dataset) * VAL_PERCENT)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)

model = FPN()
criterion = nn.DiceLoss()
optimizer = optim.RMSprop(model.parameters(), lr=LR, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    
    print(f"{epoch}/{EPOCHS}")
    for batch in tqdm(train_loader):
        imgs = batch['image']
        true_masks = batch['mask']
        pred_mask = model(imgs)
        loss = criterion(pred_mask, true_masks)
        epoch_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_value_(model.parameters(), 0.1)
        optimizer.step()

FileNotFoundError: ignored