# Imports


In [3]:
#Step 2: Imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from tqdm import tqdm
import segmentation_models_pytorch as smp


# Dataset loading

In [2]:
#Step 3: Dataset Class (Brain Tumor Segmentation)
from PIL import Image
import os

class BrainTumorDataset(torch.utils.data.Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images = sorted([os.path.join(images_dir, f) for f in os.listdir(images_dir)])
        self.masks  = sorted([os.path.join(masks_dir, f) for f in os.listdir(masks_dir)])
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert("RGB")
        mask  = Image.open(self.masks[idx]).convert("L")

        if self.transform:
            image = self.transform(image)
            mask  = self.transform(mask)

        return image, mask

In [3]:
#Step 4: Data Preparation
transform = transforms.Compose([
    transforms.Resize((512,512)),
    transforms.ToTensor()
])

# Data loader

In [4]:


train_images = "/home/readinggroup/Desktop/Image_proc_Noman/CSE465_project/dataset/segmentation_task/train/images"
train_masks  = "/home/readinggroup/Desktop/Image_proc_Noman/CSE465_project/dataset/segmentation_task/train/masks"

dataset = BrainTumorDataset(train_images, train_masks, transform=transform)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")


Train: 6292, Val: 1574


# Swin-Unet architecture

In [4]:
#Step 5: Swin U-Net Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = smp.Unet(
    encoder_name="resnet34", # Swin Transformer
    encoder_weights="imagenet",
    in_channels=3,
    classes=1
).to(device)

# Loss and optimizer

In [6]:

loss_fn = smp.losses.DiceLoss(mode='binary')
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

# Training loop

In [7]:
#Step 6: Training Loop
num_epochs = 20
best_val_loss = float('inf')
arr_loss = []
arr_val_loss = []

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, masks in tqdm(train_loader):
        images, masks = images.to(device), masks.to(device)
        preds = model(images)
        loss = loss_fn(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)
    arr_loss.append(train_loss)
    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            preds = model(images)
            loss = loss_fn(preds, masks)
            val_loss += loss.item()
    val_loss /= len(val_loader)
    arr_val_loss.append(val_loss)
    
    print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "./best_swinunet.pth")

100%|███████████████████████████████████████| 1573/1573 [01:22<00:00, 19.07it/s]


Epoch [1/20] Train Loss: 0.3611 | Val Loss: 0.2464


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.35it/s]


Epoch [2/20] Train Loss: 0.2511 | Val Loss: 0.2265


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.24it/s]


Epoch [3/20] Train Loss: 0.2221 | Val Loss: 0.2589


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.29it/s]


Epoch [4/20] Train Loss: 0.2047 | Val Loss: 0.2834


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.37it/s]


Epoch [5/20] Train Loss: 0.2031 | Val Loss: 0.2335


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.23it/s]


Epoch [6/20] Train Loss: 0.1916 | Val Loss: 0.2295


100%|███████████████████████████████████████| 1573/1573 [01:22<00:00, 19.18it/s]


Epoch [7/20] Train Loss: 0.1861 | Val Loss: 0.2202


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.25it/s]


Epoch [8/20] Train Loss: 0.1759 | Val Loss: 0.2144


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.22it/s]


Epoch [9/20] Train Loss: 0.1680 | Val Loss: 0.1679


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.24it/s]


Epoch [10/20] Train Loss: 0.1570 | Val Loss: 0.1592


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.25it/s]


Epoch [11/20] Train Loss: 0.1572 | Val Loss: 0.1774


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.31it/s]


Epoch [12/20] Train Loss: 0.1587 | Val Loss: 0.1581


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.37it/s]


Epoch [13/20] Train Loss: 0.1429 | Val Loss: 0.1725


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.23it/s]


Epoch [14/20] Train Loss: 0.1420 | Val Loss: 0.1565


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.22it/s]


Epoch [15/20] Train Loss: 0.1536 | Val Loss: 0.1573


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.26it/s]


Epoch [16/20] Train Loss: 0.1457 | Val Loss: 0.1511


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.21it/s]


Epoch [17/20] Train Loss: 0.1271 | Val Loss: 0.1505


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.33it/s]


Epoch [18/20] Train Loss: 0.1356 | Val Loss: 0.1725


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.29it/s]


Epoch [19/20] Train Loss: 0.1288 | Val Loss: 0.1424


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.28it/s]


Epoch [20/20] Train Loss: 0.1218 | Val Loss: 0.1429


In [8]:
print(arr_loss)
print(arr_val_loss)

