# Imports


In [3]:

import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch

# Choose GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


# Dataset loading

In [2]:

class BrainTumorDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.images = sorted(os.listdir(image_dir))
        self.masks = sorted(os.listdir(mask_dir))
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.masks[idx])

        # Load grayscale image and mask
        image = Image.open(img_path).convert("L")
        mask = Image.open(mask_path).convert("L")
        #img = img.resize((512, 512), Image.BILINEAR)
        #mask = mask.resize((512, 512), Image.NEAREST)
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        mask = torch.where(mask > 0, 1.0, 0.0)
        return image, mask



In [3]:
# 4. Transforms
# -----------------------------
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),  # This will create 1-channel tensor
])


# Data loader

In [4]:

# Change the train images and masks
train_images = "/home/readinggroup/Desktop/Image_proc_Noman/CSE465_project/dataset/segmentation_task/train/images"
train_masks  = "/home/readinggroup/Desktop/Image_proc_Noman/CSE465_project/dataset/segmentation_task/train/masks"

full_dataset = BrainTumorDataset(train_images, train_masks, transform=transform)

# Split: 80% train, 20% validation
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

print(f"Train samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

Train samples: 6292, Validation samples: 1574


# UNet++ architecture

In [4]:
# 5. U-Net++ Model (simplified)
# -----------------------------
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class UNetPlusPlus(nn.Module):
    def __init__(self, in_ch=1, out_ch=1):
        super().__init__()
        filters = [64, 128, 256, 512]

        # Encoder
        self.conv0_0 = ConvBlock(in_ch, filters[0])
        self.pool0 = nn.MaxPool2d(2)
        self.conv1_0 = ConvBlock(filters[0], filters[1])
        self.pool1 = nn.MaxPool2d(2)
        self.conv2_0 = ConvBlock(filters[1], filters[2])
        self.pool2 = nn.MaxPool2d(2)
        self.conv3_0 = ConvBlock(filters[2], filters[3])

        # Decoder with nested connections
        self.up2_1 = nn.ConvTranspose2d(filters[3], filters[2], 2, stride=2)
        self.conv2_1 = ConvBlock(filters[2]*2, filters[2])

        self.up1_2 = nn.ConvTranspose2d(filters[2], filters[1], 2, stride=2)
        self.conv1_2 = ConvBlock(filters[1]*2, filters[1])

        self.up0_3 = nn.ConvTranspose2d(filters[1], filters[0], 2, stride=2)
        self.conv0_3 = ConvBlock(filters[0]*2, filters[0])

        # Final output
        self.final = nn.Conv2d(filters[0], out_ch, 1)

    def forward(self, x):
        # Encoder
        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(self.pool0(x0_0))
        x2_0 = self.conv2_0(self.pool1(x1_0))
        x3_0 = self.conv3_0(self.pool2(x2_0))

        # Decoder
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up2_1(x3_0)], dim=1))
        x1_2 = self.conv1_2(torch.cat([x1_0, self.up1_2(x2_1)], dim=1))
        x0_3 = self.conv0_3(torch.cat([x0_0, self.up0_3(x1_2)], dim=1))

        out = torch.sigmoid(self.final(x0_3))
        return out

model = UNetPlusPlus().to(device)

# Loss and optimizer

In [6]:

def dice_loss(pred, target, smooth=1.):
    pred = pred.contiguous()
    target = target.contiguous()
    intersection = (pred * target).sum(dim=2).sum(dim=2)
    loss = 1 - ((2. * intersection + smooth) /
                (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth))
    return loss.mean()

optimizer = torch.optim.Adam(model.parameters(), lr= 3e-4)

# Training loop

In [7]:

num_epochs = 20
arr_loss = []
arr_val_loss = []
best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, masks in tqdm(train_loader):
        images, masks = images.to(device), masks.to(device)
        preds = model(images)
        loss = dice_loss(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)
    arr_loss.append(train_loss)
    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            preds = model(images)
            loss = dice_loss(preds, masks)
            val_loss += loss.item()
    val_loss /= len(val_loader)
    arr_val_loss.append(val_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "./best_unetpp.pth")

