In [1]:
from torchsummary import summary
from thop import profile
import torch
from torch.profiler import ProfilerActivity
from torch.profiler import profile as profilee
    
def print_model_summary(model, input_size):
    input_tensor = torch.randn(input_size)
    device = next(model.parameters()).device
    input_tensor = input_tensor.to(device)
    flops, params = profile(model, inputs=(input_tensor,))

    print(f"Model: {model.__class__.__name__}")
    print(f"FLOPs: {flops:,}, GFLOPs: {flops / 1e9:.2f}")
    print(f"Parameters: {params:,}")
    print("-" * 50)

def profile_model(model, input_size, log_dir='./log'):

    input_tensor = torch.randn(input_size)
    device = next(model.parameters()).device
    input_tensor = input_tensor.to(device)

    # 프로파일링
    with profilee(
        activities=[
            ProfilerActivity.CPU, 
            ProfilerActivity.CUDA
        ],
        on_trace_ready=torch.profiler.tensorboard_trace_handler(log_dir),  # TensorBoard 연동
        record_shapes=True,
        with_stack=True
    ) as prof:
        model(input_tensor)

    # 프로파일링 결과 출력
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

In [2]:
from src.models import *
from monai.networks.layers.factories import Act, Norm

enc_channels = (32, 64, 128, 256)
enc_strides = (2, 2, 2)
num_layers_enc = (1, 1, 1, 1)

core_channels = 64
dec_channels = (128, 64, 32)
dec_strides = (2, 2, 2)
num_layers_dec = (1, 1, 1)

skip_map = {
    0: [("enc", 2)],       # 디코더0 => 인코더2
    1: [("enc", 3), ("enc", 1)],  # 디코더1 => 인코더1 + 디코더0
    2: [("enc", 3), ("dec", 0), ("enc", 0)]   # 디코더2 => 인코더0 + 디코더1
}

net = FlexibleUNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
        encoder_channels=enc_channels,
        encoder_strides=enc_strides,
        core_channels=core_channels,
        decoder_channels=dec_channels,
        decoder_strides=dec_strides,
        num_layers_encoder=num_layers_enc,
        num_layers_decoder=num_layers_dec,
        skip_connections=skip_map,
        kernel_size=3,
        up_kernel_size=3,
        act=Act.PRELU,
        norm=Norm.INSTANCE,
        dropout=0.0,
        bias=True,
        mode="trilinear",
        align_corners=False,
    )

x = (1, 1, 96, 96, 96)
print_model_summary(net, x)

profile_model(net, x)


  from .autonotebook import tqdm as notebook_tqdm


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_prelu() for <class 'torch.nn.modules.activation.PReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose3d'>.




Model: FlexibleUNet
FLOPs: 395,546,222,592.0, GFLOPs: 395.55
Parameters: 2,984,044.0
--------------------------------------------------


  warn("CUDA is not available, disabling CUDA profiling")


--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                        aten::to         0.00%       6.085us         0.00%       6.085us       0.869us             7  
                    aten::conv3d         0.00%     391.381us        99.28%      202.897s       15.607s            13  
               aten::convolution         0.00%       2.009ms        99.55%      203.443s       12.715s            16  
              aten::_convolution         0.00%       2.487ms        99.55%      203.441s       12.715s            16  
               aten::slow_conv3d         0.00%     305.212us        99.28%      202.894s       15.607s            13  
       aten::slow_conv3d_forward        99.25%  

In [3]:
from src.models import DP_UNet
from monai.networks.nets import UNet

unet_dp = DP_UNet(
    # img_size = (96, 96, 96),
    spatial_dims=3,
    in_channels=1,
    out_channels=7,
    channels=(32,64,128,256,512),
    strides=(2, 2, 2, 2),
    # num_res_units=0,
)

unet = UNet(
    # img_size = (96, 96, 96),
    spatial_dims=3,
    in_channels=1,
    out_channels=7,
    channels=(32,64,128,256),
    strides=(2, 2, 2),
    # num_res_units=0,
)

x = (1, 1, 96, 96, 96)
print_model_summary(unet_dp, x)
print_model_summary(unet, x)

profile_model(unet_dp, x)
profile_model(unet, x)

# print(unet_dp)
# print("====================================")
# print(unet)

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_prelu() for <class 'torch.nn.modules.activation.PReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose3d'>.
Model: DP_UNet
FLOPs: 32,260,349,952.0, GFLOPs: 32.26
Parameters: 836,169.0
--------------------------------------------------
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_prelu() for <class 'torch.nn.modules.activation.PReLU'>.
[INFO] Register count_convNd() for <class 'torch.nn.mod

In [3]:
from monai.networks.nets import UNet

