In [1]:
import torch
import collections
import copy

In [2]:
detectron_model = torch.load('../detectron2/exps/rpn_r50_fpn_1x/model_final.pth', map_location='cpu')

In [3]:
detectron_model['model'].keys()

odict_keys(['backbone.fpn_lateral2.weight', 'backbone.fpn_lateral2.bias', 'backbone.fpn_output2.weight', 'backbone.fpn_output2.bias', 'backbone.fpn_lateral3.weight', 'backbone.fpn_lateral3.bias', 'backbone.fpn_output3.weight', 'backbone.fpn_output3.bias', 'backbone.fpn_lateral4.weight', 'backbone.fpn_lateral4.bias', 'backbone.fpn_output4.weight', 'backbone.fpn_output4.bias', 'backbone.fpn_lateral5.weight', 'backbone.fpn_lateral5.bias', 'backbone.fpn_output5.weight', 'backbone.fpn_output5.bias', 'backbone.bottom_up.stem.conv1.weight', 'backbone.bottom_up.stem.conv1.norm.weight', 'backbone.bottom_up.stem.conv1.norm.bias', 'backbone.bottom_up.stem.conv1.norm.running_mean', 'backbone.bottom_up.stem.conv1.norm.running_var', 'backbone.bottom_up.res2.0.shortcut.weight', 'backbone.bottom_up.res2.0.shortcut.norm.weight', 'backbone.bottom_up.res2.0.shortcut.norm.bias', 'backbone.bottom_up.res2.0.shortcut.norm.running_mean', 'backbone.bottom_up.res2.0.shortcut.norm.running_var', 'backbone.bottom_

In [4]:
mmdet_model = torch.load('./exps/mask_rcnn_r50_fpn_ms1333_detectron2-caffe-freezeBN_l1-loss_roialign-v2_1x_noctr/latest.pth', map_location='cpu')

In [5]:
mmdet_model['state_dict'].keys()

odict_keys(['backbone.conv1.weight', 'backbone.bn1.weight', 'backbone.bn1.bias', 'backbone.bn1.running_mean', 'backbone.bn1.running_var', 'backbone.bn1.num_batches_tracked', 'backbone.layer1.0.conv1.weight', 'backbone.layer1.0.bn1.weight', 'backbone.layer1.0.bn1.bias', 'backbone.layer1.0.bn1.running_mean', 'backbone.layer1.0.bn1.running_var', 'backbone.layer1.0.bn1.num_batches_tracked', 'backbone.layer1.0.conv2.weight', 'backbone.layer1.0.bn2.weight', 'backbone.layer1.0.bn2.bias', 'backbone.layer1.0.bn2.running_mean', 'backbone.layer1.0.bn2.running_var', 'backbone.layer1.0.bn2.num_batches_tracked', 'backbone.layer1.0.conv3.weight', 'backbone.layer1.0.bn3.weight', 'backbone.layer1.0.bn3.bias', 'backbone.layer1.0.bn3.running_mean', 'backbone.layer1.0.bn3.running_var', 'backbone.layer1.0.bn3.num_batches_tracked', 'backbone.layer1.0.downsample.0.weight', 'backbone.layer1.0.downsample.1.weight', 'backbone.layer1.0.downsample.1.bias', 'backbone.layer1.0.downsample.1.running_mean', 'backbone.

In [6]:
new_state_dict = collections.OrderedDict()
mmdet_model_copy = copy.deepcopy(mmdet_model['state_dict'])
for k, v in detectron_model['model'].items():
    key_name_split = k.split('.')
    if 'backbone.fpn_lateral' in k:
        lateral_id = int(key_name_split[-2][-1])
        name = f'neck.lateral_convs.{lateral_id-2}.conv.{key_name_split[-1]}'
    elif 'backbone.fpn_output' in k:
        lateral_id = int(key_name_split[-2][-1])
        name = f'neck.fpn_convs.{lateral_id-2}.conv.{key_name_split[-1]}'
    elif 'backbone.bottom_up.stem.conv1.norm.' in k:
        name = f'backbone.bn1.{key_name_split[-1]}'
    elif 'backbone.bottom_up.stem.conv1.' in k:
        name = f'backbone.conv1.{key_name_split[-1]}'
    elif 'backbone.bottom_up.res' in k:
        weight_type = key_name_split[-1]
        res_id = int(key_name_split[2][-1]) - 1
        # deal with short cut
        if 'shortcut' in key_name_split[4]:
            if 'shortcut' == key_name_split[-2]:
                name = f'backbone.layer{res_id}.{key_name_split[3]}.downsample.0.{key_name_split[-1]}'
            elif 'shortcut' == key_name_split[-3]:
                name = f'backbone.layer{res_id}.{key_name_split[3]}.downsample.1.{key_name_split[-1]}'
            else:
                print(f"Unvalid key {k}")
        # deal with conv
        elif 'conv' in key_name_split[-2]:
            conv_id = int(key_name_split[-2][-1])
            name = f'backbone.layer{res_id}.{key_name_split[3]}.conv{conv_id}.{key_name_split[-1]}'
        # deal with BN
        elif key_name_split[-2] == 'norm':
            conv_id = int(key_name_split[-3][-1])
            name = f'backbone.layer{res_id}.{key_name_split[3]}.bn{conv_id}.{key_name_split[-1]}'
        else:
            print(f"{k} is invalid")
    elif 'proposal_generator.anchor_generator' in k:
        continue
    elif 'rpn' in k:
        if 'conv' in key_name_split[2]:
            name = f'rpn_head.rpn_conv.{key_name_split[-1]}'
        elif 'objectness_logits' in key_name_split[2]:
            name = f'rpn_head.rpn_cls.{key_name_split[-1]}'
        elif 'anchor_deltas' in key_name_split[2]:
            name = f'rpn_head.rpn_reg.{key_name_split[-1]}'
        else:
            print(f"{k} is invalid")
    elif 'roi_heads' in k:
        if key_name_split[1] == 'box_head':
            fc_id = int(key_name_split[2][-1]) - 1
            name = f'bbox_head.shared_fcs.{fc_id}.{key_name_split[-1]}'
        elif 'cls_score' == key_name_split[2]:
            name = f'bbox_head.fc_cls.{key_name_split[-1]}'
        elif 'bbox_pred' == key_name_split[2]:
            name = f'bbox_head.fc_reg.{key_name_split[-1]}'
        elif 'mask_fcn' in key_name_split[2]:
            conv_id = int(key_name_split[2][-1])-1
            name = f'mask_head.convs.{conv_id}.conv.{key_name_split[-1]}'
        elif 'deconv' in key_name_split[2]:
            name = f'mask_head.upsample.{key_name_split[-1]}'
        elif'roi_heads.mask_head.predictor' in k:
            name = f'mask_head.conv_logits.{key_name_split[-1]}'
        else:
            print(f"{k} is invalid")
    else:
        print(f"{k} is not converted!!")
    
    assert name in mmdet_model_copy.keys(), f"{k} converted to {name} but not in models"
    # print(f"{name} popped")
    mmdet_model_copy.pop(name)
    new_state_dict[name] = v

In [7]:
mmdet_model['state_dict'] = new_state_dict
torch.save(mmdet_model, 'mmdet_detectron2_rpn_r50.pth')