# U-Net on HexAI-HipKneeBonSeg dataset

### U-Net Model

In [None]:
import torch
import torch.nn as nn
from torch.nn import ConvTranspose2d, Conv2d, MaxPool2d, ReLU
from torchvision.transforms import CenterCrop
from torch.nn import functional as F


class Block(nn.Module):
    def __init__(self, in_Channels, out_Channels):
        super().__init__()
        self.conv1 = Conv2d(in_Channels, out_Channels, kernel_size=3, padding=1, padding_mode='replicate')
        self.conv2 = Conv2d(out_Channels, out_Channels, kernel_size=3, padding=1, padding_mode='replicate')
        self.ReLU = ReLU()
        self.BatchNorm = torch.nn.BatchNorm2d(out_Channels)

    def forward(self, x):
        conv1_output = self.conv1(x)
        conv1_output = self.BatchNorm(conv1_output)
        conv2_input = self.ReLU(conv1_output)
        conv2_output = self.conv2(conv2_input)
        conv2_output = self.BatchNorm(conv2_output)
        output = self.ReLU(conv2_output)
        return output


class Encoder(nn.Module):
    def __init__(self, channels=[1, 64, 128, 256, 512, 1024]):
        super().__init__()
        self.blocks = nn.ModuleList(
            [Block(channels[i], channels[i + 1]) for i in range(0, len(channels) - 1)]
        )
        self.pooling_layer = MaxPool2d(kernel_size=2, stride=2)  # The kernel size for max pool in the U-net model is 2

    def forward(self, x):
        outputs = []
        for block in self.blocks:
            x = block(x)
            outputs.append(x)
            x = self.pooling_layer(x)
        return outputs



class FinalBlock(nn.Module):
    def __init__(self, in_Channels, hidden_Channels, out_Channels):
        super().__init__()
        self.conv1 = Conv2d(in_Channels, hidden_Channels, kernel_size=3, padding=1, padding_mode='replicate')
        self.conv2 = Conv2d(hidden_Channels, hidden_Channels, kernel_size=3, padding=1, padding_mode='replicate')
        self.conv3 = Conv2d(hidden_Channels, out_Channels, kernel_size=1, padding=1, padding_mode='replicate')
        self.batchnorm1 = nn.BatchNorm2d(hidden_Channels)
        self.batchnorm2 = nn.BatchNorm2d(out_Channels)
        self.ReLU = ReLU()

    def forward(self, x):
        conv1_output = self.conv1(x)
        conv1_output = self.batchnorm1(conv1_output)
        conv2_input = self.ReLU(conv1_output)
        conv2_output = self.conv2(conv2_input)
        conv2_output = self.batchnorm1(conv2_output)
        conv3_input = self.ReLU(conv2_output)
        output = self.conv3(conv3_input)
        output = self.batchnorm2(output)
        return output


