In [1]:
from pathlib import Path
import sys

sys.path.append(str(Path(sys.path[0]).parent))

In [26]:

from isegm.model.is_plainvit_model import PlainVitModel
from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss


embed_dim = 160
neck_dims = [96, 192, 288, 384]

def params_vit_tiny_448(**kwargs):

    backbone_params = dict(img_size=(448, 448), patch_size=(16,16), in_chans=3,
        embed_dim=embed_dim, depth=8, num_heads=4, mlp_ratio=4, qkv_bias=True,)

    neck_params = dict(in_dim=embed_dim, out_dims=neck_dims,)

    head_params = dict(in_channels=neck_dims, in_index=[0, 1, 2, 3],
        dropout_ratio=0.1, num_classes=1, loss_decode=CrossEntropyLoss(),
        align_corners=False, channels=128,)

    return backbone_params, neck_params, head_params


vith_backbone_params, vith_neck_params, vith_head_params = params_vit_tiny_448()
model_vitt = PlainVitModel(use_disks=True, norm_radius=5, with_prev_mask=True, 
    backbone_params=vith_backbone_params, neck_params=vith_neck_params, 
    head_params=vith_head_params)

In [3]:
def get_params_count(model):
    return sum(param.numel() for param in model.parameters())

for model in [model_vitt]:
    print('----')
    print('{:.1f} M'.format(get_params_count(model) / 1024 / 1024))
    print('{:.1f} M'.format(get_params_count(model) * 4.0 / 1024 / 1024))
    print('{:.1f} M'.format(get_params_count(model.backbone) * 4.0 / 1024 / 1024))
    print('{:.1f} M'.format(get_params_count(model.neck) * 4.0 / 1024 / 1024))
    print('{:.1f} M'.format(get_params_count(model.head) * 4.0 / 1024 / 1024))

----
13.6 M
54.3 M
38.7 M
11.4 M
3.6 M


In [27]:
import torch
from thop import profile
from thop import clever_format


# ViT-B, ViT-L, ViT-H 448
input = torch.randn(1, 4, 448, 448)
point = torch.randn(1, 2, 3)

for model in [model_vitt]:
    model.eval()

    macs, params = profile(model, inputs=(input, point))
    gflops, params = clever_format([macs*2, params], "%.5f")

    print(gflops, params)


vitb_backbone_params, vitb_neck_params, vitb_head_params = params_vit_tiny_448()
model_vitb = PlainVitModel(use_disks=True, norm_radius=5, with_prev_mask=True, 
    backbone_params=vitb_backbone_params, neck_params=vitb_neck_params, 
    head_params=vitb_head_params)

# ViT-B-224
input = torch.randn(1, 4, 448, 448)
point = torch.randn(1, 2, 3)

for model in [model_vitb]:
    model.eval()

    macs, params = profile(model, inputs=(input, point))
    gflops, params = clever_format([macs*2, params], "%.5f")

    print(gflops, params)


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
10.52171G 3.71619M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register coun