100%|█████████████████████████████████████████| 787/787 [01:19<00:00,  9.84it/s]


Epoch [1/20] Train Loss: 0.6184 | Val Loss: 0.4427


100%|█████████████████████████████████████████| 787/787 [01:18<00:00,  9.99it/s]


Epoch [2/20] Train Loss: 0.3481 | Val Loss: 0.3529


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.03it/s]


Epoch [3/20] Train Loss: 0.3125 | Val Loss: 0.2882


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.03it/s]


Epoch [4/20] Train Loss: 0.2828 | Val Loss: 0.3059


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.02it/s]


Epoch [5/20] Train Loss: 0.2666 | Val Loss: 0.2742


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.06it/s]


Epoch [6/20] Train Loss: 0.2524 | Val Loss: 0.2601


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.00it/s]


Epoch [7/20] Train Loss: 0.2450 | Val Loss: 0.2410


100%|█████████████████████████████████████████| 787/787 [01:18<00:00,  9.98it/s]


Epoch [8/20] Train Loss: 0.2332 | Val Loss: 0.2660


100%|█████████████████████████████████████████| 787/787 [01:19<00:00,  9.96it/s]


Epoch [9/20] Train Loss: 0.2226 | Val Loss: 0.2499


100%|█████████████████████████████████████████| 787/787 [01:18<00:00,  9.99it/s]


Epoch [10/20] Train Loss: 0.2093 | Val Loss: 0.2217


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.04it/s]


Epoch [11/20] Train Loss: 0.2160 | Val Loss: 0.2103


100%|█████████████████████████████████████████| 787/787 [01:19<00:00,  9.93it/s]


Epoch [12/20] Train Loss: 0.1983 | Val Loss: 0.2135


100%|█████████████████████████████████████████| 787/787 [01:19<00:00,  9.93it/s]


Epoch [13/20] Train Loss: 0.1919 | Val Loss: 0.2514


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.00it/s]


Epoch [14/20] Train Loss: 0.1864 | Val Loss: 0.2109


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.01it/s]


Epoch [15/20] Train Loss: 0.1823 | Val Loss: 0.2077


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.00it/s]


Epoch [16/20] Train Loss: 0.1768 | Val Loss: 0.1987


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.01it/s]


Epoch [17/20] Train Loss: 0.1751 | Val Loss: 0.1933


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.04it/s]


Epoch [18/20] Train Loss: 0.1729 | Val Loss: 0.1964


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.03it/s]


Epoch [19/20] Train Loss: 0.1637 | Val Loss: 0.1842


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.03it/s]


Epoch [20/20] Train Loss: 0.1583 | Val Loss: 0.1819


In [8]:
print(arr_loss)
print(arr_val_loss)

[0.618406601013495, 0.3481415968667598, 0.3124854265034123, 0.2827809602495524, 0.26660745069289904, 0.25242166413323247, 0.24497065041805918, 0.2332423186305972, 0.22263491756769116, 0.20926080510645081, 0.21596734447279303, 0.1983425939287588, 0.19188927698127803, 0.18637375375480786, 0.18228953460461, 0.17684716698175162, 0.1750772531218874, 0.1729141300996195, 0.16368162277920273, 0.1583376938744603]
[0.4427055951756269, 0.35286171337220873, 0.28817695526756004, 0.30594591365247814, 0.2741560292274214, 0.26008611161091605, 0.2410309955023872, 0.26604929146579076, 0.24994958158071875, 0.22167358373476165, 0.2103497960845831, 0.2134701632303635, 0.25144485996913185, 0.21085874331632848, 0.20774140186267456, 0.1987439627347864, 0.1932508124388414, 0.19643148832817367, 0.1842288132473297, 0.18192721835247755]


In [7]:
from sklearn.metrics import confusion_matrix
from numpy import ndarray
import cv2 as cv
import torch
import numpy as np
import os
from torchvision import transforms
from PIL import Image
def dice(TP, FP, FN):
    return (2*TP)/(FP + (2*TP) + FN)