class Decoder(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.up_convs = nn.ModuleList(
            [ConvTranspose2d(channels[i], channels[i + 1], kernel_size=2, stride=2)
             for i in range(0, len(channels) - 2)]
        )
        self.blocks = nn.ModuleList(
            [Block(channels[i], channels[i + 1]) for i in range(0, len(channels) - 3)]
        )
        self.final_block = FinalBlock(channels[len(channels) - 3],  # in_channel
                                      channels[len(channels) - 2],  # hidden_channel
                                      channels[len(channels) - 1])  # out_channel


    def forward(self, encoder_output):
        x = encoder_output[len(encoder_output) - 1]  # get the output of the last layer of the encoder
        x = self.up_convs[0](x)
        for i, block in enumerate(self.blocks):
            enc_features = encoder_output[len(encoder_output) - i - 2]  # output of the same layer in encoder
            enc_features = self.copy_and_crop(x, enc_features)
            x = torch.cat([x, enc_features], dim=1)
            x = block(x)
            x = self.up_convs[i + 1](x)
        enc_features = encoder_output[0]
        enc_features = self.copy_and_crop(x, enc_features)
        x = torch.cat([x, enc_features], dim=1)

        return self.final_block(x)

    def copy_and_crop(self, x, enc_features):
        (_, _, H, W) = x.shape
        enc_features = CenterCrop([H, W])(enc_features)
        return enc_features

class UNet(nn.Module):
    def __init__(self, config, encoder_channels=[1, 64, 128, 256, 512, 1024],
                 decoder_channels=[1024, 512, 256, 128, 64],
                 retain_dim=True):
        super().__init__()

        self.encoder = Encoder(encoder_channels)
        decoder_channels.append(config.num_classes)
        self.decoder = Decoder(decoder_channels)
        self.output_size = config.image_shape
        self.retain_dim = retain_dim

    def forward(self, x):
        encoder_outputs = self.encoder(x)
        output = self.decoder(encoder_outputs)
        if self.retain_dim:
            output = F.interpolate(output, self.output_size)

        return output



### dataset

In [1]:
from torch.utils.data import Dataset
import cv2
import numpy as np
import torch
import SimpleITK as sitk
import pydicom
import numpy as np


class Dataset(Dataset):
    def __init__(self, paths, transforms_image, transforms_mask, transform_both=None):
        self.path_to_images = paths['path_to_images']
        self.path_to_masks = paths['path_to_masks']
        self.transforms_image = transforms_image
        self.transforms_mask = transforms_mask
        self.transforms_both = transform_both
        self.image_sizes = set()

    def __len__(self):
        return len(self.path_to_images)

    def __getitem__(self, index):
        image_path = self.path_to_images[index]
        image = pydicom.dcmread(image_path)
        image = image.pixel_array
        try:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        except:
            pass
        image = image.astype(np.uint8)
        mask_path = self.path_to_masks[index]
        mask = sitk.ReadImage(mask_path)
        mask = sitk.GetArrayFromImage(mask)
        if self.transforms_image is not None:
            image = self.transforms_image(image)
        if mask.ndim == 3 and mask.shape[2] == 3:  
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
            mask = np.squeeze(mask)
        if mask.ndim == 3 and mask.shape[2] != 3:
            for i in range(mask.shape[0]):
                if np.sum(mask[i, :, :]) > 0:
                    mask = mask[i, :, :]
                    mask = np.squeeze(mask)
        mask = mask.astype(np.uint8)
        if self.transforms_mask is not None:
            if mask.ndim == 3:
                mask = np.squeeze(mask)
            if mask.ndim == 3:
                mask = mask[0]
                mask = np.squeeze(mask)
            mask = self.transforms_mask(mask)
        if self.transforms_both is not None:
            stacked = torch.cat([image, mask], dim=0)  # shape=(2xHxW)
            stacked = self.transforms_both(stacked)
            image, mask = torch.chunk(stacked, chunks=2, dim=0)

        mask = mask > 0
        mask = mask.to(torch.float)
        self.image_sizes.add(image.size())
        return image, mask


### Utils

In [None]:
from torchvision import transforms
from imutils import paths
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from .dataset import Dataset, UnlabeledDataset, TestDataset
import os
import pandas as pd


def load_data(config):
    """
    it loads train and test set given the path to all images.
    :param config: initialized config class
    :return: trainset and testset
    """
    
    image_paths = sorted(list(os.listdir(os.path.join(config.path_to_dataset, 'Images'))))
    for i, image_path in enumerate(image_paths):
        image_paths[i] = os.path.join(config.path_to_dataset, 'Images', image_path)
    mask_paths = sorted(list(os.listdir(os.path.join(config.path_to_dataset, 'Annotations_NII'))))
    for i, mask_path in enumerate(mask_paths):
        mask_paths[i] = os.path.join(config.path_to_dataset, 'Annotations', mask_path)
    
    train_images, test_images, train_masks, test_masks = train_test_split(image_paths, mask_paths,
                                                                          test_size=config.test_size, random_state=7)
    with open("output.txt", "w") as txt_file:
        for line in test_images:
            txt_file.write(line.split('/')[-1].split('.')[0] + "\n") 

    transform_image = transforms.Compose([transforms.ToPILImage(),
                                          transforms.Resize(config.input_size),
                                          transforms.ToTensor(),
                                          transforms.Normalize(0.5, 0.1),
                                          transforms.GaussianBlur(kernel_size=3)])
    transform_mask = transforms.Compose([transforms.ToPILImage(),
                                         transforms.Resize(config.input_size),
                                         transforms.ToTensor()])
    train_path = {'path_to_images': train_images, 'path_to_masks': train_masks}
    test_path = {'path_to_images': test_images, 'path_to_masks': test_masks}
    train_set = Dataset(paths=train_path, transforms_image=transform_image,
                        transforms_mask=transform_mask)
    transform_image = transforms.Compose([transforms.ToPILImage(),
                                          transforms.Resize(config.input_size),
                                          transforms.ToTensor(),
                                          transforms.Normalize(0.5, 0.1)])
    transform_mask = transforms.Compose([transforms.ToPILImage(),
                                         transforms.Resize(config.input_size),
                                         transforms.ToTensor()])
    test_set = Dataset(paths=test_path, transforms_image=transform_image, transforms_mask=transform_mask)
    
    print(f"[INFO] found {len(train_set)} examples in the training set...")
    print(f"[INFO] found {len(test_set)} examples in the test set...")
    train_loader = DataLoader(train_set, shuffle=True, batch_size=config.batch_size, num_workers=os.cpu_count())
    test_loader = DataLoader(test_set, shuffle=False, batch_size=config.batch_size, num_workers=os.cpu_count())
    
    return train_loader, test_loader

### Training

In [None]:
import time
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import ExponentialLR
import matplotlib.pyplot as plt
from torchvision.ops import sigmoid_focal_loss
from torch.utils.data import DataLoader
from .dataset import PseudoDataset
from torch.nn.functional import sigmoid
from torchvision import transforms
import os


class Train(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.results = {"train_loss": [], "test_loss": [], 'IoU': [], 'precision': [], 'recall': []}
        self.epochs = self.config.num_epochs
        self.pseudo_labels = []
        self.pseudo_images = []
        self.transform_image = transforms.Compose([transforms.ToPILImage(),
                                                   transforms.Resize(config.input_size),
                                                   transforms.ToTensor(),
                                                   transforms.Normalize(0.5, 0.1),
                                                   transforms.GaussianBlur(kernel_size=3)])
        self.transform_mask = transforms.Compose([transforms.ToPILImage(),
                                                  transforms.Resize(config.input_size),
                                                  transforms.ToTensor()])
        

    def train(self, model, train_steps, train_loader, test_steps, test_loader, unlabeled_loader):

        loss_function_CE = CrossEntropyLoss()
        optimizer = SGD(model.parameters(), lr=self.config.lr, momentum=self.config.momentum)
        scheduler = ExponentialLR(optimizer, gamma=0.8)
        start_time = time.time()

        for epoch in tqdm(range(self.epochs)):
            model.train()
            total_loss = 0
            for (i, (image, mask)) in enumerate(train_loader):
                (image, mask) = (image.to(device=self.config.device), mask.to(device=self.config.device))
                prediction = model(image)
                loss = loss_function_CE(prediction, mask)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss
                if i % 10 == 0:
                    print(f'Train loss in epoch {epoch} and on batch {i} is: {loss}')
                    p = sigmoid(prediction)[0].cpu().detach().numpy()
                    p = p[0]
                    plt.imshow((p > 0.5) * 255)
                    plt.show()
                    plt.imshow(mask.cpu()[0][0])
                    plt.show()
            torch.save(model, os.path.join(self.config.path_to_models,
                                           ''.join(['epoch', str(epoch %5 + 1), '.pth'])))
            
            scheduler.step()
        end_time = time.time()
        print("[INFO] total time taken to train the model: {:.2f}s".format(end_time - start_time))
        return model



### Config

create a config similar to this:

In [None]:
import torch
import os
import argparse
import wandb
from ast import literal_eval

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
DATASET = 'JHIR_Knee'
MODEL = 'U-Net'
PRETRAINED = 'True'
PATH_TO_DATA = './data'
DATASET_FOLDER = 'JHIR_Hip_Knee_Datasets/Knee'
PATH_TO_MODELS = './models'
PATH_TO_REPORTS = './reports'
PATH_TO_RESULTS = './results'
TEST_SIZE = '0.2'
LEARNING_RATE = '2e-2'
MOMENTUM = '9e-1'
OPTIMIZER = 'sgd'
LOSS = 'BCE'
NUM_CLASSES = '9'
BATCH_SIZE = '16'
NUM_EPOCHS = '10000'
SCHEDULER_GAMMA = '0.98'
INPUT_SIZE = '(224,224)'
DATA_READ = 'hard'
AUGMENTATION = 'False'
THRESHOLD = 0.5


def parse_args(arguments=None):
    parser = argparse.ArgumentParser(description="Total Knee Replacement Prediction Task")
    parser.add_argument(
        "-m", "--model",
        default=MODEL,
        help="model"
    )
    parser.add_argument(
        "-pt", "--pretrained",
        default=PRETRAINED,
        help="pretrained"
    )
    parser.add_argument(
        "-ds", "--dataset_name",
        default=DATASET,
        help="dataset name"
    )
    parser.add_argument(
        "-ptd", "--path_to_data",
        default=PATH_TO_DATA,
        help="path to the data folder"
    )
    parser.add_argument(
        "-dsf", "--dataset_folder",
        default=DATASET_FOLDER,
        help="dataset folder"
    )
    parser.add_argument(
        "-ptds", "--path_to_dataset",
        default=os.path.join(PATH_TO_DATA, DATASET_FOLDER),
        help="path to the dataset folder"
    )
    parser.add_argument(
        "-ptm", "--path_to_models",
        default=PATH_TO_MODELS,
        help="path to models"
    )
    parser.add_argument(
        "-ptrp", "--path_to_reports",
        default=PATH_TO_REPORTS,
        help="path to reports to be saved"
    )
    parser.add_argument(
        "-ptrs", "--path_to_results",
        default=PATH_TO_RESULTS,
        help="path to outputs of the generation model to be saved"
    )
    parser.add_argument(
        "-ts", "--test_size",
        default=TEST_SIZE,
        help="test size for train test split"
    )
    parser.add_argument(
        "-lr", "--learning_rate",
        default=LEARNING_RATE,
        help="enter the learning rate"
    )
    parser.add_argument(
        "-mo", "--momentum",
        default=MOMENTUM,
        help="Enter momentum value"
    )
    parser.add_argument(
        "-o", "--optimizer",
        default=OPTIMIZER,
        help="optimizer"
    )
    parser.add_argument(
        "-l", "--loss",
        default=LOSS,
        help="loss type"
    )
    parser.add_argument(
        "-nc", "--num_classes",
        default=NUM_CLASSES,
        help="number of classes for classification"
    )
    parser.add_argument(
        "-bc", "--batch_size",
        default=BATCH_SIZE,
        help="batch size in the training phase"
    )
    parser.add_argument(
        "-ne", "--num_epochs",
        default=NUM_EPOCHS,
        help="number of epochs in training phase"
    )
    parser.add_argument(
        "-sg", "--scheduler_gamma",
        default=SCHEDULER_GAMMA,
        help="value of gamma for exponential lr scheduler"
    )
    parser.add_argument(
        "-is", "--input_size",
        default=INPUT_SIZE,
        help="the size of the input to the model"
    )
    parser.add_argument(
        "-dr", "--data_read",
        default=DATA_READ,
        help="choose if you want to keep data in hard or memory"
    )
    parser.add_argument(
        "-ag", "--augmentation",
        default=AUGMENTATION,
        help="Does training include augmentation"
    )
    parser.add_argument(
        "-th", "--threshold",
        default=THRESHOLD,
        help="threshold for classifier of the segmentation model"
    )
    args = parser.parse_args(arguments)
    return args


class Config:
    def __init__(self):
        args = parse_args()
        self.device = DEVICE
        self.model = args.model
        self.pretrained = (args.pretrained == 'True')
        self.augmentation = (args.augmentation == 'True')
        self.dataset = args.dataset_name
        self.path_to_data = args.path_to_data
        self.dataset_folder = args.dataset_folder
        self.path_to_dataset = args.path_to_dataset
        self.path_to_models = args.path_to_models
        os.makedirs(self.path_to_models, exist_ok=True)
        os.makedirs(os.path.join(self.path_to_models, self.model), exist_ok=True)
        self.reports_path = args.path_to_reports
        os.makedirs(self.reports_path, exist_ok=True)
        self.results_path = args.path_to_results
        os.makedirs(self.results_path, exist_ok=True)
        self.test_size = float(args.test_size)
        self.lr = float(args.learning_rate)
        self.momentum = float(args.momentum)
        self.optimizer = args.optimizer
        self.loss = args.loss
        self.num_classes = int(args.num_classes)
        self.num_epochs = int(args.num_epochs)
        self.batch_size = int(args.batch_size)
        self.scheduler_gamma = float(args.scheduler_gamma)
        self.image_shape = literal_eval(args.input_size)
        self.input_size = literal_eval(args.input_size)
        self.data_read = args.data_read
        self.threshold = args.threshold
        self.evaluation = False


In [None]:
from config import Config
from model import UNet, UNetPretrained
import utils as utils
from train import Train
import torch
import os

if __name__ == '__main__':
    config = Config()
   
    train = Train(config)
    model = UNet(config).to(device=config.device)
    model = model.to(device=config.device)
    train_loader, test_loader, unlabeled_loader = utils.load_data(config)
    train_steps = len(train_loader)
    test_steps = len(test_loader)
    model = train.train(model=model, train_steps=train_steps, train_loader=train_loader,
                            test_steps=test_steps, test_loader=test_loader, unlabeled_loader=unlabeled_loader)