In [4]:
#!/usr/bin/env python
"""
Quick sanity check that PyTorch + CUDA + MONAI all work together.

Run:
    python sanity_check_monai.py

Exit code 0 = all checks passed.
"""

import sys
import traceback
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import monai
from monai.transforms import Compose, ScaleIntensityd, RandFlipd, EnsureChannelFirstd, EnsureTyped
from monai.data import Dataset, DataLoader
from monai.networks.nets import UNet

RESULTS = []  # list of (name, passed_bool, msg)


def report(name, ok, msg=""):
    RESULTS.append((name, ok, msg))
    status = "PASS" if ok else "FAIL"
    print(f"[{status}] {name} - {msg}")


def check_versions():
    try:
        msg = (
            f"torch={torch.__version__}, "
            f"cuda_available={torch.cuda.is_available()}, "
            f"cuda_version={torch.version.cuda}, "
            f"cudnn_enabled={torch.backends.cudnn.enabled}, "
            f"cudnn_version={torch.backends.cudnn.version()}, "
            f"monai={monai.__version__}"
        )
        report("Version Info", True, msg)
    except Exception as e:
        report("Version Info", False, str(e))


def check_cuda_basic():
    name = "CUDA Basic Tensor Ops"
    if not torch.cuda.is_available():
        report(name, False, "torch.cuda.is_available() is False")
        return

    try:
        device = torch.device("cuda")
        x = torch.randn(8, 8, device=device)
        y = torch.randn(8, 8, device=device)
        z = x @ y  # matmul
        torch.cuda.synchronize()
        if torch.all(torch.isfinite(z)):
            report(name, True, f"Device: {torch.cuda.get_device_name(device)}")
        else:
            report(name, False, "Non-finite result.")
    except Exception:
        report(name, False, traceback.format_exc())


def make_synthetic_dict_batch_2d(batch_size=4, img_size=(1, 64, 64), num_classes=2):
    """
    Produce a list[dict] suitable for monai.data.Dataset with dict transforms.
    Keys: 'image', 'label'
    """
    data = []
    for _ in range(batch_size):
        img = np.random.rand(*img_size).astype(np.float32)
        lbl = np.random.randint(0, num_classes, size=img_size[1:], dtype=np.int16)  # H,W
        data.append({"image": img, "label": lbl})
    return data


def make_synthetic_dict_batch_3d(batch_size=2, img_size=(1, 32, 32, 32), num_classes=2):
    data = []
    for _ in range(batch_size):
        img = np.random.rand(*img_size).astype(np.float32)
        lbl = np.random.randint(0, num_classes, size=img_size[1:], dtype=np.int16)  # D,H,W
        data.append({"image": img, "label": lbl})
    return data


def check_monai_transforms():
    name = "MONAI Dict Transforms (2D)"
    try:
        transforms = Compose([
            EnsureChannelFirstd(keys=["image"], channel_dim=0),  # already channel-first but safe
            ScaleIntensityd(keys=["image"]),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=-1),
            EnsureTyped(keys=["image", "label"]),
        ])
        ds = Dataset(data=make_synthetic_dict_batch_2d(), transform=transforms)
        loader = DataLoader(ds, batch_size=2, shuffle=False, num_workers=2)
        batch = next(iter(loader))
        img, lbl = batch["image"], batch["label"]
        # Expect shape: (B, C, H, W)
        assert img.ndim == 4, f"Expected 4D (B,C,H,W), got {img.shape}"
        assert lbl.ndim == 3 or lbl.ndim == 4, f"Label expected 3D or 4D; got {lbl.shape}"
        report(name, True, f"image batch shape={tuple(img.shape)}")
    except Exception:
        report(name, False, traceback.format_exc())


def tiny_unet_2d(in_channels=1, out_channels=2, channels=(8, 16), strides=(2,), num_res_units=1):
    return UNet(
        spatial_dims=2,
        in_channels=in_channels,
        out_channels=out_channels,
        channels=channels,
        strides=strides,
        num_res_units=num_res_units,
    )


def tiny_unet_3d(in_channels=1, out_channels=2, channels=(8, 16), strides=(2,), num_res_units=1):
    return UNet(
        spatial_dims=3,
        in_channels=in_channels,
        out_channels=out_channels,
        channels=channels,
        strides=strides,
        num_res_units=num_res_units,
    )


