In [6]:
#import modules
import argparse
import os
import random
from tkinter.tix import IMAGE
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split

# Device
CUDA_DEVICE_NUM = 0
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)
print(torch.cuda.is_available())

# Hyperparameters
LEARNING_RATE = 0.0001
TRAINDATA = "ISIC/ISIC-2017_Training_Data"
TESTDATA = "ISIC/ISIC-2017_Test_v2_Data"
VALIDDATA = "ISIC/ISIC-2017_Validation_Data"

NUM_EPOCHS = 5
BATCH_SIZE = 4
WORKERS = 4


Device: cuda:0
True


In [19]:

train_dataset = dset.ImageFolder(root=TRAINDATA,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE,
                                         shuffle=True, num_workers=WORKERS)

test_dataset = dset.ImageFolder(root=TESTDATA,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE,
                                         shuffle=True, num_workers=WORKERS)

valid_dataset = dset.ImageFolder(root=VALIDDATA,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
test_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=BATCH_SIZE,
                                         shuffle=True, num_workers=WORKERS)


In [20]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, num_filters):
        super(ConvBlock, self).__init__()
        
        # 2D convolution layer
        self.conv = nn.Conv2d(in_channels, num_filters, kernel_size=3, padding=1)
        
        # Batch normalization
        self.bn = nn.BatchNorm2d(num_filters)
        
        # No need to explicitly define ReLU activation here, as we can use the functional API provided by PyTorch
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)
        return x
    
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv = double_conv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(2)
        
    def forward(self, x):
        x1 = self.conv(x)
        x2 = self.pool(x1)
        return x1, x2

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels):
        super(DecoderBlock, self).__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.concat = torch.cat
        self.conv = double_conv(out_channels + skip_channels, out_channels) # +skip_channels because of concatenation
        
    def forward(self, x, skip_features):
        x = self.upconv(x)
        x = torch.cat([x, skip_features], dim=1) # concatenate along the channels dimension
        x = self.conv(x)
        return x


In [21]:




class UNet(nn.Module):
    def __init__(self, input_channels):
        super(UNet, self).__init__()
        
        # Encoder
        self.enc1 = double_conv(input_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = double_conv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = double_conv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = double_conv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        
        # Bridge
        self.bridge = double_conv(512, 1024)
        
        # Decoder
        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec1 = double_conv(1024, 512)
        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec2 = double_conv(512, 256)
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = double_conv(256, 128)
        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec4 = double_conv(128, 64)
        
        # Final Layer
        self.final = nn.Conv2d(64, 1, kernel_size=1)
        
    def forward(self, x):
        # Encoder
        s1 = self.enc1(x)
        p1 = self.pool1(s1)
        s2 = self.enc2(p1)
        p2 = self.pool2(s2)
        s3 = self.enc3(p2)
        p3 = self.pool3(s3)
        s4 = self.enc4(p3)
        p4 = self.pool4(s4)
        
        # Bridge
        b1 = self.bridge(p4)
        
        # Decoder
        d1 = self.up1(b1)
        d1 = torch.cat((s4, d1), dim=1)
        d1 = self.dec1(d1)
        d2 = self.up2(d1)
        d2 = torch.cat((s3, d2), dim=1)
        d2 = self.dec2(d2)
        d3 = self.up3(d2)
        d3 = torch.cat((s2, d3), dim=1)
        d3 = self.dec3(d3)
        d4 = self.up4(d3)
        d4 = torch.cat((s1, d4), dim=1)
        d4 = self.dec4(d4)
        out = self.final(d4)
        
        return torch.sigmoid(out)


model = UNet(input_channels=3) 

TRAIN 