In [None]:
!pip install pydicom -q

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.1/1.8 MB[0m [31m3.7 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.8/1.8 MB[0m [31m34.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m26.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
from torch import nn, optim
import time

import torch.utils.data as data
from torch.utils.data import DataLoader
import pydicom
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
from glob import glob
import re
from tqdm.notebook import tqdm_notebook

def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')

device = get_default_device()
device

device(type='cuda')

In [None]:
class CTDataset(data.Dataset):
    def __init__(self, base_path, trainset=True):
        super(CTDataset, self).__init__()

        directories = sorted([d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))])
        split_idx = int(0.8 * len(directories))
        if trainset:
            directories = directories[:split_idx]
        else:
            directories = directories[split_idx:]

        self.dcm_files = []
        self.label_files = []

        for directory in directories:
            folder_path = os.path.join(base_path, directory)
            self.dcm_folder = os.path.join(folder_path, 'DICOM_anon')
            self.label_folder = os.path.join(folder_path, 'Ground')

            sorted_dcm_files = sorted(os.listdir(self.dcm_folder), key=self.extract_file_order)
            sorted_label_files = sorted(os.listdir(self.label_folder), key=self.extract_file_order)

            for ind in range(len(sorted_dcm_files)):
                self.dcm_files.append(os.path.join(f"{folder_path}/DICOM_anon",sorted_dcm_files[ind]))
                self.label_files.append(os.path.join(f"{folder_path}/Ground",sorted_label_files[ind]))


    def extract_file_order(self, file_name):
        match = re.search(r'(\d+)', file_name)
        return int(match.group()) if match else 0


    def __getitem__(self, index):
            img_path = self.dcm_files[index]
            mask_path = self.label_files[index]
            dcm_file = pydicom.read_file(img_path)
            pixel_data = dcm_file.pixel_array.astype(np.float64)[None,:,:].copy()

            # For verification
            label = np.asarray(Image.open(mask_path))[None,:,:].copy()

            return torch.from_numpy(pixel_data).float(), torch.from_numpy(label).float()

    def __len__(self):
        return len(self.dcm_files)



In [None]:
class UNet(nn.Module):
    """
    Four blocks of Conv-Maxpool followed by UpConv
    """

    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()

        dims = [4, 8, 16, 32, 64]

        # Contraction
        self.Conv1 = ConvBlock(in_channels, dims[0])
        self.MaxPool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv2 = ConvBlock(dims[0], dims[1])
        self.MaxPool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv3 = ConvBlock(dims[1], dims[2])
        self.MaxPool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv4 = ConvBlock(dims[2], dims[3])
        self.MaxPool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv5 = ConvBlock(dims[3], dims[4])

        # Expansion
        self.Up5 = ConvTransBlock(dims[4], dims[3])
        self.UpConv5 = ConvBlock(dims[4], dims[3])

        self.Up4 = ConvTransBlock(dims[3], dims[2])
        self.UpConv4 = ConvBlock(dims[3], dims[2])

        self.Up3 = ConvTransBlock(dims[2], dims[1])
        self.UpConv3 = ConvBlock(dims[2], dims[1])

        self.Up2 = ConvTransBlock(dims[1], dims[0])
        self.UpConv2 = ConvBlock(dims[1], dims[0])

        self.FinalConv = nn.Conv2d(dims[0] , out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        # Contraction
        e1 = self.Conv1(x)

        e2 = self.MaxPool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.MaxPool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.MaxPool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.MaxPool4(e4)
        e5 = self.Conv5(e5)

        # Expansion
        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)
        d5 = self.UpConv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.UpConv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.UpConv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.UpConv2(d2)

        out = self.FinalConv(d2)

        return out

class ConvTransBlock(nn.Module):
    """
    Up Convolution Block
    """
    def __init__(self, in_ch, out_ch):
        super(ConvTransBlock, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x

class ConvBlock(nn.Module):
    """
    Conv-BatchNorm-ReLu-Conv-BatchNorm-ReLu
    """

    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=True,
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(
                out_channels,
                out_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=True,
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(inplace=True),
        )

    def forward(self, X):
        return self.model(X)


# class ConvTransBlock(nn.Module):
#     """
#     ConvTranspose2D-BatchNorm-ReLu
#     """

#     def __init__(self, in_channels, out_channels):
#         super(ConvTransBlock, self).__init__()
#         self.up_model = nn.Sequential(
#             nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
#             nn.BatchNorm2d(out_channels),
#             nn.LeakyReLU(inplace=True),
#         )

#     def forward(self, X):
#         return self.up_model(X)


In [None]:
def save_model(model, filepath=None):
    if filepath is None:
        current_time = time.strftime("%b%d-%H:%M")
        filepath = f"./unet-{current_time}.pt"

    torch.save(model.state_dict(), filepath)
    print(f"Model saved to {filepath}")

def load_model(model, filepath):
    model.load_state_dict(torch.load(filepath))
    print(f"Model loaded from {filepath}")
    return model

def jaccard_score(y_true, y_pred):
    # TODO: Modify for batches of data
    intersection = (y_true & y_pred).sum()
    union = (y_true | y_pred).sum()

    if union == 0:
        return 1.0  # If there's no ground truth or prediction, return 1 as a special case.
    else:
        return float(intersection) / float(union)

In [None]:
learning_rate = 1e-3
weight_decay = 1e-1
num_epoch = 100
batch_size = 8

data_dir = "/content/drive/MyDrive/Team Fancy/Train_Sets/CT"
trainset = CTDataset(data_dir)
testset = CTDataset(data_dir, trainset=False)

n_train = len(trainset)
train_loader = DataLoader(trainset, batch_size)
test_loader = DataLoader(testset, batch_size)

unet = UNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(unet.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [None]:
def train_model(unet, train_loader, test_loader, optimizer, criterion, n_epochs=5):
	total_loss = []
	train_loss_list, val_loss_list = [], []
	for epoch in range(1, n_epochs + 1):

		unet.train() # training mode

		train_loss = 0
		test_loss = 0

		for (i, (x, y)) in tqdm_notebook(enumerate(train_loader), total = len(train_loader)):
			x, y = x.to(device), y.to(device)
			pred = unet(x)
			loss = criterion(pred, y)
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()
			train_loss += loss

		train_loss /= n_train
		train_loss_list.append(train_loss)

		with torch.no_grad():
			unet.eval()
			for (x, y) in test_loader:
				x, y = x.to(device), y.to(device)
				pred = unet(x)
				total_loss += criterion(pred, y)

		test_loss /= n_train
		print(f"Epoch: {epoch}; train_loss: {train_loss}; val_loss: {test_loss}")
		val_loss_list.append(test_loss)

train_model(unet, train_loader, test_loader, optimizer, criterion, n_epochs=2)

  0%|          | 0/288 [00:00<?, ?it/s]

In [None]:
len(train_loader)