[0.361053874404735, 0.25110906032870245, 0.222121250758259, 0.20470547463889344, 0.20305545233043226, 0.1915684346173572, 0.18609065154670382, 0.17589696766234322, 0.16803134829090116, 0.15703129582699135, 0.1572187702821003, 0.15868924172608168, 0.14285741315400305, 0.14202350139314782, 0.15360473805904692, 0.14567229111938815, 0.12714307723497026, 0.13559026784072314, 0.12880292172956376, 0.12182841358560706]
[0.24637850090331836, 0.22649635094676526, 0.2589454066934924, 0.28340094010842026, 0.2335007553778324, 0.22945683377648368, 0.22019361118374742, 0.21442013646140318, 0.1679384458791181, 0.15922598260913404, 0.17739438072679006, 0.15812735796579855, 0.17250827333043675, 0.15649576371696394, 0.15734674800471002, 0.15112149896960572, 0.15053618181175388, 0.17245382963098246, 0.14239043284793795, 0.14292416188317508]


In [7]:
from sklearn.metrics import confusion_matrix
from numpy import ndarray
import cv2 as cv
import torch
import numpy as np
import os
from torchvision import transforms
from PIL import Image
def dice(TP, FP, FN):
    return (2*TP)/(FP + (2*TP) + FN)

def iou(TP, FP, FN):
    return TP/(TP + FP + FN)

def ppv(TP, FP):
    return TP/(FP + TP)

def accuracy(TP, TN, FP, FN):
    return (TP + TN)/(TP + TN + FP + FN)

def sensitivity(TP, FN):
    return TP/(TP+FN)

def perf_measure(y_actual, y_pred):
    TP = 0
    FP = 0
    TN = 0
    FN = 0

    for i in range(len(y_pred)):
        if y_actual[i]==y_pred[i]==1:
           TP += 1
        if y_pred[i]==1 and y_actual[i]!=y_pred[i]:
           FP += 1
        if y_actual[i]==y_pred[i]==0:
           TN += 1
        if y_pred[i]==0 and y_actual[i]!=y_pred[i]:
           FN += 1

    return (TN, FP, FN, TP)








image_dir = "/home/readinggroup/Desktop/Image_proc_Noman/CSE465_project/dataset/segmentation_task/test/images/"
mask_dir = "/home/readinggroup/Desktop/Image_proc_Noman/CSE465_project/dataset/segmentation_task/test/masks/"

img_paths = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))])
mask_paths = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir) if os.path.isfile(os.path.join(mask_dir, f))])





# # Empty arrays to keep the evaluation results
arr_dice = []
arr_iou = []
arr_ppv = []
arr_accuracy = []
arr_sensitivity = []
MODEL_PATH = "/home/readinggroup/Desktop/Image_proc_Noman/CSE465_project/model_weights/best_swinunet.pth"
device = "cuda" if torch.cuda.is_available() else "cpu"

# model = UNet(in_channels=1, num_classes=1).to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(device)))



for i in range(len(img_paths)):

      SINGLE_IMG_PATH = img_paths[i]
      MASK_PATH = mask_paths[i]

      # model = UNet(in_channels=1, num_classes=1).to(device)
      # model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(device)))

      transform = transforms.Compose([
          transforms.Resize((512, 512)),
          transforms.ToTensor()])

      img = transform(Image.open(SINGLE_IMG_PATH).convert("RGB")).float().to(device)
      mask = transform(Image.open(MASK_PATH).convert("L")).float().to(device)

      img = img.unsqueeze(0)
      mask = mask.unsqueeze(0)
      pred_mask = model(img)

      # img = img.squeeze(0).cpu().detach()
      # mask = mask.squeeze(0).cpu().detach()
      # # mask = mask.cpu().detach()
      # img = img.permute(1, 2, 0)
      # mask = mask.permute(1, 2, 0)

      # pred_mask = pred_mask.squeeze(0).cpu().detach()
      # pred_mask = pred_mask.permute(1, 2, 0)
      pred_mask[pred_mask <= 0.5]=0
      pred_mask[pred_mask > 0.5]=1


      # # Binarize the predicted mask
      pred_mask = pred_mask.squeeze(0).cpu().detach().numpy() # Remove the batch dimension
      mask = mask.squeeze(0).cpu().detach().numpy()


      mask = mask.flatten()
      pred_mask = pred_mask.flatten()
    
      # print(mask)
      TN, FP, FN, TP = perf_measure(mask, pred_mask)



      print(TN, FP, FN,  TP)

      arr_dice.append(dice(TP, FP, FN))
      arr_iou.append(iou(TP, FP, FN))
      arr_ppv.append(ppv(TP, FP))
      arr_accuracy.append(accuracy(TP, TN, FP, FN))
      arr_sensitivity.append(sensitivity(TP, FN))



