In [1]:
import os
import torch
import torch.nn as nn
from torch.optim import Adam
from monai.networks.nets import UNet
from monai.data import Dataset, DataLoader, pad_list_data_collate
from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, ToTensord, ResizeD
import numpy as np
from scipy.ndimage import label

# ------------------------------------------------
# 1. Prepare the 2D Datasets and DataLoaders
# ------------------------------------------------

train_2d_dir = './dataset/train_2d'
val_2d_dir = './dataset/val_2d'

train_data_2d = []
for patient in os.listdir(train_2d_dir):
    patient_path = os.path.join(train_2d_dir, patient)
    if os.path.isdir(patient_path):
        for file in os.listdir(patient_path):
            if file.endswith('.nii') and '_gt' not in file:
                image_path = os.path.join(patient_path, file)
                if '_slice' in file:
                    parts = file.split('_slice')
                    gt_file = parts[0] + '_gt_slice' + parts[1]
                else:
                    gt_file = file.replace('.nii', '_gt.nii')
                gt_path = os.path.join(patient_path, gt_file)
                if os.path.exists(gt_path):
                    train_data_2d.append({"image": image_path, "label": gt_path})
                else:
                    print("Warning: No corresponding ground truth for", image_path)
print(f"Found {len(train_data_2d)} 2D training samples.")

val_data_2d = []
for patient in os.listdir(val_2d_dir):
    patient_path = os.path.join(val_2d_dir, patient)
    if os.path.isdir(patient_path):
        for file in os.listdir(patient_path):
            if file.endswith('.nii') and '_gt' not in file:
                image_path = os.path.join(patient_path, file)
                if '_slice' in file:
                    parts = file.split('_slice')
                    gt_file = parts[0] + '_gt_slice' + parts[1]
                else:
                    gt_file = file.replace('.nii', '_gt.nii')
                gt_path = os.path.join(patient_path, gt_file)
                if os.path.exists(gt_path):
                    val_data_2d.append({"image": image_path, "label": gt_path})
                else:
                    print("Warning: No corresponding ground truth for", image_path)
print(f"Found {len(val_data_2d)} 2D validation samples.")

transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    ResizeD(keys=["image", "label"], spatial_size=(352, 352)),
    ToTensord(keys=["image", "label"]),
])

train_ds_2d = Dataset(data=train_data_2d, transform=transforms)
val_ds_2d = Dataset(data=val_data_2d, transform=transforms)

train_loader_2d = DataLoader(train_ds_2d, batch_size=10, shuffle=True, collate_fn=pad_list_data_collate)
val_loader_2d = DataLoader(val_ds_2d, batch_size=10, shuffle=False, collate_fn=pad_list_data_collate)

# ------------------------------------------------
# 2. Build the 2D U-Net Model and Training Setup
# ------------------------------------------------

in_channels = 1
out_channels = 4
initial_features = 48

channels = (initial_features, initial_features*2, initial_features*4, initial_features*8)
strides = (2, 2, 2)
num_res_units = 2