def _train_step(model, inputs, labels, device):
    model.train()
    inputs = inputs.to(device)
    # Convert labels to long; if label lacks channel dim, add one-hot
    # We'll use cross-entropy: model outputs (B,C,...) & labels (B,...)
    labels = labels.to(device).long()
    # If labels came in with (B,1,H,W) etc., squeeze
    if labels.ndim == inputs.ndim:
        labels = labels.squeeze(1)
    logits = model(inputs)
    loss = F.cross_entropy(logits, labels)
    loss.backward()
    return loss.item()


def check_monai_unet_2d_training():
    name = "MONAI UNet 2D Forward/Backward"
    try:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        device = torch.device("cuda")
        model = tiny_unet_2d().to(device)
        opt = torch.optim.Adam(model.parameters(), lr=1e-3)

        # synthetic batch
        ds = Dataset(data=make_synthetic_dict_batch_2d(batch_size=2), transform=None)
        batch = ds[0]
        img = torch.from_numpy(batch["image"]).unsqueeze(0)  # (1,C,H,W)
        lbl = torch.from_numpy(batch["label"]).unsqueeze(0)  # (1,H,W)

        opt.zero_grad(set_to_none=True)
        loss_val = _train_step(model, img, lbl, device)
        opt.step()
        report(name, True, f"loss={loss_val:.4f}")
    except Exception:
        report(name, False, traceback.format_exc())


def check_monai_unet_3d_training():
    name = "MONAI UNet 3D Forward/Backward"
    try:
        # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        device = torch.device("cuda")
        model = tiny_unet_3d().to(device)
        opt = torch.optim.Adam(model.parameters(), lr=1e-3)

        ds = Dataset(data=make_synthetic_dict_batch_3d(batch_size=1), transform=None)
        batch = ds[0]
        img = torch.from_numpy(batch["image"]).unsqueeze(0)  # (1,C,D,H,W)
        lbl = torch.from_numpy(batch["label"]).unsqueeze(0)  # (1,D,H,W)

        opt.zero_grad(set_to_none=True)
        loss_val = _train_step(model, img, lbl, device)
        opt.step()
        report(name, True, f"loss={loss_val:.4f}")
    except Exception:
        report(name, False, traceback.format_exc())


def main():
    print("=== PyTorch + MONAI Sanity Check ===")
    check_versions()
    check_cuda_basic()
    check_monai_transforms()
    check_monai_unet_2d_training()
    check_monai_unet_3d_training()

    print("\n=== Summary ===")
    all_pass = True
    for name, ok, msg in RESULTS:
        status = "PASS" if ok else "FAIL"
        print(f"{status:5s} | {name:30s} | {msg}")
        all_pass &= ok

    if all_pass:
        print("\nAll checks passed.")
        return 0
    else:
        print("\nOne or more checks FAILED. Scroll up for traceback(s).")
        return 1


if __name__ == "__main__":
    sys.exit(main())


=== PyTorch + MONAI Sanity Check ===
[PASS] Version Info - torch=2.7.1+cu128, cuda_available=True, cuda_version=12.8, cudnn_enabled=True, cudnn_version=90701, monai=1.5.0
[PASS] CUDA Basic Tensor Ops - Device: NVIDIA GeForce RTX 5090
[PASS] MONAI Dict Transforms (2D) - image batch shape=(2, 1, 64, 64)
[PASS] MONAI UNet 2D Forward/Backward - loss=0.8133
[PASS] MONAI UNet 3D Forward/Backward - loss=0.8152

=== Summary ===
PASS  | Version Info                   | torch=2.7.1+cu128, cuda_available=True, cuda_version=12.8, cudnn_enabled=True, cudnn_version=90701, monai=1.5.0
PASS  | CUDA Basic Tensor Ops          | Device: NVIDIA GeForce RTX 5090
PASS  | MONAI Dict Transforms (2D)     | image batch shape=(2, 1, 64, 64)
PASS  | MONAI UNet 2D Forward/Backward | loss=0.8133
PASS  | MONAI UNet 3D Forward/Backward | loss=0.8152

All checks passed.


SystemExit: 0