# Live Demo: Training a 3D UNet with MONAI

## Set up environment

In [None]:
!pip -q install "monai-weekly" "torch>=2.1" "tqdm"

In [None]:
import monai, torch, os, tempfile, matplotlib.pyplot as plt
from monai.data import DataLoader, CacheDataset
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric

## Download the dataset

In [None]:

from monai.apps import download_and_extract
root_dir = tempfile.mkdtemp()
resource = 'https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar'
compressed_file = os.path.join(root_dir, "spleen.tar")
download_and_extract(resource, compressed_file, root_dir)


## Prepare the dataset

In [None]:
from monai.data import Dataset
from glob import glob
images = sorted(glob(os.path.join(root_dir,"Task09_Spleen/imagesTr/*.nii.gz")))
labels = sorted(glob(os.path.join(root_dir,"Task09_Spleen/labelsTr/*.nii.gz")))
train_files = [{"image":img, "label":lbl} for img,lbl in zip(images,labels)]

print(f"Number of files: {len(train_files)}")
print(train_files[0])


In [None]:
from monai import transforms as T

def make_transform(n_pixels=64):
    # Steps to transform the file information to model inputs:

    transforms = []

    # step 1: load the data, nii.gz format to tensor
    transforms.append(
        T.LoadImaged(keys=['image', 'label'])
    )

    # step 2: Add an extra "channel" dimension (pytorch/monai convention)
    transforms.append(
        T.EnsureChannelFirstd(keys=['image', 'label'])
    )

    # step 3: Resize the data to a uniform size
    transforms.append(
        T.ResizeD(keys=['image', 'label'], spatial_size=(n_pixels, n_pixels, n_pixels//2), mode=['bilinear', 'nearest'])
    )

    # step 4: rescale the image intenisty between 0 and 1
    transforms.append(T.ScaleIntensityD(keys=['image']))

    transform = T.Compose(transforms)
    return transform


transform = make_transform(256)

In [None]:
# look at the output
data = transform(train_files[0])
fig, ax = plt.subplots(1, 2)
ax[0].imshow(data['image'][0, ..., 75])
ax[1].imshow(data['image'][0, ..., 75])
ax[1].imshow(data['label'][0, ..., 75], alpha=0.5)
fig.tight_layout()

image = data['image']
label = data['label']
print(f"pixel mean: {image.mean()}")
print(f"pixel std: {image.std()}")
print(f"Image shape: {image.shape}")
print(f"Label shape: {label.shape}")
print(f"Label values: {torch.unique(label)}")

In [None]:
# create dataset and dataloader
train_ds = train_dataset = monai.data.CacheDataset(train_files, transform=make_transform(64), cache_rate=1)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=4, shuffle=True)

batch = next(iter(train_loader))
image = batch['image']
label = batch['label']
print(f'Image shape: {image.shape}')
print(f'Label shape: {label.shape}')


## Prepare the model for training

In [None]:
# define a unet using monai

model = UNet(spatial_dims=3, in_channels=1, out_channels=2,
             channels=(16,32,64,128), strides=(2,2,2),
             num_res_units=2)

n_params = sum(p.numel() for p in model.parameters())
print(f"Model has {n_params} parameters.")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device to run calculations - "cuda" means gpu
print(f"Using device {device}")
model.to(device) # convert model to the correct device
loss_fn = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)


In [None]:
from tqdm import tqdm


max_epochs = 100
for epoch in range(max_epochs):
    model.train()
    epoch_loss = 0
    for batch_data in tqdm(train_loader, desc=f"Epoch {epoch}", leave=False):
        inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"Epoch {epoch+1}, avg loss: {epoch_loss/len(train_loader):.4f}")


In [None]:
model.eval()
with torch.no_grad():
    sample = train_ds[0]
    input_volume = sample["image"].unsqueeze(0).to(device)
    pred = torch.argmax(model(input_volume), dim=1).cpu()[0]

import numpy as np, matplotlib.pyplot as plt
mid_slice = pred.shape[-1]//2
plt.figure(figsize=(12,4))
plt.subplot(1,3,1); plt.imshow(input_volume.cpu()[0,0,:, :, mid_slice], cmap='gray'); plt.title('Image'); plt.axis('off')
plt.subplot(1,3,2); plt.imshow(sample["label"][0, :, :, mid_slice], cmap='gray'); plt.title('Ground truth'); plt.axis('off')
plt.subplot(1,3,3); plt.imshow(pred[:, :, mid_slice], cmap='gray'); plt.title('Prediction'); plt.axis('off')
plt.show()
