# Imports


In [1]:

import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch

# Choose GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


# Dataset loading

In [2]:

class BrainTumorDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.images = sorted(os.listdir(image_dir))
        self.masks = sorted(os.listdir(mask_dir))
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.masks[idx])

        # Load grayscale image and mask
        image = Image.open(img_path).convert("L")
        mask = Image.open(mask_path).convert("L")
        #img = img.resize((512, 512), Image.BILINEAR)
        #mask = mask.resize((512, 512), Image.NEAREST)
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        mask = torch.where(mask > 0, 1.0, 0.0)
        return image, mask



In [3]:
# 4. Transforms
# -----------------------------
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),  # This will create 1-channel tensor
])


# Data loader

In [4]:

# Change the train images and masks
train_images = "/home/readinggroup/Desktop/Image_proc_Noman/CSE465_project/dataset/segmentation_task/train/images"
train_masks  = "/home/readinggroup/Desktop/Image_proc_Noman/CSE465_project/dataset/segmentation_task/train/masks"

full_dataset = BrainTumorDataset(train_images, train_masks, transform=transform)

# Split: 80% train, 20% validation
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

print(f"Train samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

Train samples: 6292, Validation samples: 1574


# UNet++ architecture

In [5]:
# 5. U-Net++ Model (simplified)
# -----------------------------
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class UNetPlusPlus(nn.Module):
    def __init__(self, in_ch=1, out_ch=1):
        super().__init__()
        filters = [64, 128, 256, 512]

        # Encoder
        self.conv0_0 = ConvBlock(in_ch, filters[0])
        self.pool0 = nn.MaxPool2d(2)
        self.conv1_0 = ConvBlock(filters[0], filters[1])
        self.pool1 = nn.MaxPool2d(2)
        self.conv2_0 = ConvBlock(filters[1], filters[2])
        self.pool2 = nn.MaxPool2d(2)
        self.conv3_0 = ConvBlock(filters[2], filters[3])

        # Decoder with nested connections
        self.up2_1 = nn.ConvTranspose2d(filters[3], filters[2], 2, stride=2)
        self.conv2_1 = ConvBlock(filters[2]*2, filters[2])

        self.up1_2 = nn.ConvTranspose2d(filters[2], filters[1], 2, stride=2)
        self.conv1_2 = ConvBlock(filters[1]*2, filters[1])

        self.up0_3 = nn.ConvTranspose2d(filters[1], filters[0], 2, stride=2)
        self.conv0_3 = ConvBlock(filters[0]*2, filters[0])

        # Final output
        self.final = nn.Conv2d(filters[0], out_ch, 1)

    def forward(self, x):
        # Encoder
        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(self.pool0(x0_0))
        x2_0 = self.conv2_0(self.pool1(x1_0))
        x3_0 = self.conv3_0(self.pool2(x2_0))

        # Decoder
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up2_1(x3_0)], dim=1))
        x1_2 = self.conv1_2(torch.cat([x1_0, self.up1_2(x2_1)], dim=1))
        x0_3 = self.conv0_3(torch.cat([x0_0, self.up0_3(x1_2)], dim=1))

        out = torch.sigmoid(self.final(x0_3))
        return out

model = UNetPlusPlus().to(device)

# Loss and optimizer

In [6]:

def dice_loss(pred, target, smooth=1.):
    pred = pred.contiguous()
    target = target.contiguous()
    intersection = (pred * target).sum(dim=2).sum(dim=2)
    loss = 1 - ((2. * intersection + smooth) /
                (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth))
    return loss.mean()

optimizer = torch.optim.Adam(model.parameters(), lr= 3e-4)

# Training loop

In [7]:

num_epochs = 20
arr_loss = []
arr_val_loss = []
best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, masks in tqdm(train_loader):
        images, masks = images.to(device), masks.to(device)
        preds = model(images)
        loss = dice_loss(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)
    arr_loss.append(train_loss)
    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            preds = model(images)
            loss = dice_loss(preds, masks)
            val_loss += loss.item()
    val_loss /= len(val_loader)
    arr_val_loss.append(val_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "./best_unetpp.pth")

