# Imports


In [None]:
#Step 2: Imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from tqdm import tqdm
import segmentation_models_pytorch as smp


# Dataset loading

In [None]:
#Step 3: Dataset Class (Brain Tumor Segmentation)
from PIL import Image
import os

class BrainTumorDataset(torch.utils.data.Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images = sorted([os.path.join(images_dir, f) for f in os.listdir(images_dir)])
        self.masks  = sorted([os.path.join(masks_dir, f) for f in os.listdir(masks_dir)])
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert("RGB")
        mask  = Image.open(self.masks[idx]).convert("L")

        if self.transform:
            image = self.transform(image)
            mask  = self.transform(mask)

        return image, mask

In [None]:
#Step 4: Data Preparation
transform = transforms.Compose([
    # transforms.Resize((224,224)),
    transforms.Resize((512,512)),
    transforms.ToTensor()
])

# Data loader

In [None]:


train_images = ""
train_masks  = ""

dataset = BrainTumorDataset(train_images, train_masks, transform=transform)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")


# Swin-Unet architecture

In [None]:
#Step 5: Swin U-Net Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = smp.Unet(
    encoder_name="resnet34", # Swin Transformer
    encoder_weights="imagenet",
    in_channels=3,
    classes=1
).to(device)

# Loss and optimizer

In [None]:

loss_fn = smp.losses.DiceLoss(mode='binary')
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

# Training loop

In [None]:
#Step 6: Training Loop
num_epochs = 10
best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, masks in tqdm(train_loader):
        images, masks = images.to(device), masks.to(device)
        preds = model(images)
        loss = loss_fn(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            preds = model(images)
            loss = loss_fn(preds, masks)
            val_loss += loss.item()
    val_loss /= len(val_loader)

    print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "./best_swinunet.pth")