In [27]:
import torch
from models.swin_transformer import SwinTransformer

In [28]:
model = SwinTransformer(
    in_chans=3,
    patch_size=4,
    embed_dim=96,
    depths=[2,2,6,2],
    num_heads=[3,6,12,24],
    window_size=14,
    mlp_ratio=4,
    qkv_bias=True,
    drop_rate=0,
    attn_drop_rate=0,
    drop_path_rate= 0,
    ape=False,
    patch_norm=True,
)

In [29]:
state_dict = torch.load("/home/yishido/pretrained_model/esvit/checkpoint_best.pth", map_location="cpu")

In [30]:
state_dict = state_dict["teacher"]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}

In [31]:
model.load_state_dict(state_dict, strict=False)#mlpheadはいらないよね？

_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=['head_dense.mlp.0.weight', 'head_dense.mlp.0.bias', 'head_dense.mlp.2.weight', 'head_dense.mlp.2.bias', 'head_dense.mlp.4.weight', 'head_dense.mlp.4.bias', 'head_dense.last_layer.weight_g', 'head_dense.last_layer.weight_v', 'head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])

In [32]:
model

SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0, inplace=False)
  (layers): ModuleList(
    (0): BasicLayer(
      dim=96, input_resolution=(56, 56), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          dim=96, input_resolution=(56, 56), num_heads=3, window_size=14, shift_size=0 mlp_ratio=4
          (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=96, window_size=(14, 14), num_heads=3
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (attn_drop): Dropout(p=0, inplace=False)
            (proj): Linear(in_features=96, out_features=96, bias=True)
            (proj_drop): Dropout(p=0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((96

In [33]:
import torchinfo
col_names=["input_size","output_size","num_params", "mult_adds"]
torchinfo.summary(model, col_names=col_names, input_size=(1, 3, 224, 224))

Layer (type:depth-idx)                             Input Shape               Output Shape              Param #                   Mult-Adds
SwinTransformer                                    [1, 3, 224, 224]          [1, 1000]                 --                        --
├─PatchEmbed: 1-1                                  [1, 3, 224, 224]          [1, 3136, 96]             --                        --
│    └─Conv2d: 2-1                                 [1, 3, 224, 224]          [1, 96, 56, 56]           4,704                     14,751,744
│    └─LayerNorm: 2-2                              [1, 3136, 96]             [1, 3136, 96]             192                       192
├─Dropout: 1-2                                     [1, 3136, 96]             [1, 3136, 96]             --                        --
├─ModuleList: 1-3                                  --                        --                        --                        --
│    └─BasicLayer: 2-3                             [1, 3136,

In [38]:
def extract(target, inputs):#抽出する関数
    feature = None

    def forward_hook(module, inputs, outputs):
        # 順伝搬の出力を features というグローバル変数に記録する
        global features
        # 1. detach でグラフから切り離す。
        # 2. clone() でテンソルを複製する。モデルのレイヤーで ReLU(inplace=True) のように
        #    inplace で行う層があると、値がその後のレイヤーで書き換えられてまい、
        #    指定した層の出力が取得できない可能性があるため、clone() が必要。
        features = outputs.detach().clone()

    # コールバック関数を登録する。
    handle = target.register_forward_hook(forward_hook)

    # 推論する
    model.eval()
    model(inputs)

    # コールバック関数を解除する。
    handle.remove()

    return features

In [39]:
model.to("cuda")
image = torch.randn(1,3,224,224).to("cuda")
target_module = model.norm

In [40]:
emb = extract(target_module, image)

In [41]:
emb.shape

torch.Size([1, 49, 768])