100%|█████████████████████████████████████████| 787/787 [01:19<00:00,  9.84it/s]


Epoch [1/20] Train Loss: 0.6184 | Val Loss: 0.4427


100%|█████████████████████████████████████████| 787/787 [01:18<00:00,  9.99it/s]


Epoch [2/20] Train Loss: 0.3481 | Val Loss: 0.3529


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.03it/s]


Epoch [3/20] Train Loss: 0.3125 | Val Loss: 0.2882


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.03it/s]


Epoch [4/20] Train Loss: 0.2828 | Val Loss: 0.3059


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.02it/s]


Epoch [5/20] Train Loss: 0.2666 | Val Loss: 0.2742


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.06it/s]


Epoch [6/20] Train Loss: 0.2524 | Val Loss: 0.2601


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.00it/s]


Epoch [7/20] Train Loss: 0.2450 | Val Loss: 0.2410


100%|█████████████████████████████████████████| 787/787 [01:18<00:00,  9.98it/s]


Epoch [8/20] Train Loss: 0.2332 | Val Loss: 0.2660


100%|█████████████████████████████████████████| 787/787 [01:19<00:00,  9.96it/s]


Epoch [9/20] Train Loss: 0.2226 | Val Loss: 0.2499


100%|█████████████████████████████████████████| 787/787 [01:18<00:00,  9.99it/s]


Epoch [10/20] Train Loss: 0.2093 | Val Loss: 0.2217


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.04it/s]


Epoch [11/20] Train Loss: 0.2160 | Val Loss: 0.2103


100%|█████████████████████████████████████████| 787/787 [01:19<00:00,  9.93it/s]


Epoch [12/20] Train Loss: 0.1983 | Val Loss: 0.2135


100%|█████████████████████████████████████████| 787/787 [01:19<00:00,  9.93it/s]


Epoch [13/20] Train Loss: 0.1919 | Val Loss: 0.2514


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.00it/s]


Epoch [14/20] Train Loss: 0.1864 | Val Loss: 0.2109


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.01it/s]


Epoch [15/20] Train Loss: 0.1823 | Val Loss: 0.2077


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.00it/s]


Epoch [16/20] Train Loss: 0.1768 | Val Loss: 0.1987


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.01it/s]


Epoch [17/20] Train Loss: 0.1751 | Val Loss: 0.1933


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.04it/s]


Epoch [18/20] Train Loss: 0.1729 | Val Loss: 0.1964


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.03it/s]


Epoch [19/20] Train Loss: 0.1637 | Val Loss: 0.1842


100%|█████████████████████████████████████████| 787/787 [01:18<00:00, 10.03it/s]


Epoch [20/20] Train Loss: 0.1583 | Val Loss: 0.1819


In [8]:
print(arr_loss)
print(arr_val_loss)

[0.618406601013495, 0.3481415968667598, 0.3124854265034123, 0.2827809602495524, 0.26660745069289904, 0.25242166413323247, 0.24497065041805918, 0.2332423186305972, 0.22263491756769116, 0.20926080510645081, 0.21596734447279303, 0.1983425939287588, 0.19188927698127803, 0.18637375375480786, 0.18228953460461, 0.17684716698175162, 0.1750772531218874, 0.1729141300996195, 0.16368162277920273, 0.1583376938744603]
[0.4427055951756269, 0.35286171337220873, 0.28817695526756004, 0.30594591365247814, 0.2741560292274214, 0.26008611161091605, 0.2410309955023872, 0.26604929146579076, 0.24994958158071875, 0.22167358373476165, 0.2103497960845831, 0.2134701632303635, 0.25144485996913185, 0.21085874331632848, 0.20774140186267456, 0.1987439627347864, 0.1932508124388414, 0.19643148832817367, 0.1842288132473297, 0.18192721835247755]
