In [None]:
import os
import numpy as np
import glob
import cv2
import random
import pathlib
import matplotlib.pyplot as plt
from pprint import pprint
import torchinfo
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import transforms 
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision
from PIL import Image

In [None]:
from tqdm import tqdm

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
!nvidia-smi

In [None]:
lr = 1e-4
batch_size = 32
img_height = 256
img_width = 256
pin_memory = True

train_img_dir = r'C:/Users/utkar/Desktop/ML/pytorch/image segmentation/SkinCancerDataset/train/images'
train_mask_dir = r'C:/Users/utkar/Desktop/ML/pytorch/image segmentation/SkinCancerDataset/train/mask'
val_img_dir = r'C:/Users/utkar/Desktop/ML/pytorch/image segmentation/SkinCancerDataset/test/images'
val_mask_dir = r'C:/Users/utkar/Desktop/ML/pytorch/image segmentation/SkinCancerDataset/test/mask'

In [None]:
# dataset source: https://www.kaggle.com/datasets/surajghuwalewala/ham1000-segmentation-and-classification

class ISDataset(Dataset):
    def __init__(self, img_dir, mask_dir, img_transform, mask_transform):
        super(ISDataset, self)
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_transform = img_transform
        self.mask_transform = mask_transform
        self.img = os.listdir(img_dir)
        
    def __len__(self):
        return len(self.img)
    
    def __getitem__(self, index):
        img_path = os.path.join(self.img_dir, self.img[index])
        mask_path = os.path.join(self.mask_dir, self.img[index].replace('.jpg', '_mask.png'))
        img = np.array(Image.open(img_path).convert('RGB'))
        mask = np.array(Image.open(mask_path).convert('L'))
        
        img = self.img_transform(img)
        mask = self.mask_transform(mask)
            
        return img, mask

In [None]:
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256, 256)),
    transforms.Normalize([0, 0, 0], [1, 1, 1]),
])

mask_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256, 256)),
    transforms.Normalize([0], [1]),
])

train_ds = ISDataset(
    img_dir=train_img_dir,
    mask_dir=train_mask_dir,
    img_transform=img_transform,
    mask_transform=mask_transform
)

val_ds = ISDataset(
    img_dir=val_img_dir,
    mask_dir=val_mask_dir,
    img_transform=img_transform,
    mask_transform=mask_transform
)

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    pin_memory=pin_memory,
    shuffle=True
)

val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    pin_memory=pin_memory,
    shuffle=False
)

