In [1]:
models = {
    "ViTPose_base_coco_256x192": dict(
    type='TopDown',
    pretrained=None,
    backbone=dict(
        type='ViT',
        img_size=(256, 192),
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        ratio=1,
        use_checkpoint=False,
        mlp_ratio=4,
        qkv_bias=True,
        drop_path_rate=0.3,
    ),
    keypoint_head=dict(
        type='TopdownHeatmapSimpleHead',
        in_channels=768,
        num_deconv_layers=2,
        num_deconv_filters=(256, 256),
        num_deconv_kernels=(4, 4),
        extra=dict(final_conv_kernel=1, ),
        out_channels=17,
        loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
    train_cfg=dict(),
    test_cfg=dict()),
    
"ViTPose_base_simple_coco_256x192":  dict(
    type='TopDown',
    pretrained=None,
    backbone=dict(
        type='ViT',
        img_size=(256, 192),
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        ratio=1,
        use_checkpoint=False,
        mlp_ratio=4,
        qkv_bias=True,
        drop_path_rate=0.3,
    ),
    keypoint_head=dict(
        type='TopdownHeatmapSimpleHead',
        in_channels=768,
        num_deconv_layers=0,
        num_deconv_filters=[],
        num_deconv_kernels=[],
        upsample=4,
        extra=dict(final_conv_kernel=3, ),
        out_channels=17,
        loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
    train_cfg=dict(),
    test_cfg=dict(
        flip_test=True,
        post_process='default',
        shift_heatmap=False,
        target_type='GaussianHeatmap',
        modulate_kernel=11,
        use_udp=True))

    }


In [2]:
import torch
import torch.nn as nn

from pretrained_models.vit import ViT
from pretrained_models.topdown_heatmap_simple_head import TopdownHeatmapSimpleHead


def build_vitpose(model_name,checkpoint=None):
    try:
        # path = 'builder.configs.coco.'+model_name
        # mod = import_module(
        #     path
        # )
        
        # model = getattr(mod, "model")
        model = models[model_name]
        # from path import model
    except:
        raise ValueError('not a correct config')

        
    head = TopdownHeatmapSimpleHead(in_channels=model['keypoint_head']['in_channels'], 
                                    out_channels=model['keypoint_head']['out_channels'],
                                    num_deconv_filters=model['keypoint_head']['num_deconv_filters'],
                                    num_deconv_kernels=model['keypoint_head']['num_deconv_kernels'],
                                    num_deconv_layers=model['keypoint_head']['num_deconv_layers'],
                                    extra=model['keypoint_head']['extra'])
    # print(head)
    backbone = ViT(img_size=model['backbone']['img_size'],
                patch_size=model['backbone']['patch_size']
                ,embed_dim=model['backbone']['embed_dim'],
                depth=model['backbone']['depth'],
                num_heads=model['backbone']['num_heads'],
                ratio = model['backbone']['ratio'],
                mlp_ratio=model['backbone']['mlp_ratio'],
                qkv_bias=model['backbone']['qkv_bias'],
                drop_path_rate=model['backbone']['drop_path_rate']
                )

    class VitPoseModel(nn.Module):
        def __init__(self,backbone,keypoint_head):
            super(VitPoseModel, self).__init__()
            self.backbone = backbone
            self.keypoint_head = keypoint_head
        def forward(self,x):
            x = self.backbone(x)
            x = self.keypoint_head(x)
            return x
    
    pose = VitPoseModel(backbone, head)
    if checkpoint is not None:
        check = torch.load(checkpoint)
        
        pose.load_state_dict(check['state_dict'])
    return pose


In [3]:
pose_path = 'vitpose-b.pth'
pose = build_model('ViTPose_base_coco_256x192',pose_path)
print(pose)

VitPoseModel(
  (backbone): ViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), padding=(2, 2))
    )
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwi

In [4]:
# Test with a dummy input to ensure the model works correctly
dummy_input = torch.randn(1, 3, 256, 192)
output1 = pose(dummy_input)
print(output1.shape)

torch.Size([1, 17, 64, 48])


In [11]:
class MLPHead(nn.Module):
    def __init__(self, in_features, num_outputs=15):
        super(MLPHead, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, num_outputs)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the tensor from (batch_size, embed_dim, height, width) to (batch_size, embed_dim * height * width)
        return self.fc(x)

class ModifiedViTPoseModel(nn.Module):
    def __init__(self, original_model, num_outputs=15):
        super(ModifiedViTPoseModel, self).__init__()
        self.backbone = original_model.backbone
        # Adjusting the input features for MLPHead based on the flattened output of the backbone
        in_features = 768 * 16 * 12
        self.mlp_head = MLPHead(in_features=in_features, num_outputs=num_outputs)

    def forward(self, x):
        x = self.backbone(x)  # Output shape: (batch_size, embed_dim, height, width)
        x = self.mlp_head(x)
        return x

# Assuming 'model' is your loaded original ViTPoseModel
modified_model = ModifiedViTPoseModel(pose, num_outputs=15)
print(modified_model)

ModifiedViTPoseModel(
  (backbone): ViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), padding=(2, 2))
    )
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): Block(
        (norm1): LayerNorm((768,), eps=1e-06, e

In [12]:
dummy_input = torch.randn(1, 3, 256, 192)
output = modified_model(dummy_input)
print(output.shape)

torch.Size([1, 15])