model = UNet(
    # img_size = (96, 96, 96),
    spatial_dims=3,
    in_channels=1,
    out_channels=7,
    channels=(48, 64, 80, 80),
    strides=(2, 2, 1),
    # num_res_units=0,
)

x = (1, 1, 96, 96, 96)
print_model_summary(model, x)

print(model)

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_prelu() for <class 'torch.nn.modules.activation.PReLU'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
Model: UNet
FLOPs: 43,854,151,680.0, GFLOPs: 43.85
Parameters: 856,189.0
--------------------------------------------------
UNet(
  (model): Sequential(
    (0): Convolution(
      (conv): Conv3d(1, 48, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      (adn): ADN(
        (N): InstanceNorm3d(48, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (D): Dropout(p=0.0, inplace=False)
        (A): PReLU(num_parameters=1)
      )
    )
    (1): SkipConnection(
      (subm

In [2]:
from src.models import UNet

model = UNet(
    # img_size = (96, 96, 96),
    spatial_dims=3,
    in_channels=1,
    out_channels=7,
    channels=(48, 64, 80, 80),
    strides=(2, 2, 1),
    # num_res_units=0,
)

x = (1, 1, 96, 96, 96)
print_model_summary(model, x)

print(model)

  from .autonotebook import tqdm as notebook_tqdm


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_prelu() for <class 'torch.nn.modules.activation.PReLU'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose3d'>.




Model: UNet
FLOPs: 43,854,151,680.0, GFLOPs: 43.85
Parameters: 856,189.0
--------------------------------------------------
UNet(
  (encoder1): Encoder(
    (conv): Convolution(
      (conv): Conv3d(1, 48, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      (adn): ADN(
        (N): InstanceNorm3d(48, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (D): Dropout(p=0.0, inplace=False)
        (A): PReLU(num_parameters=1)
      )
    )
  )
  (encoder2): Encoder(
    (conv): Convolution(
      (conv): Conv3d(48, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      (adn): ADN(
        (N): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (D): Dropout(p=0.0, inplace=False)
        (A): PReLU(num_parameters=1)
      )
    )
  )
  (encoder3): Encoder(
    (conv): Convolution(
      (conv): Conv3d(64, 80, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (adn): ADN(
        (N): I

In [5]:
from monai.networks.nets import UNETR

model = UNETR(
    img_size = (96, 96, 96),
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    # channels=(16, 32, 64, 128, 256),
    # strides=(2, 2, 2, 2),
    # num_res_units=2,
)

x = (1, 1, 96, 96, 96)
print_model_summary(model, x)


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm3d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
Model: UNETR
FLOPs: 82,521,317,376.0, GFLOPs: 82.52
Parameters: 92,617,937.0
--------------------------------------------------


In [2]:
from src.models import CSPBlock, UnetResBlock


in_channels = 64
out_channels = 128
imgsz  = 96
block = CSPBlock(
        spatial_dims=3,
        in_channels=in_channels,  # 입력 채널 수정
        out_channels=out_channels,
        kernel_size=3,
        stride=2,
        norm_name="batch",
        act_name=("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
        dropout=None,
        split_ratio=0.5,
        n=2
    )
x = (1, in_channels, imgsz, imgsz, imgsz)

# Print summaries
print("SwinTransformer Summary:")
print_model_summary(block, x)

block = UnetResBlock(
        spatial_dims=3,
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=3,
        stride=2,
        norm_name="batch",
        act_name=("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
        dropout=None,
     
    )

# Print summaries
print("UnetResBlock Summary:")
print_model_summary(block, x)

SwinTransformer Summary:
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
Model: CSPBlock
FLOPs: 51,640,270,848.0, GFLOPs: 51.64
Parameters: 466,944.0
--------------------------------------------------
UnetResBlock Summary:
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm3d'>.
Model: UnetResBlock
FLOPs: 74,459,381,760.0, GFLOPs: 74.46
Parameters: 672,512.0
--------------------------------------------------


In [3]:
import torch
import time

# 입력 텐서
x = torch.randn(1, 1, 8, 8, 8).cuda()

# 업샘플링 설정
upsample1 = torch.nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
upsample2 = torch.nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False)

# align_corners=True 실행 시간
start = time.time()
for _ in range(100):
    y1 = upsample1(x)
end = time.time()
print(f"align_corners=True 실행 시간: {end - start:.6f}초")

# align_corners=False 실행 시간
start = time.time()
for _ in range(100):
    y2 = upsample2(x)
end = time.time()
print(f"align_corners=False 실행 시간: {end - start:.6f}초")


align_corners=True 실행 시간: 0.007005초
align_corners=False 실행 시간: 0.003999초


In [2]:
from src.models.swincspunetr3plus import SwinCSPUNETR3plus
x = (1, 1, 96, 96, 96)
swin_unetr = SwinCSPUNETR3plus(
    img_size=(96, 96, 96),
    in_channels=1,
    out_channels=7,
    feature_size=48,
    depths=(2, 2, 2, 2),
    num_heads=(3, 6, 12, 24),
    norm_name="instance",
    drop_rate=0.0,
    attn_drop_rate=0.0,
    dropout_path_rate=0.0,
    normalize=True,
    use_checkpoint=True,
    spatial_dims=3,
    downsample="merging",
    use_v2=True,
    n=2,
)

# Print summaries
print("SwinCSPUNETR_unet Summary:")
print_model_summary(swin_unetr, x)

# Call the function
profile_model(swin_unetr, x)



  from .autonotebook import tqdm as notebook_tqdm


SwinCSPUNETR_unet Summary:
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool3d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose3d'>.
Model: SwinCSPUNETR3plus
FLOPs: 629,032,886,808.0, GFLOPs: 629.03
Parameters: 55,484,983.0
--------------------------------------------------
------

In [2]:
from src.models.swincspunetr import SwinCSPUNETR
x = (1, 1, 96, 96, 96)
swin_unetr = SwinCSPUNETR(
    img_size=(96, 96, 96),
    in_channels=1,
    out_channels=7,
    feature_size=48,
    depths=(2, 2, 2, 2),
    num_heads=(3, 6, 12, 24),
    norm_name="instance",
    drop_rate=0.0,
    attn_drop_rate=0.0,
    dropout_path_rate=0.0,
    normalize=True,
    use_checkpoint=True,
    spatial_dims=3,
    downsample="merging",
    use_v2=True,
    n=2,
)

# Print summaries
print("SwinCSPUNETR Summary:")
print_model_summary(swin_unetr, x)



SwinCSPUNETR Summary:
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose3d'>.
Model: SwinCSPUNETR
FLOPs: 289,518,045,720.0, GFLOPs: 289.52
Parameters: 62,104,375.0
--------------------------------------------------


In [5]:


# Example model
import torch.nn as nn
from monai.networks.nets import SwinUNETR, SwinTransformer

# SwinTransformer 테스트
swin_transformer = SwinTransformer(
    in_chans=1,
    embed_dim=48,
    window_size=(7, 7, 7),
    patch_size=(2, 2, 2),
    depths=(2, 2, 2, 2),
    num_heads=(3, 6, 12, 24),
    mlp_ratio=4.0,
    qkv_bias=True,
    drop_rate=0.0,
    attn_drop_rate=0.0,
    drop_path_rate=0.0,
    norm_layer=nn.LayerNorm,
    use_checkpoint=True,
    spatial_dims=3,
    downsample="merging",
    use_v2=True
)

# 전체 SwinUNETR 모델
swin_unetr = SwinUNETR(
    img_size=(96, 96, 96),
    in_channels=1,
    out_channels=7,
    feature_size=48,
    depths=(2, 2, 2, 2),
    num_heads=(3, 6, 12, 24),
    norm_name="instance",
    drop_rate=0.0,
    attn_drop_rate=0.0,
    dropout_path_rate=0.0,
    normalize=True,
    use_checkpoint=True,
    spatial_dims=3,
    downsample="merging",
    use_v2=True
)

# Input sizes
swin_transformer_input = (1, 1, 96, 96, 96)
swin_unetr_input = (1, 1, 96, 96, 96)

# # Print summaries
# print("SwinTransformer Summary:")
# print_model_summary(swin_transformer, swin_transformer_input)

print("\nComplete SwinUNETR Summary:")
print_model_summary(swin_unetr, swin_unetr_input)




Complete SwinUNETR Summary:
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm3d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose3d'>.
torch.Size([1, 48, 48, 48, 48])
torch.Size([1, 96, 24, 24, 24])
torch.Size([1, 192, 12, 12, 12])
torch.Size([1, 384, 6, 6, 6])
torch.Size([1, 768, 3, 3, 3])
enc0: torch.Size([1, 48, 96, 96, 96])
enc1 torch.Size([1, 48, 48, 48, 48])
torch.Size([1, 96, 24, 24, 24])
torch.Size([1, 192, 12, 12, 12])
torch.

torch.Size([1, 48, 48, 48, 48])
torch.Size([1, 96, 24, 24, 24])
torch.Size([1, 192, 12, 12, 12])
torch.Size([1, 384, 6, 6, 6])
torch.Size([1, 768, 3, 3, 3])
enc0: torch.Size([1, 48, 96, 96, 96])
enc1 torch.Size([1, 48, 48, 48, 48])
torch.Size([1, 96, 24, 24, 24])
torch.Size([1, 192, 12, 12, 12])
torch.Size([1, 768, 3, 3, 3])
Model: SwinUNETR
FLOPs: 329,543,087,640.0, GFLOPs: 329.54
Parameters: 61,989,223.0