In [1]:
import os
import csv
import time
import random
import sys
import math
import torch
import numpy as np
import pandas as pd
from PIL import Image
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchsummary import summary
from torch.nn.parallel import DataParallel
from torch.utils.data import Dataset, DataLoader
import torch.optim.lr_scheduler as lr_scheduler
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [2]:
#######################################   DATALOADER    ###########################################
class MedicalImageSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_files = sorted(os.listdir(image_dir))
        self.mask_files = sorted(os.listdir(mask_dir))
    def __len__(self):
        return len(self.image_files)
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        msk_name = self.mask_files[idx]
        img_path = os.path.join(self.image_dir, img_name)
        msk_path = os.path.join(self.mask_dir, msk_name)
        img = np.load(img_path)
        msk = np.load(msk_path)
        img = np.expand_dims(img, axis=0)
        msk = np.expand_dims(msk, axis=0)
        subject_id = img_name.split('_')[0]
        return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(msk)}

test_image_folder = "/ssd_scratch/ATLAS/Training/test/images"
test_mask_folder = "/ssd_scratch/ATLAS/Training/test/masks"
test_dataset = MedicalImageSegmentationDataset(test_image_folder, test_mask_folder)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [3]:
######################################     MODEL 2D Attention U-NET         ###########################
class conv_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2), 
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2),  
        )
    def forward(self, x):
        x = self.conv(x)
        return x
class up_conv(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        x = self.up(x)
        return x
class Attention_block(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(Attention_block, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)
    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi
class AttU_Net(nn.Module):
    def __init__(self, img_ch=1, output_ch=1, num_classes=1):
        super(AttU_Net, self).__init__()
        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Conv1 = conv_block(ch_in=img_ch, ch_out=64)
        self.Conv2 = conv_block(ch_in=64, ch_out=128)
        self.Conv3 = conv_block(ch_in=128, ch_out=256)
        self.Conv4 = conv_block(ch_in=256, ch_out=512)
        self.Conv5 = conv_block(ch_in=512, ch_out=1024)

        self.Up5 = up_conv(ch_in=1024, ch_out=512)
        self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512, ch_out=256)
        self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256, ch_out=128)
        self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128, ch_out=64)
        self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)
        self.final_conv = nn.Conv2d(output_ch, num_classes, kernel_size=1, stride=1, padding=0)
    def forward(self, x):
        # encoding path
        x1 = self.Conv1(x)
        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)
        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)
        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)
        
        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5, x=x4)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4, x=x3)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3, x=x2)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2, x=x1)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)
        return d1
model = AttU_Net(img_ch=1, output_ch=1, num_classes=1)
model = DataParallel(model)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load('/home/prantik.deb/notebooks/2D_models_demo/best_model_att_unet2D.pth', map_location=device))
model.to(device)