In [None]:
def bottleneck(in_c, out_c, regular=False, dilated=False, dilation=None, asymm=False, down_sampling=False, up_sampling=False):
    if down_sampling:
        conv = nn.Sequential(
            nn.Conv2d(in_c, in_c//2, kernel_size=(2, 2), stride=2),
            nn.BatchNorm2d(in_c//2),
            nn.PReLU(),  
            nn.Conv2d(in_c//2, in_c//2, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(in_c//2),
            nn.PReLU(),
            nn.Conv2d(in_c//2, in_c*(out_c//in_c-1), kernel_size=(1, 1)),
            nn.BatchNorm2d(in_c*(out_c//in_c-1))
        )
        conv1 = nn.Sequential(nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), return_indices=True))
        
        return conv, conv1
        
    if up_sampling:
        conv = nn.Sequential(
            nn.ConvTranspose2d(in_c, in_c//2, kernel_size=(2, 2), stride=(2, 2)),
            nn.BatchNorm2d(in_c//2),
            nn.PReLU(),
            nn.Conv2d(in_c//2, in_c//2, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(in_c//2),
            nn.PReLU(),
            nn.Conv2d(in_c//2, out_c, kernel_size=(1, 1)),
            nn.BatchNorm2d(out_c)
        )
        conv1 = nn.Sequential(nn.MaxUnpool2d(kernel_size=(2, 2), stride=(2, 2)))
        return conv#, conv1
        
    elif regular:
        conv = nn.Sequential(
            nn.Conv2d(in_c, in_c//2, kernel_size=(1, 1)),
            nn.BatchNorm2d(in_c//2),
            nn.PReLU(),
            nn.Conv2d(in_c//2, in_c//2, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(in_c//2),
            nn.PReLU(),
            nn.Conv2d(in_c//2, out_c, kernel_size=(1, 1)),
            nn.BatchNorm2d(out_c)
        )
        return conv
    elif dilated:
        conv = nn.Sequential(
            nn.Conv2d(in_c, in_c//2, kernel_size=(1, 1)),
            nn.BatchNorm2d(in_c//2),
            nn.PReLU(),
            nn.Conv2d(in_c//2, in_c//2, kernel_size=(3, 3), dilation=dilation, padding=dilation),
            nn.BatchNorm2d(in_c//2),
            nn.PReLU(),
            nn.Conv2d(in_c//2, out_c, kernel_size=(1, 1)),
            nn.BatchNorm2d(out_c)
        )
        return conv
        
    elif asymm:
        conv = nn.Sequential(
            nn.Conv2d(in_c, in_c//2, kernel_size=(1, 1)),
            nn.BatchNorm2d(in_c//2),
            nn.PReLU(),
            nn.Conv2d (in_c//2, in_c//2, (1, 5), bias = False, padding=(0, 2)),
            nn.BatchNorm2d(in_c//2),
            nn.PReLU(),
            nn.Conv2d (in_c//2, in_c//2, (5, 1), bias = False, padding=(2, 0)),
            nn.BatchNorm2d(in_c//2),
            nn.PReLU(),
            nn.Conv2d(in_c//2, out_c, kernel_size=(1, 1)),
            nn.BatchNorm2d(out_c)
        )
        return conv


In [None]:
def conv1x1(in_c, out_c):
    conv = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size=(1, 1)))
    return conv

In [None]:
class ENET(nn.Module):
    def __init__(self):
        super(ENET, self).__init__()
        
        self.initial1 = nn.Conv2d(3, 13, kernel_size=(3, 3), stride=2, padding=1)
        self.initial2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.bn10_1, self.bn10_2 = bottleneck(16, 64, down_sampling=True)   ### 2
        self.conv1x1_10 = conv1x1(64, 64)
        self.bn11 = bottleneck(64, 64, regular=True)
        self.conv1x1_11 = conv1x1(128, 64)
        self.bn12 = bottleneck(64, 64, regular=True)
        self.conv1x1_12 = conv1x1(128, 64)
        self.bn13 = bottleneck(64, 64, regular=True)
        self.conv1x1_13 = conv1x1(128, 64)
        self.bn14 = bottleneck(64, 64, regular=True)
        self.conv1x1_14 = conv1x1(128, 64)
        
        self.bn20_1, self.bn20_2 = bottleneck(64, 128, down_sampling=True)  ### 2
        self.conv1x1_20 = conv1x1(128, 128)
        self.bn21 = bottleneck(128, 128, regular=True)
        self.conv1x1_21 = conv1x1(256, 128)
        self.bn22 = bottleneck(128, 128, dilated=True, dilation=2)
        self.conv1x1_22 = conv1x1(256, 128)
        self.bn23 = bottleneck(128, 128, asymm=True)
        self.conv1x1_23 = conv1x1(256, 128)
        self.bn24 = bottleneck(128, 128, dilated=True, dilation=4)
        self.conv1x1_24 = conv1x1(256, 128)
        self.bn25 = bottleneck(128, 128, regular=True)
        self.conv1x1_25 = conv1x1(256, 128)
        self.bn26 = bottleneck(128, 128, dilated=True, dilation=8)
        self.conv1x1_26 = conv1x1(256, 128)
        self.bn27 = bottleneck(128, 128, asymm=True)
        self.conv1x1_27 = conv1x1(256, 128)
        self.bn28 = bottleneck(128, 128, dilated=True, dilation=8)
        self.conv1x1_28 = conv1x1(256, 128)

        self.bn31 = bottleneck(128, 128, regular=True)
        self.conv1x1_31 = conv1x1(256, 128)
        self.bn32 = bottleneck(128, 128, dilated=True, dilation=2)
        self.conv1x1_32 = conv1x1(256, 128)
        self.bn33 = bottleneck(128, 128, asymm=True)
        self.conv1x1_33 = conv1x1(256, 128)
        self.bn34 = bottleneck(128, 128, dilated=True, dilation=4)
        self.conv1x1_34 = conv1x1(256, 128)
        self.bn35 = bottleneck(128, 128, regular=True)
        self.conv1x1_35 = conv1x1(256, 128)
        self.bn36 = bottleneck(128, 128, dilated=True, dilation=8)
        self.conv1x1_36 = conv1x1(256, 128)
        self.bn37 = bottleneck(128, 128, asymm=True)
        self.conv1x1_37 = conv1x1(256, 128)
        self.bn38 = bottleneck(128, 128, dilated=True, dilation=8)
        self.conv1x1_38 = conv1x1(256, 128)
        
        self.bn40_1 = bottleneck(128, 64, up_sampling=True) # , self.bn40_2
        self.conv1x1_40 = conv1x1(64, 64)
        self.bn41 = bottleneck(64, 64, regular=True)
        self.conv1x1_41 = conv1x1(128, 64)
        self.bn42 = bottleneck(64, 64, regular=True)
        self.conv1x1_42 = conv1x1(128, 64)
        
        self.bn50_1 = bottleneck(64, 16, up_sampling=True) # , self.bn50_2
        self.conv1x1_50 = conv1x1(16, 16)
        self.bn51 = bottleneck(16, 16, regular=True)
        self.conv1x1_51 = conv1x1(32, 16)
        
        self.out = nn.ConvTranspose2d(16, 1, kernel_size=(2, 2), stride=(2, 2))
        
    def forward(self, img):
        x_i1 = self.initial1(img)
        x_i2 = self.initial2(img)
        #print(x_i1.shape, x_i2.shape)
        x_i = torch.concat([x_i1, x_i2], 1)
        
        x10_1 = self.bn10_1(x_i)
        x10_2, indices_10 = self.bn10_2(x_i)      # maxpooling
        #print(x10_2.shape, indices_10.shape)
        x10 = self.conv1x1_10(torch.concat([x10_1, x10_2], 1))
        x = self.bn11(x10)
        x11 = self.conv1x1_11(torch.concat([x, x10], 1))
        x = self.bn12(x11)
        x12 = self.conv1x1_12(torch.concat([x, x11], 1))
        x = self.bn13(x12)
        x13 = self.conv1x1_13(torch.concat([x, x12], 1))
        x = self.bn14(x13)
        x14 = self.conv1x1_14(torch.concat([x, x13], 1))
        
        x20_1 = self.bn20_1(x14)
        x20_2, indices_20 = self.bn20_2(x14)     # maxpooling
        #print(indices_20.shape)
        x20 = self.conv1x1_20(torch.concat([x20_1, x20_2], 1))
        x = self.bn21(x20)
        x21 = self.conv1x1_21(torch.concat([x, x20], 1))
        x = self.bn22(x21)
        #print(x.shape, x21.shape)
        x22 = self.conv1x1_22(torch.concat([x, x21], 1))
        x = self.bn23(x22)
        #print(x.shape, x22.shape)
        x23 = self.conv1x1_23(torch.concat([x, x22], 1))
        x = self.bn24(x23)
        x24 = self.conv1x1_24(torch.concat([x, x23], 1))
        x = self.bn25(x24)
        x25 = self.conv1x1_25(torch.concat([x, x24], 1))
        x = self.bn26(x25)
        x26 = self.conv1x1_26(torch.concat([x, x25], 1))
        x = self.bn27(x26)
        x27 = self.conv1x1_27(torch.concat([x, x26], 1))
        x = self.bn28(x27)
        x28 = self.conv1x1_28(torch.concat([x, x27], 1))
        
        x = self.bn31(x28)
        x31 = self.conv1x1_31(torch.concat([x, x28], 1))
        x = self.bn32(x31)
        x32 = self.conv1x1_32(torch.concat([x, x31], 1))
        x = self.bn33(x32)
        x33 = self.conv1x1_33(torch.concat([x, x32], 1))
        x = self.bn34(x33)
        x34 = self.conv1x1_34(torch.concat([x, x33], 1))
        x = self.bn35(x34)
        x35 = self.conv1x1_35(torch.concat([x, x34], 1))
        x = self.bn36(x35)
        x36 = self.conv1x1_36(torch.concat([x, x35], 1))
        x = self.bn37(x36)
        x37 = self.conv1x1_37(torch.concat([x, x36], 1))
        x = self.bn38(x37)
        x38 = self.conv1x1_38(torch.concat([x, x37], 1))
        
        x40_1 = self.bn40_1(x38)
        #print(x40_1.shape)
        x40_1 = self.conv1x1_40(x40_1)
        #print(indices_20)
        #x40_2 = self.bn40_2(x38)   # add x38
        #x40 = self.conv1x1_40(torch.concat([x40_1, x40_2], 1))
        x = self.bn41(x40_1)
        x41 = self.conv1x1_41(torch.concat([x, x40_1], 1))
        x = self.bn42(x41)
        x42 = self.conv1x1_42(torch.concat([x, x41], 1))
        
        x50_1 = self.bn50_1(x42)
        #print(x50_1.shape, x42.shape)
        x50_1 = self.conv1x1_50(x50_1)
        #print(x50_1.shape)
        #x50_2 = self.bn50_2(x42)
        #x50 = self.conv1x1_50(torch.concat([x50_1, x50_2], 1))
        x = self.bn51(x50_1)
        x51 = self.conv1x1_51(torch.concat([x, x50_1], 1))
        
        x_out = self.out(x51)
        return x_out
        

In [None]:
model = ENET()
print(model)

In [None]:
pprint(torchinfo.summary(model, input_size=(1, 3, 256, 256)))

In [None]:
x = torch.rand((1, 3, 512, 512))
y = model(x.to(device))
print(y.shape)

In [None]:
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
def check_acc(loader, model, device='cuda'):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()
    
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)   # returns product of the values in tensor
            dice_score += (2*(preds*y).sum())/((preds + y).sum()+1e-8)
            
    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()
            

In [None]:
scaler = torch.cuda.amp.GradScaler()   # to avoid vanishing gradient problem

In [None]:
epochs = 5

for epoch in range(epochs):
    # training
    loop = tqdm(train_loader)
    
    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device)
        target = targets.float().to(device)
        
        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, target)
        
        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        # update tqdm loop
        loop.set_postfix(loss=loss.item())
        # checcking accuracy
        check_acc(val_loader, model, device=device)
    print(f'epoch {epoch}')
        

In [None]:
a = 0
b = np.random.randint(442, 443)
print(b)
for img, mask in val_ds:
    a += 1
    if a == b:
        img1 = img.unsqueeze(0).to(device)
        #rint(img1.shape)
        mask_m = model(img1)
        #rint(mask_m.shape)
        print('target')
        plt.imshow(mask.permute(1, 2, 0), cmap='gray')
        plt.show()
        mask_m = torch.squeeze(mask_m)
        print(mask_m.shape)
        mask_m = mask_m.unsqueeze(1).permute(0, 2, 1).cpu().detach().numpy()
        print('predicted')
        plt.imshow(mask_m, cmap='gray')
        plt.show()
        print('image')
        plt.imshow(img.permute(1, 2, 0))
        plt.show()
        break
