<a href="https://colab.research.google.com/github/rcbusinesstechlab/realtime-face-recognition/blob/main/KAN_UNet_BrainTumor_Segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# KAN-UNet Brain Tumor Segmentation
This notebook implements a full pipeline for brain tumor segmentation using a U-Net enhanced with Kolmogorov-Arnold Networks (KAN).

In [None]:
!pip install monai nibabel wandb

In [None]:
import os
import torch
import numpy as np
import nibabel as nib
import monai
from monai.transforms import *
from monai.networks.nets import UNet
from monai.metrics import compute_meandice, compute_hausdorff_distance
from monai.data import DataLoader, Dataset, decollate_batch
from monai.losses import DiceLoss
import wandb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Sample data paths (replace with actual)
train_files = [{"image": "path_to_image.nii.gz", "label": "path_to_label.nii.gz"}]

# Define transforms
train_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    AddChanneld(keys=["image", "label"]),
    Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    ScaleIntensityRanged(keys=["image"], a_min=-1000, a_max=1000, b_min=0.0, b_max=1.0),
    ToTensord(keys=["image", "label"])
])

train_ds = Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True)

In [None]:
# SimpleKAN assumed to be available as simplekan.py
from simplekan import SimpleKAN

class KAN_UNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.unet = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=3,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2
        )
        self.kan = SimpleKAN(input_size=256, output_size=256)

    def forward(self, x):
        x = self.unet(x)
        b, c, *dims = x.shape
        x_flat = x.view(b, c, -1).permute(0, 2, 1)
        x_out = self.kan(x_flat)
        x_out = x_out.permute(0, 2, 1).view(b, c, *dims)
        return x_out

In [None]:
# Training loop
model = KAN_UNet().to(device)
loss_fn = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(50):
    model.train()
    for batch in train_loader:
        images, labels = batch["image"].to(device), batch["label"].to(device)
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print(f"Epoch {epoch+1} Loss: {loss.item():.4f}")

In [None]:
# Evaluation
model.eval()
dice_scores = []
hd95_scores = []

with torch.no_grad():
    for batch in train_loader:
        images, labels = batch["image"].to(device), batch["label"].to(device)
        outputs = torch.argmax(model(images), dim=1, keepdim=True)
        labels = labels[:, 0:1, ...]

        dice = compute_meandice(y_pred=outputs, y=labels, include_background=False)
        hd95 = compute_hausdorff_distance(y_pred=outputs, y=labels, percentile=95.0)

        dice_scores.append(dice.cpu().numpy())
        hd95_scores.append(hd95.cpu().numpy())

print("Average Dice:", np.mean(dice_scores))
print("Average HD95:", np.mean(hd95_scores))

In [None]:
# W&B logging
wandb.init(project="kan_unet_segmentation")
wandb.log({
    "dice_avg": float(np.mean(dice_scores)),
    "hd95_avg": float(np.mean(hd95_scores))
})