In [1]:
import os
import csv
import time
import random
import sys
import math
import torch
import numpy as np
import pandas as pd
from PIL import Image
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchsummary import summary
from torch.nn.parallel import DataParallel
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset, DataLoader
import torch.optim.lr_scheduler as lr_scheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [2]:
#######################################   DATALOADER    ###########################################
class MedicalImageSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_files = sorted(os.listdir(image_dir))
        self.mask_files = sorted(os.listdir(mask_dir))
    def __len__(self):
        return len(self.image_files)
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        msk_name = self.mask_files[idx]
        img_path = os.path.join(self.image_dir, img_name)
        msk_path = os.path.join(self.mask_dir, msk_name)
        img = np.load(img_path)
        msk = np.load(msk_path)
        img = np.expand_dims(img, axis=0)
        msk = np.expand_dims(msk, axis=0)
        subject_id = img_name.split('_')[0]
        return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(msk)}

test_image_folder = "/ssd_scratch/ATLAS/Training/test/images"
test_mask_folder = "/ssd_scratch/ATLAS/Training/test/masks"
test_dataset = MedicalImageSegmentationDataset(test_image_folder, test_mask_folder)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [3]:
# 2D TransAttUnet  Architecture
'''Original Code: https://github.com/YishuLiu/TransAttUnet/blob/main/model/TransAttUnet.py
'''
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)
class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)
class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class MultiConv(nn.Module):
    def __init__(self, in_ch, out_ch, attn=True):
        super(MultiConv, self).__init__()

        self.fuse_attn = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.PReLU(),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.PReLU(),
            nn.Conv2d(out_ch, out_ch, kernel_size=1),
            nn.BatchNorm2d(out_ch),
            nn.Softmax2d() if attn else nn.PReLU()
        )

    def forward(self, x):
        return self.fuse_attn(x)

class PAM_Module(nn.Module):
    """空间注意力模块"""
    def __init__(self, in_dim):
        super(PAM_Module, self).__init__()
        self.chanel_in = in_dim

        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        m_batchsize, C, height, width = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)

        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, height, width)

        out = self.gamma * out + x
        return out

