In [None]:
import torch
import pytorch_lightning as pl
from monai.networks.nets import UNETR
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.data import CacheDataset, DataLoader, load_decathlon_datalist
from monai.transforms import (
    AsDiscrete,
    Compose,
    EnsureChannelFirstd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    RandFlipd,
    RandRotate90d,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    Resized,
)
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, list_data_collate, decollate_batch, DataLoader
from monai.losses import DiceLoss
import pytorch_lightning as pl


In [None]:
import torch
import pytorch_lightning as pl
from monai.networks.nets import UNETR
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.data import CacheDataset, DataLoader, load_decathlon_datalist
from monai.transforms import (
    AsDiscrete,
    Compose,
    EnsureChannelFirstd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    RandFlipd,
    RandRotate90d,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    Resized,
    EnsureType,
    EnsureTyped,
    CropForegroundd,
)

class MONAIUNETR3DSegmentation(pl.LightningModule):
    def __init__(self, train_datalist, val_datalist, root_dir, batch_size=1, learning_rate=1e-4):
        super(MONAIUNETR3DSegmentation, self).__init__()
        self.train_datalist = train_datalist
        self.val_datalist = val_datalist
        self.root_dir = root_dir
        self.batch_size = batch_size
        self.learning_rate = learning_rate

        self._model = UNETR(
            in_channels=1,
            out_channels=2,
            img_size=(80, 80, 80),  # Updated to 80, 80, 80
            feature_size=16,
            hidden_size=768,
            mlp_dim=3072,
            num_heads=12,
            pos_embed="perceptron",
            norm_name="instance",
            res_block=True,
            dropout_rate=0.0,
        )

        self.loss_function = DiceLoss(to_onehot_y=True, softmax=True)
        self.post_pred = Compose([EnsureType("tensor", device="cpu"), AsDiscrete(argmax=True, to_onehot=2)])
        self.post_label = Compose([EnsureType("tensor", device="cpu"), AsDiscrete(to_onehot=2)])
        self.dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
        self.best_val_dice = 0
        self.best_val_epoch = 0
        self.validation_step_outputs = []
        
        self.roi_size = (80, 80, 80)


        self.common_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(
                keys=["image", "label"],
                # pixdim=(1.5, 1.5, 2.0),
                pixdim=(1.1, 1.1, 1.40),
                mode=("bilinear", "nearest"),
            ),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-1024,
                a_max=1024,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Resized(keys=["image", "label"], spatial_size=(80, 80, 80)),
            EnsureTyped(keys=["image", "label"]),
        ]
        )
    def forward(self, x):
        return self._model(x)
    
    def prepare_data(self):
        train_transforms = Compose(
            [
                LoadImaged(keys=["image", "label"]),
                EnsureChannelFirstd(keys=["image", "label"]),
                Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
                Orientationd(keys=["image", "label"], axcodes="RAS"),
                ScaleIntensityRanged(
                keys=["image"],
                a_min=-1024,
                a_max=1024,
                b_min=0.0,
                b_max=1.0,
                clip=True,),
                CropForegroundd(keys=["image", "label"], source_key="image"),
                Resized(keys=["image", "label"], spatial_size=(80, 80, 80), mode=("trilinear", "nearest")),  # Ensure the size is (80, 80, 80)
                EnsureTyped(keys=["image", "label"]),

                # RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=(80, 80, 80), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0),
                # RandFlipd(keys=["image", "label"], spatial_axis=[0], prob=0.10),
                # RandFlipd(keys=["image", "label"], spatial_axis=[1], prob=0.10),
                # RandFlipd(keys=["image", "label"], spatial_axis=[2], prob=0.10),
                # RandRotate90d(keys=["image", "label"], prob=0.10, max_k=3),
                # RandShiftIntensityd(keys="image", offsets=0.10, prob=0.50),
            ]
        )
        val_transforms = Compose(
            [
                LoadImaged(keys=["image", "label"]),
                EnsureChannelFirstd(keys=["image", "label"]),
                ScaleIntensityRanged(
                keys=["image"],
                a_min=-1024,
                a_max=1024,
                b_min=0.0,
                b_max=1.0,
                clip=True,),
                Orientationd(keys=["image", "label"], axcodes="RAS"),
                EnsureTyped(keys=["image", "label"]),
                CropForegroundd(keys=["image", "label"], source_key="image"),
                Resized(keys=["image", "label"], spatial_size=(80, 80, 80), mode=("trilinear", "nearest")),  # Ensure the size is (80, 80, 80)

            ]
        )

        self.train_ds = CacheDataset(data=self.train_datalist, transform=train_transforms, cache_rate=1.0, num_workers=4)
        self.val_ds = CacheDataset(data=self.val_datalist, transform=val_transforms, cache_rate=1.0, num_workers=4)

    
    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=4)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self._model.parameters(), lr=self.learning_rate)
        return optimizer

    # def training_step(self, batch, batch_idx):
    #     images, labels = (batch["image"].cuda(), batch["label"].cuda())
    #     output = self.forward(images)
    #     loss = self.loss_function(output, labels)
    #     tensorboard_logs = {"train_loss": loss.item()}
    #     return {"loss": loss, "log": tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        images, labels = batch["image"], batch["label"]
        logits = self._model(images)
        val_loss = self.loss_function(logits, labels)
        self.dice_metric(y_pred=logits, y=labels)
        self.log("val_loss", val_loss, prog_bar=True)
        return val_loss

    # def validation_step(self, batch, batch_idx):
    #     images, labels = batch["image"], batch["label"]
    #     roi_size = self.roi_size
    #     sw_batch_size = 4
    #     outputs = sliding_window_inference(images, roi_size, sw_batch_size, self)
    #     loss = self.loss_function(outputs, labels)
    #     outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
    #     labels = [self.post_label(i) for i in decollate_batch(labels)]
    #     self.dice_metric(y_pred=outputs, y=labels)
    #     d = {"val_loss": loss, "val_number": len(outputs)}
    #     self.validation_step_outputs.append(d)
        # return d

    def training_step(self, batch, batch_idx):
        images, labels = batch["image"], batch["label"]
        output = self.forward(images)
        loss = self.loss_function(output, labels)
        tensorboard_logs = {"train_loss": loss.item()}
        return {"loss": loss, "log": tensorboard_logs}

    # def validation_step(self, batch, batch_idx):
    #     images, labels = batch["image"], batch["label"]
    #     roi_size = self.roi_size
    #     sw_batch_size = 4
    #     outputs = sliding_window_inference(images, roi_size, sw_batch_size, self)
    #     loss = self.loss_function(outputs, labels)
    #     outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
    #     labels = [self.post_label(i) for i in decollate_batch(labels)]
    #     self.dice_metric(y_pred=outputs, y=labels)
    #     d = {"val_loss": loss, "val_number": len(outputs)}
    #     self.validation_step_outputs.append(d)
    #     return d
    
    def on_validation_epoch_end(self):
        dice_score = self.dice_metric.aggregate().item()
        self.dice_metric.reset()
        self.log("val_dice", dice_score, prog_bar=True)
        print(f"Validation Dice Score: {dice_score:.4f}")  # Print validation dice score

        
    # def on_validation_epoch_end(self):
    #     val_loss, num_items = 0, 0
    #     for output in self.validation_step_outputs:
    #         val_loss += output["val_loss"].sum().item()
    #         num_items += output["val_number"]
    #     mean_val_dice = self.dice_metric.aggregate().item()
    #     self.dice_metric.reset()
    #     mean_val_loss = torch.tensor(val_loss / num_items)
    #     tensorboard_logs = {
    #         "val_dice": mean_val_dice,
    #         "val_loss": mean_val_loss,
    #     }
    #     if mean_val_dice > self.best_val_dice:
    #         self.best_val_dice = mean_val_dice
    #         self.best_val_epoch = self.current_epoch
    #     print(
    #         f"current epoch: {self.current_epoch} "
    #         f"current mean dice: {mean_val_dice:.4f}"
    #         f"\nbest mean dice: {self.best_val_dice:.4f} "
    #         f"at epoch: {self.best_val_epoch}"
    #     )
    #     self.validation_step_outputs.clear()  # free memory
    #     self.log('val_dice', mean_val_dice, on_step=False, on_epoch=True, prog_bar=True, logger=True) # log

    #     return {"log": tensorboard_logs}