model = UNet(
    spatial_dims=2,
    in_channels=in_channels,
    out_channels=out_channels,
    channels=channels,
    strides=strides,
    num_res_units=num_res_units,
    norm="batch",
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# ------------------------------------------------
# Weighted Cross-Entropy Loss Setup
# ------------------------------------------------
class_weights = torch.tensor([0.2, 0.3, 0.3, 0.2], device=device, dtype=torch.float)
loss_function = nn.CrossEntropyLoss(weight=class_weights)
optimizer = Adam(model.parameters(), lr=1e-4)

# ------------------------------------------------
# 3. Utility: Keep Largest Connected Component
# ------------------------------------------------
def keep_largest_component(pred_mask):
    """
    Given a 2D predicted mask (H, W) with integer labels,
    return a new mask where for each non-background class (label != 0),
    only the largest connected component is kept.
    """
    output = np.zeros_like(pred_mask)
    for cls in np.unique(pred_mask):
        if cls == 0:
            continue  # skip background
        # create binary mask for this class
        binary_mask = (pred_mask == cls).astype(np.int32)
        labeled_array, num_features = label(binary_mask)
        if num_features == 0:
            continue
        sizes = np.bincount(labeled_array.ravel())
        sizes[0] = 0  # ignore background
        largest_label = sizes.argmax()
        largest_component = (labeled_array == largest_label)
        output[largest_component] = cls
    return output

# ------------------------------------------------
# 4. Training Loop with Validation Postprocessing
# ------------------------------------------------

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    for batch_data in train_loader_2d:
        inputs = batch_data["image"].to(device)   # (B, 1, 352, 352)
        labels = batch_data["label"].to(device)     # (B, 1, 352, 352)
        labels = labels.squeeze(1).long()
        
        optimizer.zero_grad()
        outputs = model(inputs)                     # (B, 4, 352, 352)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(train_loader_2d)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Training Loss: {avg_loss:.4f}")
    
    # Validation loop with postprocessing
    model.eval()
    with torch.no_grad():
        val_loss = 0.0
        for val_data in val_loader_2d:
            val_inputs = val_data["image"].to(device)
            val_labels = val_data["label"].to(device).squeeze(1).long()
            val_outputs = model(val_inputs)
            loss_val = loss_function(val_outputs, val_labels)
            val_loss += loss_val.item()
            
            # Obtain predictions (before postprocessing).
            preds_batch = torch.argmax(val_outputs, dim=1)  # shape: (B, 352, 352)
            
            # Now apply postprocessing per sample: keep largest connected component for each class.
            preds_np = preds_batch.cpu().numpy()
            processed_preds = []
            for pred in preds_np:
                # Process each 2D slice individually.
                processed_pred = keep_largest_component(pred)
                processed_preds.append(processed_pred)
            # You can further evaluate the processed_preds if desired.
            
        avg_val_loss = val_loss / len(val_loader_2d)
        print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {avg_val_loss:.4f}")

print("Training complete.")


2025-02-24 20:56:29.422756: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-24 20:56:29.461367: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-24 20:56:29.461382: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-24 20:56:29.461404: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-24 20:56:29.469201: I tensorflow/core/platform/cpu_feature_g

Found 1526 2D training samples.
Found 376 2D validation samples.
Epoch 1/10, Average Training Loss: 1.0136
Epoch 1/10, Validation Loss: 0.8631
Epoch 2/10, Average Training Loss: 0.7721
Epoch 2/10, Validation Loss: 0.6838
Epoch 3/10, Average Training Loss: 0.5660
Epoch 3/10, Validation Loss: 0.4691
Epoch 4/10, Average Training Loss: 0.3741
Epoch 4/10, Validation Loss: 0.3181
Epoch 5/10, Average Training Loss: 0.2565
Epoch 5/10, Validation Loss: 0.2333
Epoch 6/10, Average Training Loss: 0.1887
Epoch 6/10, Validation Loss: 0.1742
Epoch 7/10, Average Training Loss: 0.1463
Epoch 7/10, Validation Loss: 0.1461
Epoch 8/10, Average Training Loss: 0.1170
Epoch 8/10, Validation Loss: 0.1254
Epoch 9/10, Average Training Loss: 0.0966
Epoch 9/10, Validation Loss: 0.1068
Epoch 10/10, Average Training Loss: 0.0827
Epoch 10/10, Validation Loss: 0.0869
Training complete.


In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from monai.metrics import DiceMetric, HausdorffDistanceMetric

# 1. Save the model checkpoint.
torch.save(model.state_dict(), "2d_unet_checkpoint_first_model_celoss_largestcomponent.pth")
print("Model checkpoint saved.")

# 2. Visualize predictions on a validation batch.
model.eval()
with torch.no_grad():
    batch = next(iter(val_loader_2d))
    inputs = batch["image"].to(device)
    labels = batch["label"].to(device)
    outputs = model(inputs)
    # Get the predicted segmentation as the argmax over the channel dimension.
    preds = torch.argmax(outputs, dim=1).cpu().numpy()

# Move inputs and labels to CPU for visualization.
inputs = inputs.cpu().numpy()   # shape: (B, 1, H, W)
labels = labels.cpu().numpy()   # shape: (B, 1, H, W)

# Number of samples to display.
num_to_show = min(6, inputs.shape[0])
fig, axs = plt.subplots(3, num_to_show, figsize=(4*num_to_show, 12))
for i in range(num_to_show):
    # Input image.
    axs[0, i].imshow(inputs[i, 0, :, :], cmap="gray")
    axs[0, i].set_title(f"Input {i}")
    axs[0, i].axis("off")
    # Ground truth.
    axs[1, i].imshow(labels[i, 0, :, :], cmap="jet")
    axs[1, i].set_title(f"Ground Truth {i}")
    axs[1, i].axis("off")
    # Prediction.
    axs[2, i].imshow(preds[i, :, :], cmap="jet")
    axs[2, i].set_title(f"Prediction {i}")
    axs[2, i].axis("off")
plt.tight_layout()
plt.show()

# 3. Evaluate the model on the full validation set.
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
hd_metric = HausdorffDistanceMetric(include_background=True, percentile=95)

total_correct = 0
total_pixels = 0

model.eval()
with torch.no_grad():
    for val_batch in val_loader_2d:
        val_inputs = val_batch["image"].to(device)
        val_labels = val_batch["label"].to(device)
        val_outputs = model(val_inputs)
        # Obtain predictions.
        preds_batch = torch.argmax(val_outputs, dim=1, keepdim=True)  # shape: (B, 1, H, W)
        # Update metrics.
        dice_metric(y_pred=preds_batch, y=val_labels)
        hd_metric(y_pred=preds_batch, y=val_labels)
        # Pixel-wise accuracy.
        total_correct += (preds_batch == val_labels).sum().item()
        total_pixels += torch.numel(val_labels)

dice_score = dice_metric.aggregate().item()
hd_score = hd_metric.aggregate().item()
accuracy = total_correct / total_pixels

print(f"Dice Score: {dice_score:.4f}")
print(f"Hausdorff Distance (95th percentile): {hd_score:.4f}")
print(f"Pixel-wise Accuracy: {accuracy:.4f}")


2025-02-24 21:56:25.665484: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-24 21:56:25.703519: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-24 21:56:25.703539: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-24 21:56:25.703556: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-24 21:56:25.710885: I tensorflow/core/platform/cpu_feature_g

NameError: name 'model' is not defined