class PositionEmbeddingLearned(nn.Module):
    """
    可学习的位置编码
    """
    def __init__(self, num_pos_feats=256, len_embedding=32):
        super().__init__()
        self.row_embed = nn.Embedding(len_embedding, num_pos_feats)
        self.col_embed = nn.Embedding(len_embedding, num_pos_feats)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.row_embed.weight)
        nn.init.uniform_(self.col_embed.weight)

    def forward(self, tensor_list):
        x = tensor_list
        h, w = x.shape[-2:]
        i = torch.arange(w, device=x.device)
        j = torch.arange(h, device=x.device)

        x_emb = self.col_embed(i)
        y_emb = self.row_embed(j)

        pos = torch.cat([
            x_emb.unsqueeze(0).repeat(h, 1, 1),
            y_emb.unsqueeze(1).repeat(1, w, 1),
        ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)

        return pos

class ScaledDotProductAttention(nn.Module):
    '''自注意力模块'''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature ** 0.5
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, x, mask=None):
        m_batchsize, d, height, width = x.size()
        q = x.view(m_batchsize, d, -1)
        k = x.view(m_batchsize, d, -1)
        k = k.permute(0, 2, 1)
        v = x.view(m_batchsize, d, -1)

        attn = torch.matmul(q / self.temperature, k)

        if mask is not None:
            # 给需要mask的地方设置一个负无穷
            attn = attn.masked_fill(mask == 0, -1e9)

        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)
        output = output.view(m_batchsize, d, height, width)
        return output

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet_Attention_Transformer_Multiscale(nn.Module):
    def __init__(self, n_channels=1, n_classes=1, bilinear=True):
        super(UNet_Attention_Transformer_Multiscale, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(1024, 256 // factor, bilinear)
        self.up3 = Up(512, 128 // factor, bilinear)
        self.up4 = Up(256, 64, bilinear)
        self.outc = OutConv(128, n_classes)

        '''位置编码'''
        self.pos = PositionEmbeddingLearned(512 // factor)

        '''空间注意力机制'''
        self.pam = PAM_Module(512)

        '''自注意力机制'''
        self.sdpa = ScaledDotProductAttention(512)

        '''残差多尺度连接'''
        self.fuse1 = MultiConv(768, 256)
        self.fuse2 = MultiConv(384, 128)
        self.fuse3 = MultiConv(192, 64)
        self.fuse4 = MultiConv(128, 64)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        '''Setting 1'''
        x5_pam = self.pam(x5)
        '''Setting 2'''
        x5_pos = self.pos(x5)
        x5 = x5 + x5_pos
        x5_sdpa = self.sdpa(x5)
        x6 = self.up1(x5, x4)
        x5_scale = F.interpolate(x5, size=x6.shape[2:], mode='bilinear', align_corners=True)
        x6_cat = torch.cat((x5_scale, x6), 1)
        x7 = self.up2(x6_cat, x3)
        x6_scale = F.interpolate(x6, size=x7.shape[2:], mode='bilinear', align_corners=True)
        x7_cat = torch.cat((x6_scale, x7), 1)
        x8 = self.up3(x7_cat, x2)
        x7_scale = F.interpolate(x7, size=x8.shape[2:], mode='bilinear', align_corners=True)
        x8_cat = torch.cat((x7_scale, x8), 1)
        x9 = self.up4(x8_cat, x1)
        x8_scale = F.interpolate(x8, size=x9.shape[2:], mode='bilinear', align_corners=True)
        x9 = torch.cat((x8_scale, x9), 1)
        logits = self.outc(x9)
        return logits
        
model = UNet_Attention_Transformer_Multiscale(n_channels=1, n_classes=1)
model = DataParallel(model)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load('/home/prantik.deb/notebooks/2D_models_demo/best_model_transatt_unet2D.pth', map_location=device))
model.to(device)

DataParallel(
  (module): UNet_Attention_Transformer_Multiscale(
    (inc): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (down1): Down(
      (maxpool_conv): Sequential(
        (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (1): DoubleConv(
          (double_conv): Sequential(
            (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv2d(128, 128, ker

In [4]:
##################################         LOSS FUNCTION      ########################################
class DiceLoss(nn.Module):
    def __init__(self, squared_denom=False):
        super(DiceLoss, self).__init__()
        self.smooth = sys.float_info.epsilon
        self.squared_denom = squared_denom
    def forward(self, x, target):
        x = x.view(-1)
        target = target.view(-1)
        intersection = (x * target).sum()
        numer = 2. * intersection + self.smooth
        factor = 2 if self.squared_denom else 1
        denom = x.pow(factor).sum() + target.pow(factor).sum() + self.smooth
        dice_index = numer / denom
        return 1 - dice_index
class BCEWithLogitsAndDiceLoss(nn.Module):
    def __init__(self, bce_weight=0.1, smooth=1.):
        super(BCEWithLogitsAndDiceLoss, self).__init__()
        self.bce_weight = bce_weight
        self.smooth = smooth
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.dice_loss = DiceLoss()
    def forward(self, inputs, targets):
        bce_loss = self.bce_loss(inputs, targets)
        dice_loss = self.dice_loss(torch.sigmoid(inputs), targets)
        loss = self.bce_weight * bce_loss + (1. - self.bce_weight) * dice_loss
        return loss.mean()
criterion = BCEWithLogitsAndDiceLoss(bce_weight=0.1)

def dice_coefficient(inputs, labels, smooth=1):
    inputs = inputs.view(-1)
    labels = labels.view(-1)
    intersection = (inputs * labels).sum()
    union = inputs.sum() + labels.sum()
    dice = (2. * intersection + smooth) / (union + smooth)
    return dice
# IOU
def IoU(output, labels):
    smooth = 1.
    intersection = torch.logical_and(output, labels).sum()
    union = torch.logical_or(output, labels).sum()
    return (intersection + smooth) / (union + smooth)

In [5]:
# Saving Results
ep3 = []
model.eval()
test_loss = 0.0
test_dice = 0.0
test_iou = 0.0
num_slices = 0
test_precision = 0.0
test_recall = 0.0
if not os.path.exists('/ssd_scratch/ATLAS_2/results_2d/results_att_trans_unet'):
    os.makedirs('/ssd_scratch/ATLAS_2/results_2d/results_att_trans_unet')
with torch.no_grad(): 
    for i, data in enumerate(test_dataloader):
        inputs, labels = data['image'], data['mask']
        inputs = inputs.to('cuda').float()
        labels = labels.to('cuda')
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()        
        batch_dice = []
        batch_iou = []
        batch_precision = []
        batch_recall = []
        for j in range(outputs.shape[0]):
            dice = dice_coefficient(torch.sigmoid(outputs[j]), labels[j]).item()
            iou = IoU(outputs[j] > 0.5, labels[j] > 0.5).item()            
            true_positives = torch.sum((outputs[j] > 0.5) & (labels[j] > 0.5)).item()
            false_positives = torch.sum((outputs[j] > 0.5) & (labels[j] <= 0.5)).item()
            false_negatives = torch.sum((outputs[j] <= 0.5) & (labels[j] > 0.5)).item()            
            precision = true_positives / (true_positives + false_positives + 1e-6)
            recall = true_positives / (true_positives + false_negatives + 1e-6)            
            batch_dice.append(dice)
            batch_iou.append(iou)
            batch_precision.append(precision)
            batch_recall.append(recall)            
        test_dice += np.mean(batch_dice)
        test_iou += np.mean(batch_iou)
        test_precision += np.mean(batch_precision)
        test_recall += np.mean(batch_recall)       
        # Save the image, ground truth mask, and predicted mask together for comparison
        for j in range(len(inputs)):
            image = inputs[j].cpu().numpy().transpose((1, 2, 0))
            ground_truth_mask = labels[j].cpu().numpy().squeeze()  
            predicted_mask = torch.sigmoid(outputs[j]).cpu().numpy() > 0.5
            predicted_mask = predicted_mask.squeeze() 
            plt.figure()
            plt.subplot(1, 3, 1)
            plt.imshow(image, cmap='gray')
            plt.title('Image')
            plt.axis('off')
            plt.subplot(1, 3, 2)
            plt.imshow(ground_truth_mask, cmap='gray')
            plt.title('Ground Truth Mask')
            plt.axis('off')
            plt.subplot(1, 3, 3)
            plt.imshow(predicted_mask, cmap='gray')
            plt.title('Predicted Mask')
            plt.axis('off')
            plt.tight_layout()
            plt.savefig('/ssd_scratch/ATLAS_2/results_2d/results_att_trans_unet/result_{}_{}.png'.format(i, j), dpi=100)
            plt.close()            
        #After processing all the batches, the average metrics per slice are computed by dividing the 
        #accumulated metrics (test_loss, test_dice, test_iou, test_precision, test_recall) by the total 
        #number of slices (num_slices)
    # Calculate average metrics for the test dataset
    avg_test_loss = test_loss / len(test_dataloader)
    avg_test_dice = test_dice / len(test_dataloader)
    avg_test_iou = test_iou / len(test_dataloader)
    avg_test_precision = test_precision / len(test_dataloader)
    avg_test_recall = test_recall / len(test_dataloader)
    # Append epoch metrics to the list
    ep3.append([avg_test_loss, avg_test_dice, avg_test_iou, avg_test_precision, avg_test_recall])
    # Print the average metrics
    print('Average Test Dice: {:.4f}'.format(avg_test_dice))
    print('Average Test IoU: {:.4f}'.format(avg_test_iou))
    print('Average Test Precision: {:.4f}'.format(avg_test_precision))
    print('Average Test Recall: {:.4f}'.format(avg_test_recall))

ep_df = pd.DataFrame(np.array(ep3), columns=['Loss', 'Dice', 'IoU', 'Precision', 'Recall'])
ep_df.to_csv('metrics_test_att_trans_unet.csv', index=False) 

Average Test Dice: 0.5794
Average Test IoU: 0.4719
Average Test Precision: 0.6564
Average Test Recall: 0.5827
