In [3]:
def get_kernels_strides():
    """
    This function is only used for decathlon datasets with the provided patch sizes.
    When refering this method for other tasks, please ensure that the patch size for each spatial dimension should
    be divisible by the product of all strides in the corresponding dimension.
    In addition, the minimal spatial size should have at least one dimension that has twice the size of
    the product of all strides. For patch sizes that cannot find suitable strides, an error will be raised.
    """
    sizes, spacings = [96,96,96], [1.0, 1.0, 2.5]
    input_size = sizes
    strides, kernels = [], []
    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        for idx, (i, j) in enumerate(zip(sizes, stride)):
            if i % j != 0:
                raise ValueError(
                    f"Patch size is not supported, please try to modify the size {input_size[idx]} in the spatial dimension {idx}."
                )
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)

    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])
    return kernels, strides

In [16]:

import torch.nn as nn
import torch
from monai.networks.nets import DynUNet, UNet
from omegaconf import DictConfig
from loguru import logger
class Monai_DynUNet(nn.Module):
    """
    DynUNet with registry
    This is a prototype implementation of a DynUnet model.

    Examples::
        from config import cfg
        from utils.registry import MODEL

        model = MODEL['DynUnet'](cfg=cfg)
    """
    def __init__(self):
        super().__init__()
        # MODEL misc
        in_channels = 2    
        spatial_dims = 3
        out_channels = 2
        
       

        # Define model
        self.model = UNet(
            spatial_dims=3,
            in_channels=1,
            channels=(32,64,128,256,512,1024),
            out_channels=145,
            #kernel_size=kernels,
            strides=(2,2,2,2,2),
            #400->200->100->50->25
            norm="instance",
        )
        self._out_channels = out_channels

    @property
    def out_channels(self):
        """Return the output channels of the model."""
        return self._out_channels

    def forward(self, x):
        return self.model(x)

m = Monai_DynUNet()
m.model

UNet(
  (model): Sequential(
    (0): Convolution(
      (conv): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      (adn): ADN(
        (N): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (D): Dropout(p=0.0, inplace=False)
        (A): PReLU(num_parameters=1)
      )
    )
    (1): SkipConnection(
      (submodule): Sequential(
        (0): Convolution(
          (conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
          (adn): ADN(
            (N): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
        (1): SkipConnection(
          (submodule): Sequential(
            (0): Convolution(
              (conv): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
              (adn): ADN(
                (N): Ins

In [36]:
import torch
import torch.nn as nn
from loguru import logger
anatomy_ckpt_path = '/Users/keyi/Desktop/epoch=399-step=30460.ckpt'
if not anatomy_ckpt_path:
    raise ValueError("Pretrained anatomy model path not specified in config")
    
logger.info(f"Loading pretrained anatomy model from {anatomy_ckpt_path}")
ckpt = torch.load(anatomy_ckpt_path, map_location='cpu')

print(f"checkpoint keys: {ckpt['state_dict'].keys()}")
print(f"model keys: {m.model.state_dict().keys()}")
if 'state_dict' in ckpt:
    anatomy_state_dict = {k.replace('model.model.', 'model.'): v for k, v in ckpt['state_dict'].items()}
else:
    anatomy_state_dict = ckpt
    
missing_keys, unexpected_keys = m.model.load_state_dict(
    anatomy_state_dict, strict=False
)
a = torch.randn(1,1,96,96,96)
o = m.model(a)
logger.info(f"output shape: {o.shape}")
logger.info(f"missing_keys: {missing_keys}")
logger.info(f"unexpected_keys: {unexpected_keys}")
logger.info("Loaded pretrained anatomy model")

[32m2025-01-17 15:52:28.961[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m8[0m - [1mLoading pretrained anatomy model from /Users/keyi/Desktop/epoch=399-step=30460.ckpt[0m
  ckpt = torch.load(anatomy_ckpt_path, map_location='cpu')


checkpoint keys: odict_keys(['model.model.0.conv.weight', 'model.model.0.conv.bias', 'model.model.0.adn.A.weight', 'model.model.1.submodule.0.conv.weight', 'model.model.1.submodule.0.conv.bias', 'model.model.1.submodule.0.adn.A.weight', 'model.model.1.submodule.1.submodule.0.conv.weight', 'model.model.1.submodule.1.submodule.0.conv.bias', 'model.model.1.submodule.1.submodule.0.adn.A.weight', 'model.model.1.submodule.1.submodule.1.submodule.0.conv.weight', 'model.model.1.submodule.1.submodule.1.submodule.0.conv.bias', 'model.model.1.submodule.1.submodule.1.submodule.0.adn.A.weight', 'model.model.1.submodule.1.submodule.1.submodule.1.submodule.0.conv.weight', 'model.model.1.submodule.1.submodule.1.submodule.1.submodule.0.conv.bias', 'model.model.1.submodule.1.submodule.1.submodule.1.submodule.0.adn.A.weight', 'model.model.1.submodule.1.submodule.1.submodule.1.submodule.1.submodule.conv.weight', 'model.model.1.submodule.1.submodule.1.submodule.1.submodule.1.submodule.conv.bias', 'model.mo

[32m2025-01-17 15:52:29.820[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m23[0m - [1moutput shape: torch.Size([1, 145, 96, 96, 96])[0m
[32m2025-01-17 15:52:29.821[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m24[0m - [1mmissing_keys: [][0m
[32m2025-01-17 15:52:29.821[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m25[0m - [1munexpected_keys: ['loss.dice.class_weight'][0m
[32m2025-01-17 15:52:29.821[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mLoaded pretrained anatomy model[0m


In [6]:
import re
feature_fused_channels = ["320_conv1", "320_conv2"]
channel = 320
if channel in feature_fused_channels:
    print("yes")
else:
    print("no")

no


In [7]:
import torch
input_data = torch.randn(1, 1, 96, 96, 96)  # B, C, H, W, D


def hook_fn(module, input, output):
    print(f"output shape in bottleneck: {output.shape}")



bottleneck = m.model.bottleneck.conv1.conv
hook = bottleneck.register_forward_hook(hook_fn)


with torch.no_grad():
    output = m.model(input_data)
    

hook.remove()


print(f"input_data shape: {input_data.shape}")
print(f"output shape: {output.shape}")


output shape in bottleneck: torch.Size([1, 320, 6, 6, 6])
input_data shape: torch.Size([1, 1, 96, 96, 96])
output shape: torch.Size([1, 4, 145, 96, 96, 96])
