In [1]:
import numpy as np
import os
from pathlib import Path
import scipy.stats
from scipy.optimize import linear_sum_assignment
import itertools

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torch.utils.data import DataLoader

from framework.layer_node import Conv2dNode
from main.layout import Layout
from main.algorithms import enum_layout_wo_rdt, init_S, coarse_to_fined
from main.auto_models import MTSeqBackbone, MTSeqModel
from data.nyuv2_dataloader_adashare import NYU_v2

In [2]:
assert torch.cuda.is_available()

# backbone and data

In [3]:
# backbone
# mobilenet
backbone_type = 'mobilenet'
prototxt = 'models/mobilenetv2.prototxt'
D = coarse_B = 5
mapping = {0:[0,1,2,3,4,5,6], 1:[7,8,9,10,11,12,13,14,15,16,17], 2:[18,19,20,21,22], 
           3:[23,24,25,26,27,28,29,30], 4:[31], 5:[32]}

In [4]:
# data
# NYUv2
data = 'NYUv2'
dataset = NYU_v2('/gypsum/work1/huiguan/lijunzhang/policymtl/data/NYUv2/', 'test', crop_h=321, crop_w=321)
dataloader = DataLoader(dataset, 1, shuffle=False)
tasks = ['segment_semantic', 'normal', 'depth_zbuffer']
cls_num = {'segment_semantic': 40, 'normal':3, 'depth_zbuffer': 1}
T = len(tasks)
layout_idx = 2 # ind. models

In [5]:
# ind. weights
ckpt_PATH = '/gypsum/work1/huiguan/lijunzhang/multibranch/checkpoint/'
weight_PATH = ckpt_PATH + 'NYUv2/baseline/NM/2/segment_semantic_normal_depth_zbuffer.model' # NYUv2 + MobileNetV2

# load independent model weights

In [6]:
with torch.no_grad():
    backbone = MTSeqBackbone(prototxt)
    fined_B = len(backbone.basic_blocks)
    feature_dim = backbone(torch.rand(1,3,224,224)).shape[1]

In [None]:
# layout
layout_list = [] 
S0 = init_S(T, coarse_B) # initial state
L = Layout(T, coarse_B, S0) # initial layout
layout_list.append(L)
enum_layout_wo_rdt(L, layout_list)

layout = layout_list[layout_idx]
print('Coarse Layout:', flush=True)
print(layout, flush=True)

layout = coarse_to_fined(layout, fined_B, mapping)
print('Fined Layout:', flush=True)
print(layout, flush=True)

# model
with torch.no_grad():
    model = MTSeqModel(prototxt, layout=layout, feature_dim=feature_dim, cls_num=cls_num)
    model = model.cuda()
    
# load ind. model weights
model.load_state_dict(torch.load(weight_PATH)['state_dict'])

Coarse Layout:
[[{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}]]
Fined Layout:
[[{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}]]
Construct MTSeqModel from Layout:
[[{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}],

In [16]:
summary(model, (3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 112, 112]             864
       BatchNorm2d-2         [-1, 32, 112, 112]              64
              ReLU-3         [-1, 32, 112, 112]               0
            Conv2d-4         [-1, 32, 112, 112]             864
       BatchNorm2d-5         [-1, 32, 112, 112]              64
              ReLU-6         [-1, 32, 112, 112]               0
            Conv2d-7         [-1, 32, 112, 112]             864
       BatchNorm2d-8         [-1, 32, 112, 112]              64
              ReLU-9         [-1, 32, 112, 112]               0
           Conv2d-10         [-1, 32, 112, 112]             288
      BatchNorm2d-11         [-1, 32, 112, 112]              64
             ReLU-12         [-1, 32, 112, 112]               0
           Conv2d-13         [-1, 32, 112, 112]             288
      BatchNorm2d-14         [-1, 32, 1

# show weights

In [None]:
i = 1
taskConvList = {task: [] for task in tasks}

count = 0
for name, module in model.named_modules():
    if count >= i * T:
        break
    if isinstance(module, Conv2dNode):
        print(name)
        print(module)
        taskConvList[tasks[count%T]].append(module)
        count += 1
#         print(module.fatherNodeList)

In [None]:
for task in tasks:
    print(taskConvList[task].basicOp

# align conv weights in blocks

# show block features

In [None]:
# extract features
K = 1
D_pos = [mapping[d][-1] for d in range(D)]
D_feats = {task: [None]*D for task in tasks}

for i, batch in enumerate(dataloader):
    if i >= K:
        break
    
    with torch.no_grad():
        x = batch['input'].cuda()
        all_feats = model.extract_features(x)

    idx = 0
    for t in tasks:
        for d in range(D):
            vec_feat = all_feats[idx][D_pos[d]].cpu().numpy().reshape(1,-1)
            if D_feats[t][d] is None:
                D_feats[t][d] = vec_feat
            else:
                temp = D_feats[t][d]
                D_feats[t][d] = np.vstack((temp,vec_feat))
        idx += 1