In [1]:
import numpy as np
import random
import os
import argparse
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from main.layout import Layout
from main.algorithms import enum_layout_wo_rdt, init_S, coarse_to_fined
from main.auto_models import MTSeqBackbone, MTSeqModel
from main.head import ASPPHeadNode

In [2]:
seed = 10
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

# Generate backbone init weights

In [3]:
# prepare params (number of blocks and feature dim) automatically 
# prototxt = 'models/deeplab_resnet34_adashare.prototxt'
# backbone_PATH = '/mnt/nfs/work1/huiguan/lijunzhang/multibranch/checkpoint/init/resnet34_backbone.model'
prototxt = 'models/mobilenetv2_shorter.prototxt'
backbone_PATH = '/mnt/nfs/work1/huiguan/lijunzhang/multibranch/checkpoint/init/mobilenetv2_shorter_backbone.model'

with torch.cuda.device(0):
    backbone = MTSeqBackbone(prototxt)
    fined_B = len(backbone.basic_blocks)
    feature_dim = backbone(torch.rand(1,3,224,224)).shape[1]

In [4]:
backbone_init = backbone.state_dict()
new_backbone_init = backbone_init.copy()
for old_key in backbone_init:
    new_key = old_key[13:]
    new_backbone_init[new_key] = new_backbone_init.pop(old_key)
print(new_backbone_init)