def iou(TP, FP, FN):
    return TP/(TP + FP + FN)

def ppv(TP, FP):
    return TP/(FP + TP)

def accuracy(TP, TN, FP, FN):
    return (TP + TN)/(TP + TN + FP + FN)

def sensitivity(TP, FN):
    return TP/(TP+FN)

def perf_measure(y_actual, y_pred):
    TP = 0
    FP = 0
    TN = 0
    FN = 0

    for i in range(len(y_pred)):
        if y_actual[i]==y_pred[i]==1:
           TP += 1
        if y_pred[i]==1 and y_actual[i]!=y_pred[i]:
           FP += 1
        if y_actual[i]==y_pred[i]==0:
           TN += 1
        if y_pred[i]==0 and y_actual[i]!=y_pred[i]:
           FN += 1

    return (TN, FP, FN, TP)








image_dir = "/home/readinggroup/Desktop/Image_proc_Noman/CSE465_project/dataset/segmentation_task/test/images/"
mask_dir = "/home/readinggroup/Desktop/Image_proc_Noman/CSE465_project/dataset/segmentation_task/test/masks/"

img_paths = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))])
mask_paths = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir) if os.path.isfile(os.path.join(mask_dir, f))])





# # Empty arrays to keep the evaluation results
arr_dice = []
arr_iou = []
arr_ppv = []
arr_accuracy = []
arr_sensitivity = []
MODEL_PATH = "/home/readinggroup/Desktop/Image_proc_Noman/CSE465_project/model_weights/best_unetplusplus.pth"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = UNetPlusPlus().to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(device)))



for i in range(len(img_paths)):

      SINGLE_IMG_PATH = img_paths[i]
      MASK_PATH = mask_paths[i]

      # model = UNet(in_channels=1, num_classes=1).to(device)
      # model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(device)))

      transform = transforms.Compose([
          transforms.Resize((512, 512)),
          transforms.ToTensor()])

      img = transform(Image.open(SINGLE_IMG_PATH).convert("L")).float().to(device)
      mask = transform(Image.open(MASK_PATH).convert("L")).float().to(device)

      img = img.unsqueeze(0)
      mask = mask.unsqueeze(0)
      pred_mask = model(img)

      # img = img.squeeze(0).cpu().detach()
      # mask = mask.squeeze(0).cpu().detach()
      # # mask = mask.cpu().detach()
      # img = img.permute(1, 2, 0)
      # mask = mask.permute(1, 2, 0)

      # pred_mask = pred_mask.squeeze(0).cpu().detach()
      # pred_mask = pred_mask.permute(1, 2, 0)
      pred_mask[pred_mask <= 0.5]=0
      pred_mask[pred_mask > 0.5]=1


      # # Binarize the predicted mask
      pred_mask = pred_mask.squeeze(0).cpu().detach().numpy() # Remove the batch dimension
      mask = mask.squeeze(0).cpu().detach().numpy()


      mask = mask.flatten()
      pred_mask = pred_mask.flatten()
    
      # print(mask)
      TN, FP, FN, TP = perf_measure(mask, pred_mask)



      print(TN, FP, FN, TP)

      arr_dice.append(dice(TP, FP, FN))
      arr_iou.append(iou(TP, FP, FN))
      arr_ppv.append(ppv(TP, FP))
      arr_accuracy.append(accuracy(TP, TN, FP, FN))
      arr_sensitivity.append(sensitivity(TP, FN))



