In [2]:
from torchsummary import summary
from thop import profile
import torch

def print_model_summary(model, input_size):
    input_tensor = torch.randn(input_size)
    device = next(model.parameters()).device
    input_tensor = input_tensor.to(device)
    flops, params = profile(model, inputs=(input_tensor,))

    print(f"Model: {model.__class__.__name__}")
    print(f"FLOPs: {flops:,}, GFLOPs: {flops / 1e9:.2f}")
    print(f"Parameters: {params:,}")
    print("-" * 50)


# Example model
import torch.nn as nn
from monai.networks.nets import SwinUNETR, SwinTransformer

# SwinTransformer 테스트
swin_transformer = SwinTransformer(
    in_chans=1,
    embed_dim=48,
    window_size=(7, 7, 7),
    patch_size=(2, 2, 2),
    depths=(2, 2, 2, 2),
    num_heads=(3, 6, 12, 24),
    mlp_ratio=4.0,
    qkv_bias=True,
    drop_rate=0.0,
    attn_drop_rate=0.0,
    drop_path_rate=0.0,
    norm_layer=nn.LayerNorm,
    use_checkpoint=True,
    spatial_dims=3,
    downsample="merging",
    use_v2=True
)

# 전체 SwinUNETR 모델
swin_unetr = SwinUNETR(
    img_size=(96, 96, 96),
    in_channels=1,
    out_channels=7,
    feature_size=48,
    depths=(2, 2, 2, 2),
    num_heads=(3, 6, 12, 24),
    norm_name="instance",
    drop_rate=0.0,
    attn_drop_rate=0.0,
    dropout_path_rate=0.0,
    normalize=True,
    use_checkpoint=True,
    spatial_dims=3,
    downsample="merging",
    use_v2=False
)

# Input sizes
swin_transformer_input = (1, 1, 96, 96, 96)
swin_unetr_input = (1, 1, 96, 96, 96)

# Print summaries
print("SwinTransformer Summary:")
print_model_summary(swin_transformer, swin_transformer_input)

print("\nComplete SwinUNETR Summary:")
print_model_summary(swin_unetr, swin_unetr_input)



SwinTransformer Summary:
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm3d'>.
Model: SwinTransformer
FLOPs: 40,532,101,656.0, GFLOPs: 40.53
Parameters: 18,439,632.0
--------------------------------------------------

Complete SwinUNETR Summary:
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normali

torch.Size([1, 48, 48, 48, 48])
torch.Size([1, 96, 24, 24, 24])
torch.Size([1, 192, 12, 12, 12])
torch.Size([1, 384, 6, 6, 6])
torch.Size([1, 768, 3, 3, 3])
enc0: torch.Size([1, 48, 96, 96, 96])
enc1 torch.Size([1, 48, 48, 48, 48])
torch.Size([1, 96, 24, 24, 24])
torch.Size([1, 192, 12, 12, 12])
torch.Size([1, 768, 3, 3, 3])
Model: SwinUNETR
FLOPs: 329,543,087,640.0, GFLOPs: 329.54
Parameters: 61,989,223.0