In [None]:
# Trainer setup with checkpointing
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath="checkpoints_unetr_test/",
    filename="best-checkpoint",
    save_top_k=1,
    verbose=True,
    monitor="val_dice",
    mode="max"
)



In [None]:
import glob
import os

pattern = os.path.join('AeroPath', '**/*_CT_HR_label_airways.nii.gz')
train_labels = sorted(glob.glob(pattern, recursive=True))

pattern = os.path.join('AeroPath', '**/*_CT_HR.nii.gz')
train_images = sorted(glob.glob(pattern, recursive=True))

data_dicts = [
    {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]

In [None]:
# Example usage
# Define paths and datalists
# train_datalist = load_decathlon_datalist("Task09_Spleen/dataset.json", True, "training")
# val_datalist = load_decathlon_datalist("Task09_Spleen/dataset.json", True, "validation")
root_dir = "data/"
                                                                                                                                                                    
# # Initialize the task
# segmentation_task = MONAIUNETR3DSegmentation(train_datalist, val_datalist, root_dir)
segmentation_task = MONAIUNETR3DSegmentation(train_files, val_files, root_dir)

# Trainer setup
trainer = pl.Trainer(max_epochs=100, callbacks=[checkpoint_callback])
trainer.fit(segmentation_task)
# torch.save(segmentation_task.state_dict(), 'unetr_model.pth')


NameError: name 'net' is not defined

In [None]:

# Load the model for inference
model = MONAIUNETR3DSegmentation(train_datalist, val_datalist, root_dir)
model.load_state_dict(torch.load('unetr_model.pth'))
model.eval()                                                                                                                                                                                    

# Load new data for prediction
test_datalist = load_decathlon_datalist("Task09_Spleen/dataset.json", True, "test")
test_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        Spacingd(keys=["image"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear")),
        Orientationd(keys=["image"], axcodes="RAS"),
        ScaleIntensityRanged(keys="image", a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
        Resized(keys=["image"], spatial_size=(80, 80, 80), mode="trilinear"),
    ]
)
test_ds = CacheDataset(data=test_datalist, transform=test_transforms, cache_rate=1.0, num_workers=4)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4)

# Make predictions
with torch.no_grad():
    for batch in test_loader:
        images = batch["image"]
        logits = model(images)
        predictions = torch.argmax(logits, dim=1)
        # Process the predictions as needed
        print(predictions)

In [None]:
label = nib.load('Task09_Spleen/labelsTr/spleen_63.nii.gz')
import numpy as np
np.unique(label.get_fdata())

In [None]:
import nibabel as nib

# predictions = predictions.cpu().numpy()
predictions = predictions[0, :, :, :]
nib.save(nib.Nifti1Image(predictions.astype(float), nib.load('AeroPath/1/1_CT_HR_label_airways.nii.gz').affine), f'unetr_test_pred.nii.gz')



In [None]:
import os
from monai.apps import download_and_extract
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"

compressed_file = os.path.join(os.getcwd(), "Task09_Spleen.tar")
data_dir = os.path.join(os.getcwd(), "Task09_Spleen")
if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, root_dir, md5)