In [1]:
import numpy as np
import os
import itertools
import sys
sys.path.append('/home/lijunzhang/multibranch/')
from pathlib import Path
from scipy import stats
from scipy.optimize import linear_sum_assignment
from collections import OrderedDict
from ptflops import get_model_complexity_info
import matplotlib.pyplot as plt
import math
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torch.utils.data import DataLoader

from framework.layer_node import Conv2dNode, InputNode
from main.layout import Layout
from main.algorithms import enum_layout_wo_rdt, init_S, coarse_to_fined, prob_inference
from main.auto_models import MTSeqBackbone, MTSeqModel, ComputeBlock
from main.head import ASPPHeadNode
from main.trainer import Trainer

from data.nyuv2_dataloader_adashare import NYU_v2
from data.pixel2pixel_loss import NYUCriterions
from data.pixel2pixel_metrics import NYUMetrics

In [2]:
assert torch.cuda.is_available()

# backbone and data

In [16]:
# backbone
# mobilenet
backbone = 'mobilenet'
prototxt = '../models/mobilenetv2.prototxt'
D = coarse_B = 5
mapping = {0:[0,1,2,3,4,5,6], 1:[7,8,9,10,11,12,13,14,15,16,17], 2:[18,19,20,21,22], 
           3:[23,24,25,26,27,28,29,30], 4:[31], 5:[32]}

In [4]:
# data
# NYUv2
data = 'NYUv2'
dataroot = '/mnt/nfs/work1/huiguan/lijunzhang/policymtl/data/NYUv2/'
tasks = ['segment_semantic', 'normal', 'depth_zbuffer']
cls_num = {'segment_semantic': 40, 'normal':3, 'depth_zbuffer': 1}

dataset = NYU_v2(dataroot, 'train', crop_h=321, crop_w=321)
trainDataloader = DataLoader(dataset, 32, shuffle=True)
dataset = NYU_v2(dataroot, 'test', crop_h=321, crop_w=321)
valDataloader = DataLoader(dataset, 32, shuffle=True)

sampleDataloader = DataLoader(dataset, 1, shuffle=False)

criterionDict = {}
metricDict = {}
for task in tasks:
    print(task, flush=True)
    criterionDict[task] = NYUCriterions(task)
    metricDict[task] = NYUMetrics(task)

input_dim = (3,321,321)
T = len(tasks)

segment_semantic
normal
depth_zbuffer


In [5]:
# fix dataloader for fix samples
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)
sample_size = 32
sampleDataloader = DataLoader(dataset, sample_size, shuffle=True, worker_init_fn=seed_worker,generator=g)

In [6]:
# ind. weights
ckpt_PATH = '/mnt/nfs/work1/huiguan/lijunzhang/multibranch/checkpoint/'
weight_PATH = ckpt_PATH + 'NYUv2/ind/mobilenet/segment_semantic_normal_depth_zbuffer.model' # NYUv2 + MobileNetV2, from the same init
# weight_PATH = ckpt_PATH + 'NYUv2/baseline/WPreMobile/2/segment_semantic_normal_depth_zbuffer.model' # NYUv2 + MobileNetV2, from the same init

# load independent model weights

In [12]:
with torch.no_grad():
    backbone = MTSeqBackbone(prototxt)
    fined_B = len(backbone.basic_blocks)
    feature_dim = backbone(torch.rand(1,3,224,224)).shape[1]

In [13]:
# ind. layout
S = []
for i in range(fined_B):
    S.append([set([x]) for x in range(T)])
layout = Layout(T, fined_B, S) 
print('Ind. Layout:', flush=True)
print(layout, flush=True)

# model
with torch.no_grad():
    model = MTSeqModel(prototxt, layout=layout, feature_dim=feature_dim, cls_num=cls_num)
    model = model.cuda()

    # load ind. model weights
    model.load_state_dict(torch.load(weight_PATH)['state_dict'])

