# In this notebook I will test loadin `MONAI` and working with it

In [25]:
# !pip install "monai[all]"

## Loading a Pretrained Model in MONAI
MONAI’s `monai.networks.nets` and Model Zoo let you directly pull pretrained models.


In [26]:
# from monai.networks.nets import UNet
from torch import cuda
import torch

device = "cuda" if cuda.is_available() else "cpu"
device, torch.__version__


('cuda', '2.6.0+cu126')

In [27]:
from monai.networks.nets import UNet  
from torchinfo import summary

# model: MONAI UNet for 2D
model = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=3,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

summary(
    model,
    input_size=(1, 1, 256, 256),  # (batch, channels, height, width)
    col_names=("input_size", "output_size", "num_params", "trainable"),
)



Layer (type:depth-idx)                                                                     Input Shape               Output Shape              Param #                   Trainable
UNet                                                                                       [1, 1, 256, 256]          [1, 3, 256, 256]          --                        True
├─Sequential: 1-1                                                                          [1, 1, 256, 256]          [1, 3, 256, 256]          --                        True
│    └─ResidualUnit: 2-1                                                                   [1, 1, 256, 256]          [1, 16, 128, 128]         --                        True
│    │    └─Conv2d: 3-1                                                                    [1, 1, 256, 256]          [1, 16, 128, 128]         160                       True
│    │    └─Sequential: 3-2                                                                [1, 1, 256, 256]          [1, 16, 

In [28]:
from monai.networks.nets import SegResNetVAE


model_2 = SegResNetVAE(
    input_image_size=(224, 224, 144), # (H, W, D)
    vae_estimate_std=False,
    vae_nz=256,
    spatial_dims=3,
    init_filters=8,
    in_channels=4,
    out_channels=3,
    dropout_prob=None,
    blocks_down=(1, 2, 2, 4),
    blocks_up=(1, 1, 1),
    upsample_mode="nontrainable",
).to(device)

summary(
    model_2,
    input_size=(1, 4, 144, 224, 224),  # (batch, channels, depth, height, width)
    col_names=("input_size", "output_size", "num_params", "trainable"),
)

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Trainable
SegResNetVAE                             [1, 4, 144, 224, 224]     [1, 3, 144, 224, 224]     21,733,717                True
├─Convolution: 1-1                       [1, 4, 144, 224, 224]     [1, 8, 144, 224, 224]     --                        True
│    └─Conv3d: 2-1                       [1, 4, 144, 224, 224]     [1, 8, 144, 224, 224]     864                       True
├─ModuleList: 1-2                        --                        --                        --                        True
│    └─Sequential: 2-2                   [1, 8, 144, 224, 224]     [1, 8, 144, 224, 224]     --                        True
│    │    └─Identity: 3-1                [1, 8, 144, 224, 224]     [1, 8, 144, 224, 224]     --                        --
│    │    └─ResBlock: 3-2                [1, 8, 144, 224, 224]     [1, 8, 144, 224, 224]     3,488                     True
│    

In [29]:
import torch
from collections import OrderedDict
ckpt = torch.load("../bundles/brats_mri_segmentation/models/model.pt", map_location="cpu") 
sd = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt
sd = OrderedDict((k.replace("module.", ""), v) for k, v in sd.items())

print(">>> checkpoint keys sample (first 60):")
for k in list(sd.keys())[:60]:
    print(k)
print("... total ckpt keys:", len(sd))

# print first conv / final conv shapes if present
first_conv = None
final_conv = None
for k in sd:
    if "conv" in k and "weight" in k and first_conv is None:
        first_conv = (k, sd[k].shape)
    if ("final" in k or "conv_final" in k or "convInit" in k) and "weight" in k:
        final_conv = (k, sd[k].shape)
print("first conv key/shape:", first_conv)
print("final conv key/shape (if found):", final_conv)


>>> checkpoint keys sample (first 60):
convInit.conv.weight
down_layers.0.1.norm1.weight
down_layers.0.1.norm1.bias
down_layers.0.1.norm2.weight
down_layers.0.1.norm2.bias
down_layers.0.1.conv1.conv.weight
down_layers.0.1.conv2.conv.weight
down_layers.1.0.conv.weight
down_layers.1.1.norm1.weight
down_layers.1.1.norm1.bias
down_layers.1.1.norm2.weight
down_layers.1.1.norm2.bias
down_layers.1.1.conv1.conv.weight
down_layers.1.1.conv2.conv.weight
down_layers.1.2.norm1.weight
down_layers.1.2.norm1.bias
down_layers.1.2.norm2.weight
down_layers.1.2.norm2.bias
down_layers.1.2.conv1.conv.weight
down_layers.1.2.conv2.conv.weight
down_layers.2.0.conv.weight
down_layers.2.1.norm1.weight
down_layers.2.1.norm1.bias
down_layers.2.1.norm2.weight
down_layers.2.1.norm2.bias
down_layers.2.1.conv1.conv.weight
down_layers.2.1.conv2.conv.weight
down_layers.2.2.norm1.weight
down_layers.2.2.norm1.bias
down_layers.2.2.norm2.weight
down_layers.2.2.norm2.bias
down_layers.2.2.conv1.conv.weight
down_layers.2.2.co

In [30]:
# after you create your MONAI UNet instance as `model`
md = model.state_dict()
print("\n>>> model keys sample (first 60):")
for k in list(md.keys())[:60]:
    print(k)
print("... total model keys:", len(md))

# check model first & final conv shapes
for k in md:
    if "conv" in k and "weight" in k:
        print("model sample conv:", k, md[k].shape)
        break
# final conv (likely the last weight)
for k in reversed(list(md.keys())):
    if "conv" in k and "weight" in k:
        print("model final conv sample:", k, md[k].shape)
        break



>>> model keys sample (first 60):
model.0.conv.unit0.conv.weight
model.0.conv.unit0.conv.bias
model.0.conv.unit0.adn.A.weight
model.0.conv.unit1.conv.weight
model.0.conv.unit1.conv.bias
model.0.conv.unit1.adn.A.weight
model.0.residual.weight
model.0.residual.bias
model.1.submodule.0.conv.unit0.conv.weight
model.1.submodule.0.conv.unit0.conv.bias
model.1.submodule.0.conv.unit0.adn.A.weight
model.1.submodule.0.conv.unit1.conv.weight
model.1.submodule.0.conv.unit1.conv.bias
model.1.submodule.0.conv.unit1.adn.A.weight
model.1.submodule.0.residual.weight
model.1.submodule.0.residual.bias
model.1.submodule.1.submodule.0.conv.unit0.conv.weight
model.1.submodule.1.submodule.0.conv.unit0.conv.bias
model.1.submodule.1.submodule.0.conv.unit0.adn.A.weight
model.1.submodule.1.submodule.0.conv.unit1.conv.weight
model.1.submodule.1.submodule.0.conv.unit1.conv.bias
model.1.submodule.1.submodule.0.conv.unit1.adn.A.weight
model.1.submodule.1.submodule.0.residual.weight
model.1.submodule.1.submodule.0.r

In [31]:
# find first conv weight key and print shape
for k in sd.keys():
    if "conv" in k and "weight" in k:
        print(k, sd[k].shape)
        break


convInit.conv.weight torch.Size([16, 4, 3, 3, 3])
