# Modified UNet Training and implementation

## Key Organ-Specific Modifications Highlighted

Loss function:

loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([5.0]))


## Penalizes missing organ voxels more than background

Handles class imbalance common in organ segmentation

BatchNorm included in all Conv blocks

Stabilizes MRI intensity variations

Helps faster convergence

Smaller network depth

Only 2 downsampling steps + bottleneck → enough for organ, reduces GPU memory usage

Sigmoid applied at inference

y = torch.sigmoid(y)
y = (y > 0.5).float()


Produces binary organ mask

## Files used for training are 4D files

Input -> input.nii (4D MRI file)

Output -> output.nii (4D stomach organ segmented file)

## Saved files

- unet_trained.pth - This is the trained UNet and can be used for segmenting the stomach
- predicted_segmentation.nii - This is a 4D file to show how well the UNet perfomed



In [1]:
import nibabel as nib
import torch
import numpy as np

# Load nii files
img = nib.load("input.nii").get_fdata()     # shape: (X,Y,Z,T)
seg = nib.load("output.nii").get_fdata()    # shape: (X,Y,Z,T)

print("Input shape:", img.shape)
print("Seg shape:  ", seg.shape)

# Normalize MRI for stability
img = (img - np.mean(img)) / (np.std(img) + 1e-8)

# Convert to torch
img = torch.tensor(img, dtype=torch.float32)
seg = torch.tensor(seg, dtype=torch.float32)

# Add channel dimension → shape becomes (T, 1, X, Y, Z)
img = img.permute(3,0,1,2).unsqueeze(1)
seg = seg.permute(3,0,1,2).unsqueeze(1)

print("Torch input:", img.shape)
print("Torch seg:  ", seg.shape)

Input shape: (192, 192, 60, 176)
Seg shape:   (192, 192, 60, 176)
Torch input: torch.Size([176, 1, 192, 192, 60])
Torch seg:   torch.Size([176, 1, 192, 192, 60])


In [2]:
import torch.nn as nn

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True),

            nn.Conv3d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.net(x)

class UNet3D(nn.Module):
    def __init__(self):
        super().__init__()
        self.d1 = DoubleConv(1, 32)
        self.p1 = nn.MaxPool3d(2)

        self.d2 = DoubleConv(32, 64)
        self.p2 = nn.MaxPool3d(2)

        self.d3 = DoubleConv(64, 128)

        self.u2 = nn.ConvTranspose3d(128, 64, 2, stride=2)
        self.db2 = DoubleConv(128, 64)

        self.u1 = nn.ConvTranspose3d(64, 32, 2, stride=2)
        self.db1 = DoubleConv(64, 32)

        # final 1-channel prediction
        self.final = nn.Conv3d(32, 1, kernel_size=1)

    def forward(self, x):
        c1 = self.d1(x)
        p1 = self.p1(c1)

        c2 = self.d2(p1)
        p2 = self.p2(c2)

        bottleneck = self.d3(p2)

        u2 = self.u2(bottleneck)
        merge2 = torch.cat([u2, c2], dim=1)
        c2d = self.db2(merge2)

        u1 = self.u1(c2d)
        merge1 = torch.cat([u1, c1], dim=1)
        c1d = self.db1(merge1)

        # output = torch.sigmoid(self.final(c1d))
        output = self.final(c1d)   # no sigmoid here
        return output

In [3]:
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

dataset = TensorDataset(img, seg)
loader = DataLoader(dataset, batch_size=2, shuffle=True)

model = UNet3D().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
#loss_fn = nn.BCELoss()
loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([5.0]).to(device))

epochs = 10

Using: cuda


In [4]:
for epoch in range(epochs):
    model.train()
    running = 0.0
    
    for batch_img, batch_seg in tqdm(loader):
        batch_img = batch_img.to(device)
        batch_seg = batch_seg.to(device)

        optimizer.zero_grad()
        pred = model(batch_img)
        loss = loss_fn(pred, batch_seg)
        loss.backward()
        optimizer.step()

        running += loss.item()

    print(f"Epoch {epoch+1} Loss = {running/len(loader):.4f}")


# ------------------------------------------------
# SAVE TRAINED UNET MODEL
# ------------------------------------------------
torch.save(model.state_dict(), "unet_trained.pth")
print("✅ Model saved as unet_trained.pth")

100%|██████████████████████████████████████████████████████████████████████████████████| 88/88 [17:40<00:00, 12.05s/it]


Epoch 1 Loss = 0.4111


100%|██████████████████████████████████████████████████████████████████████████████████| 88/88 [17:42<00:00, 12.07s/it]


Epoch 2 Loss = 0.3323


100%|██████████████████████████████████████████████████████████████████████████████████| 88/88 [17:41<00:00, 12.07s/it]


Epoch 3 Loss = 0.2928


100%|██████████████████████████████████████████████████████████████████████████████████| 88/88 [17:41<00:00, 12.07s/it]


Epoch 4 Loss = 0.2597


100%|██████████████████████████████████████████████████████████████████████████████████| 88/88 [17:41<00:00, 12.07s/it]


Epoch 5 Loss = 0.2349


100%|██████████████████████████████████████████████████████████████████████████████████| 88/88 [17:41<00:00, 12.06s/it]


Epoch 6 Loss = 0.2145


100%|██████████████████████████████████████████████████████████████████████████████████| 88/88 [17:41<00:00, 12.06s/it]


Epoch 7 Loss = 0.1957


100%|██████████████████████████████████████████████████████████████████████████████████| 88/88 [17:41<00:00, 12.06s/it]


Epoch 8 Loss = 0.1792


100%|██████████████████████████████████████████████████████████████████████████████████| 88/88 [17:41<00:00, 12.06s/it]


Epoch 9 Loss = 0.1640


100%|██████████████████████████████████████████████████████████████████████████████████| 88/88 [17:41<00:00, 12.06s/it]

Epoch 10 Loss = 0.1506
✅ Model saved as unet_trained.pth





In [5]:
model.eval()
preds = []

with torch.no_grad():
    for t in range(img.shape[0]):
        X = img[t:t+1].to(device)
        y = model(X)               # raw logits
        y = torch.sigmoid(y)       # <-- ADD HERE
        y = (y > 0.5).float()      # <-- AND HERE
        preds.append(y.cpu())

pred_4d = torch.cat(preds, dim=0)          # (T,1,X,Y,Z)
pred_4d = pred_4d.squeeze(1).permute(1,2,3,0)  # (X,Y,Z,T)
pred_np = pred_4d.numpy()

In [6]:
nii = nib.Nifti1Image(pred_np, np.eye(4))
nib.save(nii, "predicted_segmentation.nii")

print("Saved predicted_segmentation.nii")

Saved predicted_segmentation.nii
