<a href="https://colab.research.google.com/github/zrghassabi/U-net-Segmentation/blob/main/UNetbasedSegmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#https://www.kaggle.com/suvooo/unet-pytorch
import numpy as np
import pandas as pd
import os, glob
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image

import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Function

from torchvision import utils
from torch import optim
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
# In case you need:
!pip install -U albumentations
# import albumentations 
# from albumentations.pytorch import ToTensorV2

In [None]:
#print("train data size: ", len(glob.glob("../train_masks/*")))
#data = pd.read_csv("../train_masks.csv")
#data

In [None]:
# use Dataset from kaggle directly
# https://www.analyticsvidhya.com/blog/2021/06/how-to-load-kaggle-datasets-directly-into-google-colab/
! pip install kaggle

In [None]:
#! mkdir ~/.kaggle
#/kaggle.json
! cp /kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json

In [None]:
!pip install --upgrade --force-reinstall --no-deps kaggle

In [None]:
#unzipping the files
!unzip '/content/carvana-image-masking-challenge.zip' -d './content/'



In [None]:
#https://www.kaggle.com/c/carvana-image-masking-challenge/data
#Accept Rules of competition before you doanlod data
!kaggle competitions download -c carvana-image-masking-challenge

In [None]:
#print("train data size: ", len(glob.glob("../train_masks/*")))
#data = pd.read_csv("../train_masks.csv")
#data

In [None]:
fig = plt.figure(figsize=(10,10))
img = np.array(Image.open("/kaggle/working/train/11acc40dc0ea_03.jpg"))
img_mask = np.array(Image.open("/kaggle/working/train_masks/11acc40dc0ea_03_mask.gif"))

plt.subplot(1, 2, 1)
plt.imshow(img)

plt.subplot(1, 2, 2)
plt.imshow(img_mask)

#Dataset

