In [2]:
from monai.networks.nets import SegResNet
import torch
import torch.nn as nn
from loguru import logger
class Monai_SegResNet(nn.Module):
    """
    DynUNet with registry
    This is a prototype implementation of a DynUnet model.

    Examples::
        from config import cfg
        from utils.registry import MODEL

        model = MODEL['DynUnet'](cfg=cfg)
    """
    def __init__(self):
        super().__init__()
        # MODEL misc
        in_channels = 2    
        spatial_dims = 3
        out_channels = 105


        # Define model
        self.model = SegResNet(
            spatial_dims=3,
            in_channels=1,
            init_filters=32,
            dropout_prob=0.2,
            out_channels=out_channels
        )
        self._out_channels = out_channels
        
        self._register_shape_hooks()

    @property
    def out_channels(self):
        """Return the output channels of the model."""
        return self._out_channels

    def forward(self, x):
        return self.model(x)
    
    def _register_shape_hooks(self):
        def hook_fn(name):
            def hook(module, input, output):
                logger.info(f"Layer {name} output shape: {output.shape}")
            return hook
        
        bottleneck = self.model.down_layers[-1][-1].conv2
        bottleneck.register_forward_hook(hook_fn("bottleneck"))

m = Monai_SegResNet()
print(m)
data = torch.randn(1, 1, 96, 96, 96)
output = m(data)


Monai_SegResNet(
  (model): SegResNet(
    (act_mod): ReLU(inplace=True)
    (convInit): Convolution(
      (conv): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    )
    (down_layers): ModuleList(
      (0): Sequential(
        (0): Identity()
        (1): ResBlock(
          (norm1): GroupNorm(8, 32, eps=1e-05, affine=True)
          (norm2): GroupNorm(8, 32, eps=1e-05, affine=True)
          (act): ReLU(inplace=True)
          (conv1): Convolution(
            (conv): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          )
          (conv2): Convolution(
            (conv): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          )
        )
      )
      (1): Sequential(
        (0): Convolution(
          (conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
        )
        (1): ResBlock(
          (norm1): GroupNorm

[32m2025-01-08 11:02:21.232[0m | [1mINFO    [0m | [36m__main__[0m:[36mhook[0m:[36m47[0m - [1mLayer bottleneck output shape: torch.Size([1, 256, 12, 12, 12])[0m


In [11]:
print(m.model.down_layers[-2][0].conv)

Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)


In [7]:
from loguru import logger
anatomy_ckpt_path = '/Users/keyi/Desktop/wholeBody_ct_segmentation/models/model_lowres.pt'
if not anatomy_ckpt_path:
    raise ValueError("Pretrained anatomy model path not specified in config")
    
logger.info(f"Loading pretrained anatomy model from {anatomy_ckpt_path}")
ckpt = torch.load(anatomy_ckpt_path, map_location='cpu')

if 'state_dict' in ckpt:
    anatomy_state_dict = {k.replace('model.', ''): v for k, v in ckpt['state_dict'].items()}
else:
    anatomy_state_dict = ckpt
    
missing_keys, unexpected_keys = m.model.load_state_dict(
    anatomy_state_dict, strict=False
)

logger.info("Loaded pretrained anatomy model")
logger.info(f"missing_keys: {missing_keys}")
logger.info(f"unexpected_keys: {unexpected_keys}")


[32m2025-01-08 11:02:56.672[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mLoading pretrained anatomy model from /Users/keyi/Desktop/wholeBody_ct_segmentation/models/model_lowres.pt[0m
  ckpt = torch.load(anatomy_ckpt_path, map_location='cpu')
[32m2025-01-08 11:02:56.697[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m18[0m - [1mLoaded pretrained anatomy model[0m
[32m2025-01-08 11:02:56.697[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m19[0m - [1mmissing_keys: [][0m
[32m2025-01-08 11:02:56.697[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m20[0m - [1munexpected_keys: [][0m


In [5]:
import torch
import torch.nn as nn

class AlignBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, 
                             kernel_size=1, bias=False)

    def forward(self, x):
        return self.conv(x)
    
B = 1
d= torch.randn(B, 256, 6,6,6)
a = AlignBlock(256, 512)
a(d).shape  

torch.Size([1, 512, 6, 6, 6])

In [14]:
import torch
input_data = torch.randn(1, 1, 96, 96, 96)  # B, C, H, W, D


def hook_fn(module, input, output):
    print(f"output shape in bottleneck: {output.shape}")



bottleneck = m.model.down_layers[-3][0].conv # last layer of the down_layers
hook = bottleneck.register_forward_hook(hook_fn)


with torch.no_grad():
    output = m.model(input_data)
    

hook.remove()


print(f"input_data shape: {input_data.shape}")
print(f"output shape: {output.shape}")

[32m2025-01-08 11:45:08.728[0m | [1mINFO    [0m | [36m__main__[0m:[36mhook[0m:[36m47[0m - [1mLayer bottleneck output shape: torch.Size([1, 256, 12, 12, 12])[0m


output shape in bottleneck: torch.Size([1, 128, 24, 24, 24])
input_data shape: torch.Size([1, 1, 96, 96, 96])
output shape: torch.Size([1, 105, 96, 96, 96])


In [None]:
# zip (unet, unet), zip(dynunet, dynunet)
# segResNet bottleneck 512 channels
# unet bottleneck 1024 channels
# conv1x1x1 (512 -> 1024) channels feature distribution 
# unet bottleneck 1024 channels 

In [None]:
# 1. fuse dynunet with 2 bottleneck channels -> begins at 50 epochs
# CT +PET from the pathology transformation -> CT anatomy model
# CT_ana -> Anatomy model
# CT+PET -> pathology model