OrderedDict([('0.compute_nodes.0.basicOp.weight', tensor([[[[ 1.3214e-02, -1.3274e-02,  2.3736e-03],
          [ 7.9568e-03,  1.5397e-02,  4.3011e-03],
          [-4.6026e-03,  1.7103e-02, -4.4552e-04]],

         [[-2.5175e-04,  3.7678e-03, -9.5644e-03],
          [-7.9996e-03,  9.2212e-03,  6.2399e-03],
          [-6.0951e-03, -1.6816e-03, -5.9685e-03]],

         [[-1.1773e-02,  7.0389e-03,  6.4397e-03],
          [ 8.4068e-03, -1.8877e-03, -1.6305e-02],
          [-3.5811e-03, -1.0351e-02,  6.4356e-03]]],


        [[[-3.4358e-03,  2.5661e-03,  2.1507e-03],
          [-4.1671e-03, -6.9684e-03,  5.7453e-03],
          [ 1.1987e-02, -1.7920e-02, -1.8156e-02]],

         [[ 4.4644e-03,  2.2364e-02, -1.2954e-02],
          [ 8.2039e-03,  1.6479e-02, -7.1404e-03],
          [ 1.0946e-02,  9.5298e-03,  2.2544e-02]],

         [[-2.2213e-03,  6.7141e-03,  7.2080e-03],
          [ 3.4055e-03, -3.6158e-03, -8.3882e-03],
          [ 1.2334e-03, -3.2187e-03, -5.5042e-03]]],


        [[[-8.42

In [5]:
torch.save(new_backbone_init, backbone_PATH)

# Generate heads init weights

In [6]:
cls_num = {'segment_semantic': 40, 'normal':3, 'depth_zbuffer': 1}
heads = nn.ModuleDict()
for task in cls_num:
    heads[task] = ASPPHeadNode(feature_dim, cls_num[task])
print(heads.state_dict())

OrderedDict([('segment_semantic.fc1.conv1.weight', tensor([[[[ 1.4484e-02,  1.4752e-02, -7.0488e-04],
          [ 7.9305e-03,  2.4236e-03,  1.2516e-02],
          [-1.7712e-02,  5.8785e-03, -5.5560e-03]],

         [[ 1.3903e-02,  1.3218e-02,  6.6151e-04],
          [ 1.8091e-02, -1.5456e-02,  5.5825e-04],
          [-1.3689e-02,  3.0212e-03, -1.7784e-02]],

         [[-3.1166e-03,  1.0481e-02,  1.6425e-02],
          [ 7.9230e-03,  1.1596e-02,  9.8163e-03],
          [-1.7195e-02, -9.5541e-03,  6.1133e-03]],

         ...,

         [[-1.2605e-02,  1.1169e-05,  5.2314e-03],
          [-8.8059e-03, -1.7133e-02, -1.4958e-02],
          [-7.8358e-03,  1.3669e-02, -1.5154e-02]],

         [[-1.0398e-03, -6.1174e-03, -1.3224e-03],
          [ 1.5357e-02,  3.0674e-03,  1.1333e-02],
          [-1.4694e-02, -8.9166e-03, -1.6159e-02]],

         [[ 2.8265e-04, -3.5911e-03, -1.7962e-02],
          [-3.3298e-03, -1.0845e-02, -1.5433e-02],
          [ 1.8099e-02,  1.6728e-02, -1.6496e-02]]],


  

In [7]:
heads_PATH = '/mnt/nfs/work1/huiguan/lijunzhang/multibranch/checkpoint/init/mobilenetv2_shorter_NYUv2_heads.model'

In [8]:
torch.save(heads.state_dict(), heads_PATH)

# Load into MTLModel

In [10]:
# layout
T = 3
# coarse_B = 5
# mapping = {0:[0], 1:[1,2,3], 2:[4,5,6,7], 3:[8,9,10,11,12,13], 4:[14,15,16], 5:[17]}

coarse_B = 8
mapping = {0:[0], 1:[1,2], 2:[3,4,5,6], 3:[7,8,9,10,11], 4:[12,13,14,15,16,17], 5:[18,19,20,21,22], 
           6:[23,24,25,26,27], 7:[28,29,30], 8:[31]} 

layout_idx = 10
layout_list = [] 
S0 = init_S(T, coarse_B) # initial state
L = Layout(T, coarse_B, S0) # initial layout
layout_list.append(L)
enum_layout_wo_rdt(L, layout_list)

layout = layout_list[layout_idx]
print('Coarse Layout:', flush=True)
print(layout, flush=True)

layout = coarse_to_fined(layout, fined_B, mapping)
print('Fined Layout:', flush=True)
print(layout, flush=True)

Coarse Layout:
[[{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}]]
Fined Layout:
[[{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}]]


In [11]:
cls_num = {'segment_semantic': 40, 'normal':3, 'depth_zbuffer': 1}
model = MTSeqModel(prototxt, layout=layout, feature_dim=feature_dim, cls_num=cls_num, backbone_init=backbone_PATH, heads_init=heads_PATH)

Construct MTSeqModel from Layout:
[[{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}], [{1}, {0, 2}]]


In [12]:
cls_num = {'segment_semantic': 40, 'normal':3}
model = MTSeqModel(prototxt, branch=2, fined_B=fined_B, feature_dim=feature_dim, cls_num=cls_num, backbone_init=backbone_PATH, heads_init=heads_PATH)

Construct MTSeqModel from Layout:
[[{0, 1}], [{0, 1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}], [{0}, {1}]]


In [11]:
for name, param in model.named_parameters():
    print(name)
    print(param)
    print('='*50)

backbone.mtl_blocks.0.compute_nodes.0.basicOp.weight
Parameter containing:
tensor([[[[-7.7927e-03,  7.1920e-03,  2.4667e-02,  ..., -7.8490e-03,
           -1.3857e-02,  4.4186e-03],
          [-5.5116e-03, -6.5113e-03,  6.4475e-03,  ...,  3.6505e-03,
           -1.3725e-02,  4.1747e-03],
          [ 4.4338e-03, -1.3171e-02, -7.9573e-03,  ..., -3.0204e-05,
            4.8463e-03,  3.4709e-03],
          ...,
          [-2.4296e-03,  9.2065e-03,  7.8446e-03,  ..., -2.0239e-02,
            1.1997e-02,  9.3471e-03],
          [-1.6499e-02,  1.6926e-03, -1.8079e-02,  ...,  2.0526e-03,
            8.8274e-03, -1.2498e-02],
          [ 1.8529e-02, -1.5828e-04,  1.7973e-02,  ...,  1.1346e-02,
            7.7730e-03,  3.3484e-03]],

         [[ 8.1177e-03, -3.1329e-03,  5.6567e-03,  ..., -1.7700e-02,
            8.2081e-03,  2.2929e-02],
          [-1.4322e-02,  5.2998e-03,  3.7853e-03,  ...,  1.9638e-02,
           -1.1993e-02, -4.1179e-03],
          [-1.6891e-03, -5.9618e-03, -3.0025e-03,  .

Parameter containing:
tensor([[[[-7.1133e-03,  6.4171e-03,  3.6215e-03],
          [-9.6833e-03, -1.1567e-02, -1.1857e-02],
          [ 1.3795e-02,  7.2378e-03, -1.8775e-03]],

         [[ 1.4502e-02,  1.4452e-02,  9.9973e-03],
          [ 7.4654e-03,  1.1832e-02,  1.4128e-03],
          [ 1.3514e-02,  1.0823e-02, -4.9443e-03]],

         [[ 4.1606e-04, -7.1200e-03, -1.5573e-02],
          [ 2.0428e-03,  7.7274e-05, -8.4774e-03],
          [-7.8139e-03, -9.9624e-03,  1.5358e-03]],

         ...,

         [[-1.8970e-03, -1.0217e-02, -1.6293e-02],
          [-9.0035e-03,  2.0375e-02, -1.3569e-02],
          [ 7.7085e-03, -4.0894e-03,  3.0768e-03]],

         [[-1.8769e-02, -9.4257e-03,  1.3278e-03],
          [-1.3247e-02, -1.5673e-02, -8.2622e-03],
          [ 8.8869e-03,  9.6137e-03, -1.0095e-02]],

         [[ 9.6691e-04,  7.8296e-04, -1.0381e-02],
          [-6.6235e-03, -7.1830e-04,  1.4790e-02],
          [-6.8299e-03, -7.5142e-03, -1.5235e-02]]],


        [[[-6.9080e-03, -9.0059

Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0.

Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1.

Parameter containing:
tensor([[[[-0.0264]],

         [[ 0.0230]],

         [[-0.0216]],

         ...,

         [[ 0.0226]],

         [[ 0.0181]],

         [[-0.0055]]],


        [[[-0.0279]],

         [[ 0.0230]],

         [[ 0.0046]],

         ...,

         [[-0.0232]],

         [[ 0.0282]],

         [[ 0.0036]]],


        [[[ 0.0251]],

         [[-0.0306]],

         [[-0.0073]],

         ...,

         [[-0.0245]],

         [[ 0.0010]],

         [[-0.0073]]],


        ...,


        [[[ 0.0277]],

         [[-0.0164]],

         [[ 0.0293]],

         ...,

         [[-0.0043]],

         [[ 0.0307]],

         [[-0.0250]]],


        [[[-0.0207]],

         [[-0.0164]],

         [[-0.0090]],

         ...,

         [[ 0.0185]],

         [[ 0.0171]],

         [[-0.0190]]],


        [[[ 0.0218]],

         [[ 0.0094]],

         [[ 0.0163]],

         ...,

         [[-0.0067]],

         [[ 0.0061]],

         [[-0.0239]]]], requires_grad=True)
heads.depth_zb