# Imports


In [None]:

import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch

# Choose GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# Dataset loading

In [None]:

class BrainTumorDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.images = sorted(os.listdir(image_dir))
        self.masks = sorted(os.listdir(mask_dir))
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.masks[idx])

        # Load grayscale image and mask
        image = Image.open(img_path).convert("L")
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        mask = torch.where(mask > 0, 1.0, 0.0)
        return image, mask



In [None]:
# 4. Transforms
# -----------------------------
transform = transforms.Compose([
    transforms.ToTensor(),  # This will create 1-channel tensor
])


# Data loader

In [None]:

# Change the train images and masks
train_images = ""
train_masks  = ""

full_dataset = BrainTumorDataset(train_images, train_masks, transform=transform)

# Split: 80% train, 20% validation
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

print(f"Train samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

# UNet++ architecture

In [None]:
# 5. U-Net++ Model (simplified)
# -----------------------------
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class UNetPlusPlus(nn.Module):
    def __init__(self, in_ch=1, out_ch=1):
        super().__init__()
        filters = [64, 128, 256, 512]

        # Encoder
        self.conv0_0 = ConvBlock(in_ch, filters[0])
        self.pool0 = nn.MaxPool2d(2)
        self.conv1_0 = ConvBlock(filters[0], filters[1])
        self.pool1 = nn.MaxPool2d(2)
        self.conv2_0 = ConvBlock(filters[1], filters[2])
        self.pool2 = nn.MaxPool2d(2)
        self.conv3_0 = ConvBlock(filters[2], filters[3])

        # Decoder with nested connections
        self.up2_1 = nn.ConvTranspose2d(filters[3], filters[2], 2, stride=2)
        self.conv2_1 = ConvBlock(filters[2]*2, filters[2])

        self.up1_2 = nn.ConvTranspose2d(filters[2], filters[1], 2, stride=2)
        self.conv1_2 = ConvBlock(filters[1]*2, filters[1])

        self.up0_3 = nn.ConvTranspose2d(filters[1], filters[0], 2, stride=2)
        self.conv0_3 = ConvBlock(filters[0]*2, filters[0])

        # Final output
        self.final = nn.Conv2d(filters[0], out_ch, 1)

    def forward(self, x):
        # Encoder
        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(self.pool0(x0_0))
        x2_0 = self.conv2_0(self.pool1(x1_0))
        x3_0 = self.conv3_0(self.pool2(x2_0))

        # Decoder
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up2_1(x3_0)], dim=1))
        x1_2 = self.conv1_2(torch.cat([x1_0, self.up1_2(x2_1)], dim=1))
        x0_3 = self.conv0_3(torch.cat([x0_0, self.up0_3(x1_2)], dim=1))

        out = torch.sigmoid(self.final(x0_3))
        return out

model = UNetPlusPlus().to(device)

# Loss and optimizer

In [None]:

def dice_loss(pred, target, smooth=1.):
    pred = pred.contiguous()
    target = target.contiguous()
    intersection = (pred * target).sum(dim=2).sum(dim=2)
    loss = 1 - ((2. * intersection + smooth) /
                (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth))
    return loss.mean()

optimizer = torch.optim.Adam(model.parameters(), lr= 3e-4)

# Training loop

In [None]:

num_epochs = 20
best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, masks in tqdm(train_loader):
        images, masks = images.to(device), masks.to(device)
        preds = model(images)
        loss = dice_loss(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            preds = model(images)
            loss = dice_loss(preds, masks)
            val_loss += loss.item()
    val_loss /= len(val_loader)

    print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "./best_unetpp.pth")