In [10]:
from torchsummary import summary
from thop import profile
import torch
def print_model_summary(model, input_size):
    # Create a random input tensor with the given input size
    input_tensor = torch.randn(input_size)

    # Calculate FLOPs and parameters
    flops, params = profile(model, inputs=(input_tensor,))

    print(f"FLOPs: {flops:,}, GFLOPs: {flops / 1e9:.2f}")
    print(f"Parameters: {params:,}")

# Example model
import torch.nn as nn
from monai.networks.nets import UNETR, SwinUNETR

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


model = SwinUNETR(
    img_size=(96, 96, 96),
    in_channels=1,
    out_channels=7,
    feature_size=48,
    use_checkpoint=True,
    drop_rate = 0.25,
    attn_drop_rate = 0.25,
    use_v2 = True,
).to(device)
# Input size (batch_size, channels, height, width)
input_size = (1, 1, 96, 96, 96)


# Example usage
print_model_summary(model, input_size)

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm3d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose3d'>.
FLOPs: 355,370,190,360.0, GFLOPs: 355.37
Parameters: 72,564,583.0