DataParallel(
  (module): AttU_Net(
    (Maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (Conv1): conv_block(
      (conv): Sequential(
        (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Dropout(p=0.2, inplace=False)
        (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU(inplace=True)
        (7): Dropout(p=0.2, inplace=False)
      )
    )
    (Conv2): conv_block(
      (conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Dropout(p=0.2, inplace=False)
        (4): Con

In [4]:
##################################         LOSS FUNCTION      ########################################
class DiceLoss(nn.Module):
    def __init__(self, squared_denom=False):
        super(DiceLoss, self).__init__()
        self.smooth = sys.float_info.epsilon
        self.squared_denom = squared_denom
    def forward(self, x, target):
        x = x.view(-1)
        target = target.view(-1)
        intersection = (x * target).sum()
        numer = 2. * intersection + self.smooth
        factor = 2 if self.squared_denom else 1
        denom = x.pow(factor).sum() + target.pow(factor).sum() + self.smooth
        dice_index = numer / denom
        return 1 - dice_index
class BCEWithLogitsAndDiceLoss(nn.Module):
    def __init__(self, bce_weight=0.1, smooth=1.):
        super(BCEWithLogitsAndDiceLoss, self).__init__()
        self.bce_weight = bce_weight
        self.smooth = smooth
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.dice_loss = DiceLoss()
    def forward(self, inputs, targets):
        bce_loss = self.bce_loss(inputs, targets)
        dice_loss = self.dice_loss(torch.sigmoid(inputs), targets)
        loss = self.bce_weight * bce_loss + (1. - self.bce_weight) * dice_loss
        return loss.mean()
criterion = BCEWithLogitsAndDiceLoss(bce_weight=0.1)

def dice_coefficient(inputs, labels, smooth=1):
    inputs = inputs.view(-1)
    labels = labels.view(-1)
    intersection = (inputs * labels).sum()
    union = inputs.sum() + labels.sum()
    dice = (2. * intersection + smooth) / (union + smooth)
    return dice
# IOU
def IoU(output, labels):
    smooth = 1.
    intersection = torch.logical_and(output, labels).sum()
    union = torch.logical_or(output, labels).sum()
    return (intersection + smooth) / (union + smooth)

In [5]:
# Saving Results
ep3 = []
model.eval()
test_loss = 0.0
test_dice = 0.0
test_iou = 0.0
num_slices = 0
test_precision = 0.0
test_recall = 0.0
if not os.path.exists('/ssd_scratch/ATLAS_2/results_2d/results_attunet'):
    os.makedirs('/ssd_scratch/ATLAS_2/results_2d/results_attunet')
with torch.no_grad(): 
    for i, data in enumerate(test_dataloader):
        inputs, labels = data['image'], data['mask']
        inputs = inputs.to('cuda').float()
        labels = labels.to('cuda')
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()        
        batch_dice = []
        batch_iou = []
        batch_precision = []
        batch_recall = []
        for j in range(outputs.shape[0]):
            dice = dice_coefficient(torch.sigmoid(outputs[j]), labels[j]).item()
            iou = IoU(outputs[j] > 0.5, labels[j] > 0.5).item()            
            true_positives = torch.sum((outputs[j] > 0.5) & (labels[j] > 0.5)).item()
            false_positives = torch.sum((outputs[j] > 0.5) & (labels[j] <= 0.5)).item()
            false_negatives = torch.sum((outputs[j] <= 0.5) & (labels[j] > 0.5)).item()            
            precision = true_positives / (true_positives + false_positives + 1e-6)
            recall = true_positives / (true_positives + false_negatives + 1e-6)            
            batch_dice.append(dice)
            batch_iou.append(iou)
            batch_precision.append(precision)
            batch_recall.append(recall)            
        test_dice += np.mean(batch_dice)
        test_iou += np.mean(batch_iou)
        test_precision += np.mean(batch_precision)
        test_recall += np.mean(batch_recall)       
        # Save the image, ground truth mask, and predicted mask together for comparison
        for j in range(len(inputs)):
            image = inputs[j].cpu().numpy().transpose((1, 2, 0))
            ground_truth_mask = labels[j].cpu().numpy().squeeze()  
            predicted_mask = torch.sigmoid(outputs[j]).cpu().numpy() > 0.5
            predicted_mask = predicted_mask.squeeze() 
            plt.figure()
            plt.subplot(1, 3, 1)
            plt.imshow(image, cmap='gray')
            plt.title('Image')
            plt.axis('off')
            plt.subplot(1, 3, 2)
            plt.imshow(ground_truth_mask, cmap='gray')
            plt.title('Ground Truth Mask')
            plt.axis('off')
            plt.subplot(1, 3, 3)
            plt.imshow(predicted_mask, cmap='gray')
            plt.title('Predicted Mask')
            plt.axis('off')
            plt.tight_layout()
            plt.savefig('/ssd_scratch/ATLAS_2/results_2d/results_attunet/result_{}_{}.png'.format(i, j), dpi=100)
            plt.close()            
        #After processing all the batches, the average metrics per slice are computed by dividing the 
        #accumulated metrics (test_loss, test_dice, test_iou, test_precision, test_recall) by the total 
        #number of slices (num_slices)
    # Calculate average metrics for the test dataset
    avg_test_loss = test_loss / len(test_dataloader)
    avg_test_dice = test_dice / len(test_dataloader)
    avg_test_iou = test_iou / len(test_dataloader)
    avg_test_precision = test_precision / len(test_dataloader)
    avg_test_recall = test_recall / len(test_dataloader)
    # Append epoch metrics to the list
    ep3.append([avg_test_loss, avg_test_dice, avg_test_iou, avg_test_precision, avg_test_recall])
    # Print the average metrics
    print('Average Test Dice: {:.4f}'.format(avg_test_dice))
    print('Average Test IoU: {:.4f}'.format(avg_test_iou))
    print('Average Test Precision: {:.4f}'.format(avg_test_precision))
    print('Average Test Recall: {:.4f}'.format(avg_test_recall))

ep_df = pd.DataFrame(np.array(ep3), columns=['Loss', 'Dice', 'IoU', 'Precision', 'Recall'])
ep_df.to_csv('metrics_test_attunet.csv', index=False) 

Average Test Dice: 0.4680
Average Test IoU: 0.3784
Average Test Precision: 0.6117
Average Test Recall: 0.4291