255967 4679 1498 0
257364 2208 2572 0
253916 2142 2753 3333
257099 517 2740 1788
257600 2975 1569 0
256514 502 725 4403
254981 6151 1012 0
258171 2160 144 1669
243861 8105 10178 0
258856 2120 1163 5
259323 2384 425 12
254931 5024 2107 82
257489 450 497 3708
251130 1028 1438 8548
253256 5973 179 2736
256178 2978 1611 1377
260155 434 713 842
252146 1499 3645 4854
250252 582 2586 8724
256679 442 815 4208
258225 1017 418 2484
256304 781 942 4117
258105 518 967 2554
256000 328 765 5051
259556 1410 216 962
251371 4561 3934 2278
253431 669 894 7150
255069 1698 460 4917
256156 1126 732 4130
260308 191 461 1184
251557 3120 740 6727
252090 1178 4741 4135
259109 267 566 2202
256796 2739 1013 1596
258057 500 876 2711
259977 328 332 1507
246572 226 13937 1409
254414 698 824 6208
256282 667 1122 4073
254055 378 3588 4123
253412 1790 6781 161
249537 5107 7500 0
246058 5500 1430 9156
258383 574 338 2849
244755 3913 13476 0
245594 3636 3533 9381
249093 9908 3116 27
252560 6923 363 2298
256357 3716 213 

242393 6126 3459 10166
244539 1554 3618 12433
240832 1158 4320 15834
233743 2367 8952 17082
240173 1136 3480 17355
243015 5795 2790 10544
248111 1004 3215 9814
256295 1180 1130 3539
253537 1069 1861 5677
257604 783 707 3050
253494 5350 769 2531
253437 959 1115 6633
247949 1236 3404 9555
248743 1396 1992 10013
232565 1888 16598 11093
253750 2364 1773 4257
257157 1333 1952 1702
246523 499 2874 12248
238407 1516 5338 16883
246234 1520 2253 12137
246271 1460 2216 12197
241269 1434 2491 16950
250427 5746 2213 3758
240622 2423 4315 14784
247020 1925 3237 9962
248794 1672 4869 6809
253028 634 590 7892
259542 1013 205 1384
256973 505 849 3817
254981 492 586 6085
258927 761 349 2107
256258 896 468 4522
258441 502 830 2371
258216 428 834 2666
257526 477 372 3769
260012 327 313 1492
259342 495 447 1860
260450 293 297 1104
258397 335 460 2952
259584 572 357 1631
257736 401 665 3342
259168 887 179 1910
255672 1149 399 4924
256091 601 673 4779
261044 188 267 645
258079 591 369 3105
255933 464 1040 4

260197 357 208 1382
258373 335 510 2926
259895 381 288 1580
260210 616 222 1096
260408 245 378 1113
260749 240 223 932
261050 244 222 628
260223 696 101 1124
258110 529 383 3122
257396 1121 250 3377
259249 790 187 1918
259960 824 162 1198
259679 408 240 1817
261131 300 149 564
261027 333 200 584
261169 312 137 526
261048 317 188 591
260387 291 270 1196
260297 412 275 1160
260916 486 169 573
258355 766 432 2591
257686 1075 310 3073
258011 929 401 2803
259714 589 222 1619
260071 317 305 1451
260352 398 246 1148
259524 543 253 1824
257324 736 1057 3027
258774 493 759 2118
257484 307 1919 2434
260647 193 397 907
257959 1036 513 2636
248524 705 1498 11417
258613 504 466 2561
253992 3415 479 4258
260609 441 293 801
260025 311 291 1517
260776 220 251 897
261142 201 215 586
261347 305 100 392
261125 322 136 561
260976 694 108 366
257425 782 500 3437
257712 1202 458 2772
256981 362 1234 3567
260871 590 142 541
260002 975 138 1029
260393 1246 132 373
261113 175 363 493
259301 520 283 2040
261033

In [9]:
print("Average dice: ", np.mean(arr_dice))
print("Average iou: ", np.mean(arr_iou))
print("Average ppv: ", np.mean(arr_ppv))
print("Average accuracy: ", np.mean(arr_accuracy))
print("Average sensitivity: ", np.mean(arr_sensitivity))

Average dice:  0.7237057968239025
Average iou:  0.5998337245761356
Average ppv:  0.723284043482044
Average accuracy:  0.9907876480457395
Average sensitivity:  0.7600248441450289