Ind. Layout:
[[{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}]]
Construct MTSeqModel from Layout:
[[{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}],

In [10]:
trainer = Trainer(model, tasks, trainDataloader, valDataloader, criterionDict, metricDict)
trainer.validate()

[Iter 1 Task segm] Val Loss: 1.5592
{'mIoU': 0.1994, 'Pixel Acc': 0.4897}
[Iter 1 Task norm] Val Loss: 0.0736
{'Angle Mean': 20.1623, 'Angle Median': 18.2471, 'Angle 11.25': 11.3339, 'Angle 22.5': 73.9606, 'Angle 30': 86.6333}
[Iter 1 Task dept] Val Loss: 0.9887
{'abs_err': 0.9773, 'rel_err': 0.397, 'sigma_1.25': 38.0331, 'sigma_1.25^2': 67.1293, 'sigma_1.25^3': 84.166}


# feature correlation

In [7]:
import pickle
def save_obj(obj, name):
    with open('./exp/'+ name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(name):
    with open('./exp/' + name + '.pkl', 'rb') as f:
        return pickle.load(f)

In [None]:
# extract features
K = 2
print('The number of samples: {}'.format(K*sample_size))
D_pos = [mapping[d][-1] for d in range(D)]
D_feats = {task: [None]*D for task in tasks} # D_feats[task][block] = array(batch, channel, vec_feat)

for i, batch in enumerate(sampleDataloader):
    if i >= K:
        break
    
    with torch.no_grad():
        x = batch['input'].cuda()
        all_feats = model.extract_features(x)
    idx = 0
    for t in tasks:
        for d in range(D):
            temp = all_feats[idx][D_pos[d]]
            channel = temp.shape[1]
            vec_feat = temp.reshape(sample_size, channel, -1)
            
            if D_feats[t][d] is None:
                D_feats[t][d] = vec_feat
            else:
                D_feats[t][d] = np.vstack((D_feats[t][d], vec_feat))
        idx += 1

In [14]:
def channel_corr(feat1, feat2):
    # feature: array(C, -1)
    channel = feat1.shape[0]
    corr = np.empty([channel, channel])
    for i in range(channel):
        for j in range(channel):
            corr[i,j] = np.corrcoef(feat1[i,:], feat2[j,:])[0,1]
            if math.isnan(corr[i,j]):
                corr[i,j] = -1 # nan means no correlation for inputs having no nan
    return (corr + 1) / 2 # Normalize: data - np.min(data)) / (np.max(data) - np.min(data)) --> probability
#     return corr

In [15]:
# for each (2task-combination, branching point) pair
# compute the maximum feature correlation and the corresponding feature channel alignment
feat_corr = {}
ch_align = {}
for two_task in itertools.combinations(tasks, 2):
    feat_corr[two_task] = []
    ch_align[two_task] = []
    
    for d in range(D):
        channel = D_feats[two_task[0]][d].shape[1]
        feat1 = np.swapaxes(D_feats[two_task[0]][d], 1, 2).reshape(-1,channel).T # array(C, k*H*W)
        feat2 = np.swapaxes(D_feats[two_task[1]][d], 1, 2).reshape(-1,channel).T # array(C, k*H*W)
        
        corr_mat = channel_corr(feat1, feat2)
        row_idx, col_idx = linear_sum_assignment(corr_mat, maximize=True)
        pre_cost = np.diag(corr_mat).mean()
        post_corr = corr_mat[row_idx, col_idx].mean()
#         print('pre corr:{}'.format(pre_cost))
#         print('post corr:{}'.format(post_corr))
        
        feat_corr[two_task].append(post_corr)
        ch_align[two_task].append(col_idx)

  c /= stddev[:, None]
  c /= stddev[None, :]


In [23]:
save_obj(feat_corr, 'feat_corr.pkl')

In [8]:
feat_corr = load_obj('feat_corr.pkl')

# enum layouts and compute joint possiblity

In [9]:
# enum layout
layout_list = [] 
S0 = init_S(T, coarse_B) # initial state
L = Layout(T, coarse_B, S0) # initial layout
layout_list.append(L)
enum_layout_wo_rdt(L, layout_list)

In [10]:
# pre-process two_task_prob
two_task_prob = {}
for two_task in itertools.combinations(tasks, 2):
    two_task_prob[(tasks.index(two_task[0]), tasks.index(two_task[1]))] = [1]+feat_corr[two_task] # the probability of branching point 0 is 1
print(two_task_prob)

{(0, 1): [1, 0.6362004371892473, 0.638888580009888, 0.5920661989524972, 0.2965104816678368, 0.19595387053885377], (0, 2): [1, 0.6989018835315278, 0.5262198025912297, 0.22032401748156186, 0.016121479555399526, 0.021005046841675767], (1, 2): [1, 0.6315847199298901, 0.5240788359859585, 0.21673846313033832, 0.015832776606861592, 0.0206323016000405]}


In [11]:
# Run for all L
for L in layout_list:
    print(L)
    prob_inference(L, two_task_prob)
    print(L.prob*100)
    print('=' * 100)

[[{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}]]
subtree: [[0, 1, 5], [0, 2, 5], [1, 2, 5]]
0.008492297076642277
[[{1, 2}, {0}], [{1, 2}, {0}], [{1, 2}, {0}], [{1, 2}, {0}], [{1, 2}, {0}]]
subtree: [[0, 1, 0], [0, 2, 0], [1, 2, 5]]
2.06323016000405
[[{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}]]
subtree: [[0, 1, 0], [0, 2, 0], [1, 2, 0]]
100
[[{1, 2}, {0}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}]]
subtree: [[0, 1, 0], [0, 2, 0], [1, 2, 1]]
63.15847199298901
[[{1, 2}, {0}], [{1, 2}, {0}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}]]
subtree: [[0, 1, 0], [0, 2, 0], [1, 2, 2]]
52.40788359859585
[[{1, 2}, {0}], [{1, 2}, {0}], [{1, 2}, {0}], [{0}, {2}, {1}], [{0}, {2}, {1}]]
subtree: [[0, 1, 0], [0, 2, 0], [1, 2, 3]]
21.67384631303383
[[{1, 2}, {0}], [{1, 2}, {0}], [{1, 2}, {0}], [{1, 2}, {0}], [{0}, {2}, {1}]]
subtree: [[0, 1, 0], [0, 2, 0], [1, 2, 4]]
1.5832776606861592
[[{1}, {0, 2}], [{1}, {0, 2}], [

In [12]:
# sort by layout prob
layout_order = sorted(range(len(layout_list)), key=lambda k: layout_list[k].prob,reverse=True)

In [13]:
# choose some layouts to verify
step = 1
for i in range(0,len(layout_order),step):
    print(layout_order[i])
    L = layout_list[layout_order[i]]
    print(L)
    print(L.prob*100)
    print('=' * 100)

2
[[{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}]]
100
8
[[{1}, {0, 2}], [{1}, {2}, {0}], [{1}, {2}, {0}], [{1}, {2}, {0}], [{1}, {2}, {0}]]
69.89018835315278
14
[[{2}, {0, 1}], [{2}, {0, 1}], [{2}, {1}, {0}], [{2}, {1}, {0}], [{2}, {1}, {0}]]
63.888858000988805
13
[[{2}, {0, 1}], [{2}, {1}, {0}], [{2}, {1}, {0}], [{2}, {1}, {0}], [{2}, {1}, {0}]]
63.620043718924734
3
[[{1, 2}, {0}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}]]
63.15847199298901
15
[[{2}, {0, 1}], [{2}, {0, 1}], [{2}, {0, 1}], [{2}, {1}, {0}], [{2}, {1}, {0}]]
59.20661989524972
9
[[{1}, {0, 2}], [{1}, {0, 2}], [{1}, {2}, {0}], [{1}, {2}, {0}], [{1}, {2}, {0}]]
52.62198025912297
4
[[{1, 2}, {0}], [{1, 2}, {0}], [{0}, {2}, {1}], [{0}, {2}, {1}], [{0}, {2}, {1}]]
52.40788359859585
16
[[{2}, {0, 1}], [{2}, {0, 1}], [{2}, {0, 1}], [{2}, {0, 1}], [{2}, {1}, {0}]]
29.651048166783678
27
[[{0, 1, 2}], [{2}, {0, 1}], [{2}, {1}, {0}], [{2}, {1}, {0}], [{2}, {1}, {0}]]
28

# SROCC

In [18]:
real_early_iter_rel_pref = load_obj('real_early_iter_rel_pref_'+data+'_'+backbone)
real_rel_pref = real_early_iter_rel_pref['final_rel_perf']

In [20]:
idx = real_early_iter_rel_pref['layout']
est_rank = [layout_list[i].prob for i in idx]

In [21]:
stats.spearmanr(real_rel_pref,est_rank)

SpearmanrResult(correlation=-0.5819548872180451, pvalue=0.00710443754510044)