In [19]:
import argparse
import pickle
import re

import torch

import sys
sys.path.insert(0, '..')
from models.baselines.nonlocal_net import resnet50, resnet101
# from models.convert_caffe2pytorch import convert_i3d_weights

In [20]:
charades50_file = '/data/OnlineActionRecognition/models/pre-trained/long-term-feature-banks/charades_r50_baseline_i3d_nl.pkl'
charades101_file = '/data/OnlineActionRecognition/models/pre-trained/long-term-feature-banks/charades_r101_baseline_i3d_nl.pkl'
ava50_file = '/data/OnlineActionRecognition/models/pre-trained/long-term-feature-banks/ava_r50_baseline_i3d_nl.pkl'
ava101_file = '/data/OnlineActionRecognition/models/pre-trained/long-term-feature-banks/ava_r101_baseline_i3d_nl.pkl'

In [28]:
data = pickle.load(open(charades50_file, 'rb'), encoding='latin')['blobs']
data = {k: v for k, v in data.items() if 'momentum' not in k}
sorted(data.keys())

['conv1_w',
 'lr',
 'model_iter',
 'nonlocal_conv3_1_bn_b',
 'nonlocal_conv3_1_bn_s',
 'nonlocal_conv3_1_g_b',
 'nonlocal_conv3_1_g_w',
 'nonlocal_conv3_1_out_b',
 'nonlocal_conv3_1_out_w',
 'nonlocal_conv3_1_phi_b',
 'nonlocal_conv3_1_phi_w',
 'nonlocal_conv3_1_theta_b',
 'nonlocal_conv3_1_theta_w',
 'nonlocal_conv3_3_bn_b',
 'nonlocal_conv3_3_bn_s',
 'nonlocal_conv3_3_g_b',
 'nonlocal_conv3_3_g_w',
 'nonlocal_conv3_3_out_b',
 'nonlocal_conv3_3_out_w',
 'nonlocal_conv3_3_phi_b',
 'nonlocal_conv3_3_phi_w',
 'nonlocal_conv3_3_theta_b',
 'nonlocal_conv3_3_theta_w',
 'nonlocal_conv4_1_bn_b',
 'nonlocal_conv4_1_bn_s',
 'nonlocal_conv4_1_g_b',
 'nonlocal_conv4_1_g_w',
 'nonlocal_conv4_1_out_b',
 'nonlocal_conv4_1_out_w',
 'nonlocal_conv4_1_phi_b',
 'nonlocal_conv4_1_phi_w',
 'nonlocal_conv4_1_theta_b',
 'nonlocal_conv4_1_theta_w',
 'nonlocal_conv4_3_bn_b',
 'nonlocal_conv4_3_bn_s',
 'nonlocal_conv4_3_g_b',
 'nonlocal_conv4_3_g_w',
 'nonlocal_conv4_3_out_b',
 'nonlocal_conv4_3_out_w',
 'nonl

In [21]:
def convert_i3d_weights(weigths_file_path, new_model, save_model=False, new_model_name=None):
    """Expanded from https://github.com/Tushar-N/pytorch-resnet3d
    Convert pre-trained weights of I3DResnet from caffe2 to Pytorch.
    """

    data = pickle.load(open(weigths_file_path, 'rb'), encoding='latin')['blobs']
    data = {k: v for k, v in data.items() if 'momentum' not in k}

    downsample_pat = re.compile('res(.)_(.)_branch1_.*')
    conv_pat = re.compile('res(.)_(.)_branch2(.)_.*')
    nonlocal_pat = re.compile('nonlocal_conv(.)_(.)_*')
    m2num = dict(zip('abc', [1, 2, 3]))
    suffix_dict = {
        'b': 'bias', 'w': 'weight', 's': 'weight'}
    nonlocal_dict = {'out': 'W.0', 'bn': 'W.1', 'phi': 'phi.1', 'g': 'g.1', 'theta': 'theta'}

    key_map = {'conv1.weight': 'conv1_w',
               'bn1.weight': 'res_conv1_bn_s',
               'bn1.bias': 'res_conv1_bn_b',
               'fc.weight': 'pred_w',
               'fc.bias': 'pred_b'}

    for key in data:
        conv_match = conv_pat.match(key)
        if conv_match:
            layer, block, module = conv_match.groups()
            layer, block, module = int(layer), int(block), m2num[module]
            name = 'bn' if 'bn_' in key else 'conv'
            suffix = suffix_dict[key.split('_')[-1]]
            new_key = 'layer%d.%d.%s%d.%s' % (layer-1, block, name, module, suffix)
            key_map[new_key] = key

        ds_match = downsample_pat.match(key)
        if ds_match:
            layer, block = ds_match.groups()
            layer, block = int(layer), int(block)
            module = 0 if key[-1] == 'w' else 1
            name = 'downsample'
            suffix = suffix_dict[key.split('_')[-1]]
            new_key = 'layer%d.%d.%s.%d.%s' % (layer-1, block, name, module, suffix)
            key_map[new_key] = key

        nl_match = nonlocal_pat.match(key)
        if nl_match:
            layer, block = nl_match.groups()
            layer, block = int(layer), int(block)
            nl_op = nonlocal_dict[key.split('_')[-2]]
            suffix = suffix_dict[key.split('_')[-1]]
            name = 'nonlocal_block'
            new_key = 'layer%d.%d.%s.%s.%s' % (layer-1, block, name, nl_op, suffix)
            key_map[new_key] = key

    state_dict = new_model.state_dict()

    new_state_dict = {
        key: torch.from_numpy(data[key_map[key]]) for key in state_dict if key in key_map}

    # Check if weight dimensions match
    for key in state_dict:

        if key not in key_map:
            continue

        data_v, pth_v = data[key_map[key]], state_dict[key]
        assert str(tuple(data_v.shape)) == str(tuple(pth_v.shape)), (
            'Size Mismatch {} != {} in {}').format(data_v.shape, pth_v.shape, key)
        print('{:24s} --> {:40s} | {:21s}'.format(key_map[key], key, str(tuple(data_v.shape))))

    if save_model:
        # Saving new model weights
        torch.save(new_state_dict, '{}.pth'.format(new_model_name))

    return new_model

In [34]:
blank_resnet50_i3d = resnet50(non_local=True, frame_num=32, num_classes=157)
blank_resnet101_i3d = resnet101(non_local=True, frame_num=32, num_classes=157)
ava_blank_resnet50_i3d = resnet50(non_local=True, frame_num=32, num_classes=80)
ava_blank_resnet101_i3d = resnet101(non_local=True, frame_num=32, num_classes=80)

In [25]:
charades50_i3dnl = convert_i3d_weights(charades50_file, blank_resnet50_i3d, save_model=True,
                                     new_model_name='/data/OnlineActionRecognition/models/pre-trained/'
                                     'long-term-feature-banks/charades_r50_i3d_nl_32x2')

conv1_w                  --> conv1.weight                             | (64, 3, 5, 7, 7)     
res_conv1_bn_s           --> bn1.weight                               | (64,)                
res_conv1_bn_b           --> bn1.bias                                 | (64,)                
res2_0_branch2a_w        --> layer1.0.conv1.weight                    | (64, 64, 3, 1, 1)    
res2_0_branch2a_bn_s     --> layer1.0.bn1.weight                      | (64,)                
res2_0_branch2a_bn_b     --> layer1.0.bn1.bias                        | (64,)                
res2_0_branch2b_w        --> layer1.0.conv2.weight                    | (64, 64, 1, 3, 3)    
res2_0_branch2b_bn_s     --> layer1.0.bn2.weight                      | (64,)                
res2_0_branch2b_bn_b     --> layer1.0.bn2.bias                        | (64,)                
res2_0_branch2c_w        --> layer1.0.conv3.weight                    | (256, 64, 1, 1, 1)   
res2_0_branch2c_bn_s     --> layer1.0.bn3.weight            

In [32]:
charades101_i3dnl = convert_i3d_weights(charades101_file, blank_resnet101_i3d, save_model=True,
                                        new_model_name='/data/OnlineActionRecognition/models/pre-trained/'
                                        'long-term-feature-banks/charades_r101_i3d_nl_32x2')

conv1_w                  --> conv1.weight                             | (64, 3, 5, 7, 7)     
res_conv1_bn_s           --> bn1.weight                               | (64,)                
res_conv1_bn_b           --> bn1.bias                                 | (64,)                
res2_0_branch2a_w        --> layer1.0.conv1.weight                    | (64, 64, 3, 1, 1)    
res2_0_branch2a_bn_s     --> layer1.0.bn1.weight                      | (64,)                
res2_0_branch2a_bn_b     --> layer1.0.bn1.bias                        | (64,)                
res2_0_branch2b_w        --> layer1.0.conv2.weight                    | (64, 64, 1, 3, 3)    
res2_0_branch2b_bn_s     --> layer1.0.bn2.weight                      | (64,)                
res2_0_branch2b_bn_b     --> layer1.0.bn2.bias                        | (64,)                
res2_0_branch2c_w        --> layer1.0.conv3.weight                    | (256, 64, 1, 1, 1)   
res2_0_branch2c_bn_s     --> layer1.0.bn3.weight            

In [35]:
ava50_i3dnl = convert_i3d_weights(ava50_file, ava_blank_resnet50_i3d, save_model=True,
                                  new_model_name='/data/OnlineActionRecognition/models/pre-trained/'
                                  'long-term-feature-banks/ava_r50_i3d_nl_32x2')

conv1_w                  --> conv1.weight                             | (64, 3, 5, 7, 7)     
res_conv1_bn_s           --> bn1.weight                               | (64,)                
res_conv1_bn_b           --> bn1.bias                                 | (64,)                
res2_0_branch2a_w        --> layer1.0.conv1.weight                    | (64, 64, 3, 1, 1)    
res2_0_branch2a_bn_s     --> layer1.0.bn1.weight                      | (64,)                
res2_0_branch2a_bn_b     --> layer1.0.bn1.bias                        | (64,)                
res2_0_branch2b_w        --> layer1.0.conv2.weight                    | (64, 64, 1, 3, 3)    
res2_0_branch2b_bn_s     --> layer1.0.bn2.weight                      | (64,)                
res2_0_branch2b_bn_b     --> layer1.0.bn2.bias                        | (64,)                
res2_0_branch2c_w        --> layer1.0.conv3.weight                    | (256, 64, 1, 1, 1)   
res2_0_branch2c_bn_s     --> layer1.0.bn3.weight            

In [36]:
ava101_i3dnl = convert_i3d_weights(ava101_file, ava_blank_resnet101_i3d, save_model=True,
                                   new_model_name='/data/OnlineActionRecognition/models/pre-trained/'
                                   'long-term-feature-banks/ava_r101_i3d_nl_32x2')

conv1_w                  --> conv1.weight                             | (64, 3, 5, 7, 7)     
res_conv1_bn_s           --> bn1.weight                               | (64,)                
res_conv1_bn_b           --> bn1.bias                                 | (64,)                
res2_0_branch2a_w        --> layer1.0.conv1.weight                    | (64, 64, 3, 1, 1)    
res2_0_branch2a_bn_s     --> layer1.0.bn1.weight                      | (64,)                
res2_0_branch2a_bn_b     --> layer1.0.bn1.bias                        | (64,)                
res2_0_branch2b_w        --> layer1.0.conv2.weight                    | (64, 64, 1, 3, 3)    
res2_0_branch2b_bn_s     --> layer1.0.bn2.weight                      | (64,)                
res2_0_branch2b_bn_b     --> layer1.0.bn2.bias                        | (64,)                
res2_0_branch2c_w        --> layer1.0.conv3.weight                    | (256, 64, 1, 1, 1)   
res2_0_branch2c_bn_s     --> layer1.0.bn3.weight            