254975 5780 1181 208
254594 5114 2013 423
252207 3506 4809 1622
252134 5577 2817 1616
252880 7723 1501 40
248953 7812 4004 1375
256083 5224 425 412
256572 3707 229 1636
247781 4185 10178 0
255835 5418 423 468
256186 5667 118 173
254952 5136 1719 337
254878 3167 1389 2710
246522 5495 9451 676
254093 4373 2721 957
249929 8900 2691 624
250738 9814 832 760
249928 3426 8333 457
245714 4937 9124 2369
253095 4009 1649 3391
254997 3990 1122 2035
252974 3998 1874 3298
254268 4378 1590 1908
254574 1691 3082 2797
253197 7782 232 933
250750 5134 5214 1046
246332 7392 8230 190
249454 6873 4348 1469
253322 3670 1847 3305
256198 4344 429 1173
250595 3445 6013 2091
247697 5475 6179 2793
254276 5194 428 2246
255800 3309 2018 1017
254033 4556 1525 2030
258151 2207 245 1541
244624 2290 12940 2290
252854 2146 3983 3161
250567 6051 4841 685
248397 5863 7195 689
247110 8361 5251 1422
249282 5631 6747 484
245479 5141 10761 763
254817 4147 400 2780
244887 4024 11143 2090
244801 4079 10434 2830
252056 6869 296

258239 1954 256 1695
257630 1828 260 2426
252878 2736 2093 4437
258077 1297 557 2213
251755 2442 2906 5041
253196 4340 1020 3588
256082 2434 276 3352
255004 5502 254 1384
256682 1950 446 3066
246786 3187 7632 4539
247250 5694 6651 2549
244892 3577 12295 1380
248145 5135 8861 3
243877 4225 11994 2048
242361 3628 9457 6698
238976 3060 14302 5806
230482 4458 23553 3651
239718 1606 14563 6257
243942 4544 11415 2243
245342 3416 10013 3373
254261 2764 2690 2429
250667 4044 2406 5027
255811 2862 411 3060
253816 5167 608 2553
250511 4040 1839 5754
245358 3475 10040 3271
246684 3907 2044 9509
230097 3229 26144 2674
252243 3258 4453 2190
253340 5082 2127 1595
245983 1029 9030 6102
234299 5407 15044 7394
241864 5146 11022 4112
241895 5062 10991 4196
240686 1981 13613 5864
248045 5802 8055 242
234905 6111 20712 416
242329 5590 13080 1145
244116 5256 11394 1378
252518 1083 4263 4280
257457 3077 426 1184
256825 807 503 4009
254081 1634 281 6148
257747 1982 345 2070
255876 1398 352 4518
255438 3388 1

258321 2407 306 1110
252191 5573 3385 995
255245 3268 2557 1074
257964 2926 343 911
257291 2848 1513 492
255905 4046 1058 1135
257140 4392 232 380
256759 3969 509 907
253240 5327 556 3021
255731 2361 2525 1527
251437 5455 4504 748
258811 837 872 1624
259069 2056 146 873
254649 2083 2046 3366
257245 3581 602 716
254987 4714 1735 708
256548 3806 1666 124
254465 6978 566 135
254413 5750 1964 17
254716 4383 1859 1186
258486 1879 488 1291
254061 6544 1158 381
254663 4295 3186 0
255086 5575 1467 16
257651 3805 684 4
257749 1130 1041 2224
255192 5240 1712 0
254866 5611 851 816
257433 3459 390 862
254992 5417 973 762
257092 1631 859 2562
256861 3245 1373 665
259229 1601 244 1070
258211 2521 295 1117
259388 1680 125 951
257118 4228 158 640
254467 6144 1458 75
253370 5295 660 2819
254564 3588 2737 1255
253376 6297 2471 0
257112 3663 424 945
257222 2856 424 1642
258930 2561 87 566
258008 3411 136 589
258628 2802 317 397
258272 2911 961 0
257141 3319 1684 0
254344 6054 1746 0
253812 7321 1011 0
25

In [8]:
print("Average dice: ", np.mean(arr_dice))
print("Average iou: ", np.mean(arr_iou))
print("Average ppv: ", np.mean(arr_ppv))
print("Average accuracy: ", np.mean(arr_accuracy))
print("Average sensitivity: ", np.mean(arr_sensitivity))

Average dice:  0.4044408815246971
Average iou:  0.28023155874891237
Average ppv:  0.3843992194679883
Average accuracy:  0.97532041239184
Average sensitivity:  0.5288162957616644
