In [1]:
from pathlib import Path
import sys

sys.path.append(str(Path(sys.path[0]).parent))

In [5]:

from isegm.model.is_plainvit_model import PlainVitModel
from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss


model_folder = '/playpen-raid2/qinliu/projects/iSegFormer/weights'
vitb_path = model_folder + '/sbd_vitb_epoch_54.pth'
vitl_path = model_folder + '/sbd_vitl_epoch_54.pth'
vith_path = model_folder + '/sbd_vith_epoch_54.pth'


def params_vit_base_224(**kwargs):

    backbone_params = dict(img_size=(224, 224), patch_size=(16,16), in_chans=3,
        embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,)

    neck_params = dict(in_dim=768, out_dims=[128, 256, 512, 1024],)

    head_params = dict(in_channels=[128, 256, 512, 1024], in_index=[0, 1, 2, 3],
        dropout_ratio=0.1, num_classes=1, loss_decode=CrossEntropyLoss(),
        align_corners=False, channels=256,)

    return backbone_params, neck_params, head_params


def params_vit_base_448(**kwargs):

    backbone_params = dict(img_size=(448, 448), patch_size=(16,16), in_chans=3,
        embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,)

    neck_params = dict(in_dim=768, out_dims=[128, 256, 512, 1024],)

    head_params = dict(in_channels=[128, 256, 512, 1024], in_index=[0, 1, 2, 3],
        dropout_ratio=0.1, num_classes=1, loss_decode=CrossEntropyLoss(),
        align_corners=False, channels=256,)

    return backbone_params, neck_params, head_params


def params_vit_large_448(**kwargs):

    backbone_params = dict(img_size=(448, 448), patch_size=(16,16), in_chans=3,
        embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,)

    neck_params = dict(in_dim=1024, out_dims=[192, 384, 768, 1536],)

    head_params = dict(in_channels=[192, 384, 768, 1536], in_index=[0, 1, 2, 3],
        dropout_ratio=0.1, num_classes=1, loss_decode=CrossEntropyLoss(),
        align_corners=False, channels=256,)

    return backbone_params, neck_params, head_params


def params_vit_huge_448(**kwargs):

    backbone_params = dict(img_size=(448, 448), patch_size=(14,14), in_chans=3,
        embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,)

    neck_params = dict(in_dim=1280, out_dims=[240, 480, 960, 1920],)

    head_params = dict(in_channels=[240, 480, 960, 1920], in_index=[0, 1, 2, 3],
        dropout_ratio=0.1, num_classes=1, loss_decode=CrossEntropyLoss(),
        align_corners=False, channels=256,)

    return backbone_params, neck_params, head_params


vitb_backbone_params, vitb_neck_params, vitb_head_params = params_vit_base_448()
model_vitb = PlainVitModel(use_disks=True, norm_radius=5, with_prev_mask=True, 
    backbone_params=vitb_backbone_params, neck_params=vitb_neck_params, 
    head_params=vitb_head_params)

vitl_backbone_params, vitl_neck_params, vitl_head_params = params_vit_large_448()
model_vitl = PlainVitModel(use_disks=True, norm_radius=5, with_prev_mask=True, 
    backbone_params=vitl_backbone_params, neck_params=vitl_neck_params, 
    head_params=vitl_head_params)

vith_backbone_params, vith_neck_params, vith_head_params = params_vit_huge_448()
model_vith = PlainVitModel(use_disks=True, norm_radius=5, with_prev_mask=True, 
    backbone_params=vith_backbone_params, neck_params=vith_neck_params, 
    head_params=vith_head_params)

In [13]:
def get_params_count(model):
    return sum(param.numel() for param in model.parameters())

for model in [model_vitb, model_vitl, model_vith]:
    print('----')
    print('{:.1f} M'.format(get_params_count(model) * 4.0 / 1024 / 1024))
    print('{:.1f} M'.format(get_params_count(model.backbone) * 4.0 / 1024 / 1024))
    print('{:.1f} M'.format(get_params_count(model.neck) * 4.0 / 1024 / 1024))
    print('{:.1f} M'.format(get_params_count(model.head) * 4.0 / 1024 / 1024))

----
373.9 M
332.0 M
36.1 M
3.6 M
----
1236.7 M
1163.2 M
66.0 M
4.5 M
----
2526.0 M
2414.8 M
103.1 M
5.2 M


In [6]:
import torch
from thop import profile
from thop import clever_format


# ViT-B, ViT-L, ViT-H 448
input = torch.randn(1, 4, 448, 448)
point = torch.randn(1, 2, 3)

for model in [model_vitb, model_vitl, model_vith]:
    model.eval()

    macs, params = profile(model, inputs=(input, point))
    gflops, params = clever_format([macs*2, params], "%.5f")

    print(gflops, params)


vitb_backbone_params, vitb_neck_params, vitb_head_params = params_vit_base_224()
model_vitb = PlainVitModel(use_disks=True, norm_radius=5, with_prev_mask=True, 
    backbone_params=vitb_backbone_params, neck_params=vitb_neck_params, 
    head_params=vitb_head_params)

# ViT-B-224
input = torch.randn(1, 4, 224, 224)
point = torch.randn(1, 2, 3)

for model in [model_vitb]:
    model.eval()

    macs, params = profile(model, inputs=(input, point))
    gflops, params = clever_format([macs*2, params], "%.5f")

    print(gflops, params)


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
169.77953G 96.45703M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register co