In [1]:
from torchsummary import summary
from thop import profile
import torch

def print_model_summary(model, input_size):
    input_tensor = torch.randn(input_size)
    device = next(model.parameters()).device
    input_tensor = input_tensor.to(device)
    flops, params = profile(model, inputs=(input_tensor,))

    print(f"Model: {model.__class__.__name__}")
    print(f"FLOPs: {flops:,}, GFLOPs: {flops / 1e9:.2f}")
    print(f"Parameters: {params:,}")
    print("-" * 50)


In [2]:
from src.models import CSPBlock, UnetResBlock


in_channels = 64
out_channels = 128
imgsz  = 96
block = CSPBlock(
        spatial_dims=3,
        in_channels=in_channels,  # 입력 채널 수정
        out_channels=out_channels,
        kernel_size=3,
        stride=2,
        norm_name="batch",
        act_name=("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
        dropout=None,
        split_ratio=0.5,
        n=2
    )
x = (1, in_channels, imgsz, imgsz, imgsz)

# Print summaries
print("SwinTransformer Summary:")
print_model_summary(block, x)

block = UnetResBlock(
        spatial_dims=3,
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=3,
        stride=2,
        norm_name="batch",
        act_name=("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
        dropout=None,
     
    )

# Print summaries
print("UnetResBlock Summary:")
print_model_summary(block, x)

SwinTransformer Summary:
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
Model: CSPBlock
FLOPs: 51,640,270,848.0, GFLOPs: 51.64
Parameters: 466,944.0
--------------------------------------------------
UnetResBlock Summary:
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm3d'>.
Model: UnetResBlock
FLOPs: 74,459,381,760.0, GFLOPs: 74.46
Parameters: 672,512.0
--------------------------------------------------


In [2]:
from src.models.swincspunetr3plus import SwinCSPUNETR3plus
x = (1, 1, 96, 96, 96)
swin_unetr = SwinCSPUNETR3plus(
    img_size=(96, 96, 96),
    in_channels=1,
    out_channels=7,
    feature_size=48,
    depths=(2, 2, 2, 2),
    num_heads=(3, 6, 12, 24),
    norm_name="instance",
    drop_rate=0.0,
    attn_drop_rate=0.0,
    dropout_path_rate=0.0,
    normalize=True,
    use_checkpoint=True,
    spatial_dims=3,
    downsample="merging",
    use_v2=True,
    n=2,
)

# Print summaries
print("SwinCSPUNETR_unet Summary:")
print_model_summary(swin_unetr, x)

import torch
from torch.profiler import ProfilerActivity
from torch.profiler import profile as profilee

input_tensor = torch.randn(x)
device = next(swin_unetr.parameters()).device
input_tensor = input_tensor.to(device)

# 프로파일링
with profilee(
    activities=[
        ProfilerActivity.CPU, 
        ProfilerActivity.CUDA
    ],
    on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'),  # TensorBoard 연동
    record_shapes=True,
    with_stack=True
) as prof:
    swin_unetr(input_tensor)

# 프로파일링 결과 출력
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))



  from .autonotebook import tqdm as notebook_tqdm


SwinCSPUNETR_unet Summary:
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool3d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose3d'>.
Model: SwinCSPUNETR3plus
FLOPs: 629,032,886,808.0, GFLOPs: 629.03
Parameters: 55,484,983.0
--------------------------------------------------
------

In [2]:
from src.models.swincspunetr import SwinCSPUNETR
x = (1, 1, 96, 96, 96)
swin_unetr = SwinCSPUNETR(
    img_size=(96, 96, 96),
    in_channels=1,
    out_channels=7,
    feature_size=48,
    depths=(2, 2, 2, 2),
    num_heads=(3, 6, 12, 24),
    norm_name="instance",
    drop_rate=0.0,
    attn_drop_rate=0.0,
    dropout_path_rate=0.0,
    normalize=True,
    use_checkpoint=True,
    spatial_dims=3,
    downsample="merging",
    use_v2=True,
    n=2,
)

# Print summaries
print("SwinCSPUNETR Summary:")
print_model_summary(swin_unetr, x)



SwinCSPUNETR Summary:
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose3d'>.
Model: SwinCSPUNETR
FLOPs: 289,518,045,720.0, GFLOPs: 289.52
Parameters: 62,104,375.0
--------------------------------------------------


In [5]:


# Example model
import torch.nn as nn
from monai.networks.nets import SwinUNETR, SwinTransformer

# SwinTransformer 테스트
swin_transformer = SwinTransformer(
    in_chans=1,
    embed_dim=48,
    window_size=(7, 7, 7),
    patch_size=(2, 2, 2),
    depths=(2, 2, 2, 2),
    num_heads=(3, 6, 12, 24),
    mlp_ratio=4.0,
    qkv_bias=True,
    drop_rate=0.0,
    attn_drop_rate=0.0,
    drop_path_rate=0.0,
    norm_layer=nn.LayerNorm,
    use_checkpoint=True,
    spatial_dims=3,
    downsample="merging",
    use_v2=True
)

# 전체 SwinUNETR 모델
swin_unetr = SwinUNETR(
    img_size=(96, 96, 96),
    in_channels=1,
    out_channels=7,
    feature_size=48,
    depths=(2, 2, 2, 2),
    num_heads=(3, 6, 12, 24),
    norm_name="instance",
    drop_rate=0.0,
    attn_drop_rate=0.0,
    dropout_path_rate=0.0,
    normalize=True,
    use_checkpoint=True,
    spatial_dims=3,
    downsample="merging",
    use_v2=True
)

# Input sizes
swin_transformer_input = (1, 1, 96, 96, 96)
swin_unetr_input = (1, 1, 96, 96, 96)

# # Print summaries
# print("SwinTransformer Summary:")
# print_model_summary(swin_transformer, swin_transformer_input)

print("\nComplete SwinUNETR Summary:")
print_model_summary(swin_unetr, swin_unetr_input)




Complete SwinUNETR Summary:
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm3d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose3d'>.
torch.Size([1, 48, 48, 48, 48])
torch.Size([1, 96, 24, 24, 24])
torch.Size([1, 192, 12, 12, 12])
torch.Size([1, 384, 6, 6, 6])
torch.Size([1, 768, 3, 3, 3])
enc0: torch.Size([1, 48, 96, 96, 96])
enc1 torch.Size([1, 48, 48, 48, 48])
torch.Size([1, 96, 24, 24, 24])
torch.Size([1, 192, 12, 12, 12])
torch.

torch.Size([1, 48, 48, 48, 48])
torch.Size([1, 96, 24, 24, 24])
torch.Size([1, 192, 12, 12, 12])
torch.Size([1, 384, 6, 6, 6])
torch.Size([1, 768, 3, 3, 3])
enc0: torch.Size([1, 48, 96, 96, 96])
enc1 torch.Size([1, 48, 48, 48, 48])
torch.Size([1, 96, 24, 24, 24])
torch.Size([1, 192, 12, 12, 12])
torch.Size([1, 768, 3, 3, 3])
Model: SwinUNETR
FLOPs: 329,543,087,640.0, GFLOPs: 329.54
Parameters: 61,989,223.0