# Imports


In [5]:
#Step 2: Imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from tqdm import tqdm
import segmentation_models_pytorch as smp


# Dataset loading

In [6]:
#Step 3: Dataset Class (Brain Tumor Segmentation)
from PIL import Image
import os

class BrainTumorDataset(torch.utils.data.Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images = sorted([os.path.join(images_dir, f) for f in os.listdir(images_dir)])
        self.masks  = sorted([os.path.join(masks_dir, f) for f in os.listdir(masks_dir)])
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert("RGB")
        mask  = Image.open(self.masks[idx]).convert("L")

        if self.transform:
            image = self.transform(image)
            mask  = self.transform(mask)

        return image, mask

In [7]:
#Step 4: Data Preparation
transform = transforms.Compose([
    transforms.Resize((512,512)),
    transforms.ToTensor()
])

# Data loader

In [8]:


train_images = "/home/readinggroup/Desktop/Image_proc_Noman/CSE465_project/dataset/segmentation_task/train/images"
train_masks  = "/home/readinggroup/Desktop/Image_proc_Noman/CSE465_project/dataset/segmentation_task/train/masks"

dataset = BrainTumorDataset(train_images, train_masks, transform=transform)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")


Train: 6292, Val: 1574


# Swin-Unet architecture

In [9]:
#Step 5: Swin U-Net Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = smp.Unet(
    encoder_name="resnet34", # Swin Transformer
    encoder_weights="imagenet",
    in_channels=3,
    classes=1
).to(device)

Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /home/readinggroup/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth
100%|██████████████████████████████████████| 83.3M/83.3M [00:07<00:00, 11.6MB/s]


# Loss and optimizer

In [10]:

loss_fn = smp.losses.DiceLoss(mode='binary')
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

# Training loop

In [11]:
#Step 6: Training Loop
num_epochs = 10
best_val_loss = float('inf')
arr_loss = []
arr_val_loss = []

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, masks in tqdm(train_loader):
        images, masks = images.to(device), masks.to(device)
        preds = model(images)
        loss = loss_fn(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)
    arr_loss.append(train_loss)
    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            preds = model(images)
            loss = loss_fn(preds, masks)
            val_loss += loss.item()
    val_loss /= len(val_loader)
    arr_val_loss.append(val_loss)
    
    print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "./best_swinunet.pth")

100%|███████████████████████████████████████| 1573/1573 [01:22<00:00, 18.96it/s]


Epoch [1/10] Train Loss: 0.3687 | Val Loss: 0.2850


100%|███████████████████████████████████████| 1573/1573 [01:22<00:00, 19.13it/s]


Epoch [2/10] Train Loss: 0.2536 | Val Loss: 0.2736


100%|███████████████████████████████████████| 1573/1573 [01:22<00:00, 19.18it/s]


Epoch [3/10] Train Loss: 0.2295 | Val Loss: 0.2615


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.23it/s]


Epoch [4/10] Train Loss: 0.2198 | Val Loss: 0.2523


100%|███████████████████████████████████████| 1573/1573 [01:22<00:00, 19.13it/s]


Epoch [5/10] Train Loss: 0.2105 | Val Loss: 0.2212


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.19it/s]


Epoch [6/10] Train Loss: 0.1938 | Val Loss: 0.2260


100%|███████████████████████████████████████| 1573/1573 [01:22<00:00, 19.16it/s]


Epoch [7/10] Train Loss: 0.1897 | Val Loss: 0.2604


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.21it/s]


Epoch [8/10] Train Loss: 0.1839 | Val Loss: 0.1921


100%|███████████████████████████████████████| 1573/1573 [01:22<00:00, 19.17it/s]


Epoch [9/10] Train Loss: 0.1722 | Val Loss: 0.2226


100%|███████████████████████████████████████| 1573/1573 [01:22<00:00, 19.14it/s]


Epoch [10/10] Train Loss: 0.1656 | Val Loss: 0.1558


In [12]:
print(arr_loss)
print(arr_val_loss)

[0.3687105577081461, 0.2536239789234023, 0.22948379843991346, 0.21980670494254198, 0.21050529319316963, 0.19377554990043325, 0.18969096011290942, 0.18391680327177198, 0.17218180337076497, 0.16560396432725016]
[0.2849577792404872, 0.2736116537285335, 0.2615335691096214, 0.2522834233826187, 0.22122373662624262, 0.22596219484576113, 0.2603836198748671, 0.19209829139225373, 0.22262895984698067, 0.1557987990415641]