In [None]:
class CarvanaDataset(Dataset):
    def __init__(self, root_dir, train_img_list):
        super().__init__()
        self.img_dir = os.path.join(root_dir, "train")
        self.mask_dir = os.path.join(root_dir, "train_masks")
        self.img_list = train_img_list
        self.img_transform = A.Compose([
            A.Resize(256, 256),
            A.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
        self.mask_transform = A.Compose([
            A.Resize(256, 256),
            ToTensorV2()
        ])
    
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        img_abs_path = os.path.join(self.img_dir, self.img_list[idx])
        mask_abs_path = os.path.join(
                self.mask_dir, 
                self.img_list[idx].split(".")[0] + "_mask.gif")
    
        img = np.array(Image.open(img_abs_path))
        mask = np.array(Image.open(mask_abs_path))
        
        img = self.img_transform(image=img)["image"]
        mask = self.mask_transform(image=mask)["image"]
        
        return img, mask

In [None]:
train_img_list = pd.read_csv("/kaggle/working/train_masks.csv")['img']
dataset = CarvanaDataset("/kaggle/working", train_img_list)

train_size = int(len(train_img_list) * 0.8)
val_size = len(train_img_list) - train_size

train_set, val_set = torch.utils.data.random_split(
                    dataset, [train_size, val_size])

train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
val_loader = DataLoader(val_set, batch_size=8, shuffle=True)

#Model

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.main(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.main = nn.Sequential(
            DoubleConv(in_channels, out_channels),
            nn.MaxPool2d(2)
        )
    def forward(self, x):
        return self.main(x)
    
    
class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        mid_channels = in_channels // 2
        self.up_conv =  nn.ConvTranspose2d(in_channels, mid_channels, 
                               kernel_size=2, stride=2, padding=0)
        
        self.double_conv =  DoubleConv(mid_channels*2, out_channels)

        
    def forward(self, x, copy):
        x = self.up_conv(x)
        #pad_lower = (copy.size()[2] - x.size()[2]) // 2
        #pad_upper = copy.size()[2] - pad_lower
        #copy = copy[:, :, pad_lower:pad_upper, pad_lower:pad_upper]
        x = torch.cat([copy, x], dim=1)
        
        return self.double_conv(x)
    
class UNet(nn.Module):
    def __init__(self, img_channels, num_classes, features=64):
        super().__init__()
        self.max_pool = nn.MaxPool2d(2)
        
        self.dc1 = DoubleConv(img_channels, features)
        self.dc2 = DoubleConv(features, features*2)
        self.dc3 = DoubleConv(features*2, features*4)
        self.dc4 = DoubleConv(features*4, features*8)
        self.dc5 = DoubleConv(features*8, features*16)
        
        self.up1 = Up(features*16, features*8)
        self.up2 = Up(features*8, features*4)
        self.up3 = Up(features*4, features*2)
        self.up4 = Up(features*2, features)
        
        self.final = nn.Conv2d(features, num_classes, 1, 1, 0)
        
    def forward(self, x):
        #contracting path
        d1 = self.dc1(x)
        d2 = self.dc2(self.max_pool(d1))
        d3 = self.dc3(self.max_pool(d2))
        d4 = self.dc4(self.max_pool(d3))
        x = self.dc5(self.max_pool(d4)) #bottlenek
        
        #expansive path
        x = self.up1(x, d4)
        x = self.up2(x, d3)
        x = self.up3(x, d2)
        x = self.up4(x, d1)
        return self.final(x)
    
    

#Train and Test

In [None]:
class ImageSegment(object):
    def __init__(self, train_loader, val_loader, device):
        super().__init__()
        self.device = device
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.num_classes = 1
        self.img_channels = 3
        self.unet = UNet(self.img_channels, self.num_classes).to(device)
        self.optim = optim.RMSprop(
            self.unet.parameters(), 
            lr=1e-4, momentum=0.9, weight_decay=1e-8)
        
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optim, 
            "min" if self.num_classes > 1 else "max", 
            patience=2)
        
        self.criterion = nn.BCEWithLogitsLoss()

    def dice_calc(self, gt, pred) :
        pred = torch.sigmoid(pred)
        pred = ((pred) >= .5).float()
        dice_score = (2 * (pred * gt).sum()) / ((pred + gt).sum() + 1e-8)
        return dice_score
    
    def train(self, num_epochs=1):
        loop = tqdm(train_loader, leave=False, total=train_loader.__len__())
        total_loss = 0
        dice_score = 0
        
        for epoch in range(num_epochs):
            for img, mask in loop:
                img, mask = img.to(self.device), mask.to(self.device)

                self.optim.zero_grad()
                mask_pred = self.unet(img)
                loss = self.criterion(mask_pred, mask.float())
                total_loss += loss.item()
                loss.backward()
                self.optim.step()

                #buji na
                run_DS = self.dice_calc(mask, mask_pred)
                dice_score += run_DS

                loop.set_postfix(loss=loss.item())

            print("Epoch %d| loss: %f | dice score %f" % (epoch+1, total_loss, dice_score))
            
    def test(self):
        with torch.no_grad():
            images ,masks =next(iter(self.val_loader))
            images = images.to(self.device)
            masks  = masks.to(self.device)

            mask_pred = self.unet(images)

            img = mask_pred.cpu().numpy()
            masks = masks.cpu().numpy()
            masks_2 = (masks > 0.5).astype(int)

            fig, axes = plt.subplots(1, 3, figsize=(15, 15))

            axes[0].imshow(masks[0][0])
            axes[0].set_title('Ground Truth Mask')

            axes[1].imshow(img[0][0])
            axes[1].set_title('Prababilistic Mask')

            axes[2].imshow(masks_2[0][0])
            axes[2].set_title('Probabilistic Mask threshold')

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
img_seq = ImageSegment(train_loader, val_loader, device)
img_seq.train(num_epochs=7)
img_seq.test()