# Imports


In [1]:
#Step 2: Imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from tqdm import tqdm
import segmentation_models_pytorch as smp


  from .autonotebook import tqdm as notebook_tqdm


# Dataset loading

In [2]:
#Step 3: Dataset Class (Brain Tumor Segmentation)
from PIL import Image
import os

class BrainTumorDataset(torch.utils.data.Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images = sorted([os.path.join(images_dir, f) for f in os.listdir(images_dir)])
        self.masks  = sorted([os.path.join(masks_dir, f) for f in os.listdir(masks_dir)])
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert("RGB")
        mask  = Image.open(self.masks[idx]).convert("L")

        if self.transform:
            image = self.transform(image)
            mask  = self.transform(mask)

        return image, mask

In [3]:
#Step 4: Data Preparation
transform = transforms.Compose([
    transforms.Resize((512,512)),
    transforms.ToTensor()
])

# Data loader

In [4]:


train_images = "/home/readinggroup/Desktop/Image_proc_Noman/CSE465_project/dataset/segmentation_task/train/images"
train_masks  = "/home/readinggroup/Desktop/Image_proc_Noman/CSE465_project/dataset/segmentation_task/train/masks"

dataset = BrainTumorDataset(train_images, train_masks, transform=transform)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")


Train: 6292, Val: 1574


# Swin-Unet architecture

In [5]:
#Step 5: Swin U-Net Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = smp.Unet(
    encoder_name="resnet34", # Swin Transformer
    encoder_weights="imagenet",
    in_channels=3,
    classes=1
).to(device)

# Loss and optimizer

In [6]:

loss_fn = smp.losses.DiceLoss(mode='binary')
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

# Training loop

In [7]:
#Step 6: Training Loop
num_epochs = 20
best_val_loss = float('inf')
arr_loss = []
arr_val_loss = []

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, masks in tqdm(train_loader):
        images, masks = images.to(device), masks.to(device)
        preds = model(images)
        loss = loss_fn(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)
    arr_loss.append(train_loss)
    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            preds = model(images)
            loss = loss_fn(preds, masks)
            val_loss += loss.item()
    val_loss /= len(val_loader)
    arr_val_loss.append(val_loss)
    
    print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "./best_swinunet.pth")

100%|███████████████████████████████████████| 1573/1573 [01:22<00:00, 19.07it/s]


Epoch [1/20] Train Loss: 0.3611 | Val Loss: 0.2464


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.35it/s]


Epoch [2/20] Train Loss: 0.2511 | Val Loss: 0.2265


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.24it/s]


Epoch [3/20] Train Loss: 0.2221 | Val Loss: 0.2589


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.29it/s]


Epoch [4/20] Train Loss: 0.2047 | Val Loss: 0.2834


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.37it/s]


Epoch [5/20] Train Loss: 0.2031 | Val Loss: 0.2335


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.23it/s]


Epoch [6/20] Train Loss: 0.1916 | Val Loss: 0.2295


100%|███████████████████████████████████████| 1573/1573 [01:22<00:00, 19.18it/s]


Epoch [7/20] Train Loss: 0.1861 | Val Loss: 0.2202


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.25it/s]


Epoch [8/20] Train Loss: 0.1759 | Val Loss: 0.2144


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.22it/s]


Epoch [9/20] Train Loss: 0.1680 | Val Loss: 0.1679


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.24it/s]


Epoch [10/20] Train Loss: 0.1570 | Val Loss: 0.1592


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.25it/s]


Epoch [11/20] Train Loss: 0.1572 | Val Loss: 0.1774


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.31it/s]


Epoch [12/20] Train Loss: 0.1587 | Val Loss: 0.1581


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.37it/s]


Epoch [13/20] Train Loss: 0.1429 | Val Loss: 0.1725


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.23it/s]


Epoch [14/20] Train Loss: 0.1420 | Val Loss: 0.1565


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.22it/s]


Epoch [15/20] Train Loss: 0.1536 | Val Loss: 0.1573


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.26it/s]


Epoch [16/20] Train Loss: 0.1457 | Val Loss: 0.1511


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.21it/s]


Epoch [17/20] Train Loss: 0.1271 | Val Loss: 0.1505


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.33it/s]


Epoch [18/20] Train Loss: 0.1356 | Val Loss: 0.1725


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.29it/s]


Epoch [19/20] Train Loss: 0.1288 | Val Loss: 0.1424


100%|███████████████████████████████████████| 1573/1573 [01:21<00:00, 19.28it/s]


Epoch [20/20] Train Loss: 0.1218 | Val Loss: 0.1429


In [8]:
print(arr_loss)
print(arr_val_loss)

[0.361053874404735, 0.25110906032870245, 0.222121250758259, 0.20470547463889344, 0.20305545233043226, 0.1915684346173572, 0.18609065154670382, 0.17589696766234322, 0.16803134829090116, 0.15703129582699135, 0.1572187702821003, 0.15868924172608168, 0.14285741315400305, 0.14202350139314782, 0.15360473805904692, 0.14567229111938815, 0.12714307723497026, 0.13559026784072314, 0.12880292172956376, 0.12182841358560706]
[0.24637850090331836, 0.22649635094676526, 0.2589454066934924, 0.28340094010842026, 0.2335007553778324, 0.22945683377648368, 0.22019361118374742, 0.21442013646140318, 0.1679384458791181, 0.15922598260913404, 0.17739438072679006, 0.15812735796579855, 0.17250827333043675, 0.15649576371696394, 0.15734674800471002, 0.15112149896960572, 0.15053618181175388, 0.17245382963098246, 0.14239043284793795, 0.14292416188317508]
