In [1]:
import numpy as np
import random
import os
import argparse
from pathlib import Path
import scipy.stats
import itertools
from statistics import mean
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from main.layout import Layout
from main.algorithms import enum_layout_wo_rdt, init_S, coarse_to_fined
from main.auto_models import MTSeqBackbone, MTSeqModel
from data.nyuv2_dataloader_adashare import NYU_v2
from data.taskonomy_dataloader_adashare import Taskonomy

# prepare params

In [2]:
# backbone

# resnet34
# backbone_type = 'resnet'
# prototxt = 'models/deeplab_resnet34_adashare.prototxt'
# D = coarse_B = 5
# mapping = {0:[0], 1:[1,2,3], 2:[4,5,6,7], 3:[8,9,10,11,12,13], 4:[14,15,16], 5:[17]}

# mobilenet
backbone_type = 'mobilenet'
prototxt = 'models/mobilenetv2.prototxt'
D = coarse_B = 5
mapping = {0:[0,1,2,3,4,5,6], 1:[7,8,9,10,11,12,13,14,15,16,17], 2:[18,19,20,21,22], 
           3:[23,24,25,26,27,28,29,30], 4:[31], 5:[32]}

In [3]:
# data
# NYUv2
# data = 'NYUv2'
# dataset = NYU_v2('/mnt/nfs/work1/huiguan/lijunzhang/policymtl/data/NYUv2/', 'test', crop_h=321, crop_w=321)
# dataloader = DataLoader(dataset, 1, shuffle=False)
# tasks = ['segment_semantic', 'normal', 'depth_zbuffer']
# cls_num = {'segment_semantic': 40, 'normal':3, 'depth_zbuffer': 1}
# T = len(tasks)
# layout_idx = 2 # ind. models

# Taskonomy
data = 'Taskonomy'
dataset = Taskonomy('/mnt/nfs/work1/huiguan/lijunzhang/policymtl/data/Taskonomy', 'test_small', crop_h=224, crop_w=224)
dataloader = DataLoader(dataset, 1, shuffle=False)
tasks = ['segment_semantic','normal','depth_zbuffer','keypoints2d','edge_texture']
cls_num = {'segment_semantic': 17, 'normal': 3, 'depth_zbuffer': 1, 'keypoints2d': 1, 'edge_texture': 1}
T = len(tasks)
layout_idx = 4 # ind. models

In [4]:
# ind. weights
ckpt_PATH = '/mnt/nfs/work1/huiguan/lijunzhang/multibranch/checkpoint/'
# weight_PATH = ckpt_PATH + 'NYUv2/baseline/NR/2/segment_semantic_normal_depth_zbuffer.model' # NYUv2 + Resnet34
# weight_PATH = ckpt_PATH + 'NYUv2/baseline/NM/2/segment_semantic_normal_depth_zbuffer.model' # NYUv2 + MobileNetV2
# weight_PATH = ckpt_PATH + 'Taskonomy/baseline/TR/4/segment_semantic_normal_depth_zbuffer_keypoints2d_edge_texture.model' # Taskonomy + Resnet34
weight_PATH = ckpt_PATH + 'Taskonomy/baseline/TM/4/segment_semantic_normal_depth_zbuffer_keypoints2d_edge_texture.model' # Taskonomy + MobileNetV2

In [8]:
# real results
real_date = '0216'

# procedure

In [9]:
with torch.no_grad():
    backbone = MTSeqBackbone(prototxt)
    fined_B = len(backbone.basic_blocks)
    feature_dim = backbone(torch.rand(1,3,224,224)).shape[1]

In [10]:
# layout
layout_list = [] 
S0 = init_S(T, coarse_B) # initial state
L = Layout(T, coarse_B, S0) # initial layout
layout_list.append(L)
enum_layout_wo_rdt(L, layout_list)

layout = layout_list[layout_idx]
print('Coarse Layout:', flush=True)
print(layout, flush=True)

layout = coarse_to_fined(layout, fined_B, mapping)
print('Fined Layout:', flush=True)
print(layout, flush=True)

# model
with torch.no_grad():
    model = MTSeqModel(prototxt, layout=layout, feature_dim=feature_dim, cls_num=cls_num)
    model = model.cuda()
    
# load ind. model weights
model.load_state_dict(torch.load(weight_PATH)['state_dict'])

Coarse Layout:
[[{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}]]
Fined Layout:
[[{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}]]
Construct MTSeqModel from Layout:
[[{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, 

FileNotFoundError: [Errno 2] No such file or directory: '/mnt/nfs/work1/huiguan/lijunzhang/multibranch/checkpoint/Taskonomy/baseline/NR/4/segment_semantic_normal_depth_zbuffer_keypoints2d_edge_texture.model'

In [13]:
# extract features
K = 500
D_pos = [mapping[d][-1] for d in range(D)]
D_feats = {task: [None]*D for task in tasks}

for i, batch in enumerate(dataloader):
    if i >= K:
        break
    
    with torch.no_grad():
        x = batch['input'].cuda()
        all_feats = model.extract_features(x)

    idx = 0
    for t in tasks:
        for d in range(D):
            vec_feat = all_feats[idx][D_pos[d]].cpu().numpy().reshape(1,-1)
            if D_feats[t][d] is None:
                D_feats[t][d] = vec_feat
            else:
                temp = D_feats[t][d]
                D_feats[t][d] = np.vstack((temp,vec_feat))
        idx += 1

In [14]:
# compute RDM
RDM = {task: np.zeros((D,K,K)) for task in tasks}
for t in tasks:
    for d in range(D):
        RDM[t][d,:,:] = 1 - np.corrcoef(D_feats[t][d])

In [15]:
# compute RSA
RSA = np.ones((D,T,T))
for d in range(D):
    for two_task in itertools.combinations(tasks, 2):
        rho = scipy.stats.spearmanr(np.tril(RDM[two_task[0]][d,:,:]).reshape(-1), np.tril(RDM[two_task[1]][d,:,:]).reshape(-1))[0]
        RSA[d,tasks.index(two_task[0]),tasks.index(two_task[1])] = RSA[d,tasks.index(two_task[1]),tasks.index(two_task[0])] = rho
print(RSA)

[[[1.         0.99654141 0.99381004 0.99014505 0.99293988]
  [0.99654141 1.         0.99715696 0.99130068 0.99499507]
  [0.99381004 0.99715696 1.         0.99092965 0.99425948]
  [0.99014505 0.99130068 0.99092965 1.         0.99795831]
  [0.99293988 0.99499507 0.99425948 0.99795831 1.        ]]

 [[1.         0.96637632 0.96896415 0.94644115 0.96457368]
  [0.96637632 1.         0.9785001  0.95675202 0.95459148]
  [0.96896415 0.9785001  1.         0.97153634 0.96678255]
  [0.94644115 0.95675202 0.97153634 1.         0.95624757]
  [0.96457368 0.95459148 0.96678255 0.95624757 1.        ]]

 [[1.         0.96103835 0.96726784 0.92090276 0.93047752]
  [0.96103835 1.         0.97475784 0.90728607 0.92582239]
  [0.96726784 0.97475784 1.         0.9148951  0.92784373]
  [0.92090276 0.90728607 0.9148951  1.         0.91753083]
  [0.93047752 0.92582239 0.92784373 0.91753083 1.        ]]

 [[1.         0.95778592 0.96561384 0.88098284 0.88301105]
  [0.95778592 1.         0.98047995 0.87691217 0.8

In [16]:
# set score for each layout
for layout in layout_list:
    score = 0
    for d in range(D):
        C_cluster = []
        for task_set in layout[d]:
            max_dis = 0
            for two_task in itertools.combinations(task_set, 2):
                max_dis = max(max_dis, 1-RSA[d,two_task[0],two_task[1]])
            C_cluster.append(max_dis)
        score += mean(C_cluster)
    layout.score = score

In [17]:
# sort layout
layout_order = sorted(range(len(layout_list)), key=lambda k: layout_list[k],reverse=False)
step = 1
for i in range(0,len(layout_order),step):
    print(layout_order[i])
    L = layout_list[layout_order[i]]
    print(L)
    print(L.score)
    print('=' * 100)

4
[[{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}]]
0
5
[[{0}, {1}, {3, 4}, {2}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}]]
0.0005104213555107706
150
[[{0}, {1, 2}, {4}, {3}], [{0}, {4}, {3}, {2}, {1}], [{0}, {4}, {3}, {2}, {1}], [{0}, {4}, {3}, {2}, {1}], [{0}, {4}, {3}, {2}, {1}]]
0.0007107602188133888
2275
[[{0, 1}, {2}, {4}, {3}], [{2}, {4}, {3}, {1}, {0}], [{2}, {4}, {3}, {1}, {0}], [{2}, {4}, {3}, {1}, {0}], [{2}, {4}, {3}, {1}, {0}]]
0.0008646463327494358
55
[[{0}, {1, 4}, {3}, {2}], [{0}, {3}, {2}, {4}, {1}], [{0}, {3}, {2}, {4}, {1}], [{0}, {3}, {2}, {4}, {1}], [{0}, {3}, {2}, {4}, {1}]]
0.001251233382417144
10
[[{0}, {1}, {3}, {2, 4}], [{0}, {1}, {3}, {4}, {2}], [{0}, {1}, {3}, {4}, {2}], [{0}, {1}, {3}, {4}, {2}], [{0}, {1}, {3}, {4}, {2}]]
0.0014351292477611244
1285
[[{0, 2}, {1}, {4}, {3}], [{1}, {4}, {3}, {2}, {0}], [{1}, 

0.02447836420562815
692
[[{1, 2, 3}, {0, 4}], [{3}, {1, 2}, {4}, {0}], [{3}, {1, 2}, {4}, {0}], [{3}, {1, 2}, {4}, {0}], [{3}, {4}, {0}, {2}, {1}]]
0.024630761859865208
2339
[[{0, 1}, {4}, {2, 3}], [{0, 1}, {4}, {2, 3}], [{4}, {1}, {0}, {3}, {2}], [{4}, {1}, {0}, {3}, {2}], [{4}, {1}, {0}, {3}, {2}]]
0.02487209218106042
2084
[[{1}, {0, 2, 3, 4}], [{1}, {3}, {0, 2, 4}], [{1}, {3}, {4}, {0, 2}], [{1}, {3}, {4}, {2}, {0}], [{1}, {3}, {4}, {2}, {0}]]
0.02491928487795564
223
[[{0}, {4}, {1, 2, 3}], [{0}, {4}, {3}, {1, 2}], [{0}, {4}, {3}, {1, 2}], [{0}, {4}, {3}, {1, 2}], [{0}, {4}, {3}, {1, 2}]]
0.024963081835942508
3495
[[{3}, {0, 1, 2, 4}], [{3}, {1, 2, 4}, {0}], [{3}, {0}, {4}, {1, 2}], [{3}, {0}, {4}, {2}, {1}], [{3}, {0}, {4}, {2}, {1}]]
0.02497677040682863
4180
[[{0, 1, 2, 3, 4}], [{0}, {2}, {1, 3, 4}], [{0}, {2}, {1}, {4}, {3}], [{0}, {2}, {1}, {4}, {3}], [{0}, {2}, {1}, {4}, {3}]]
0.024991117257733474
4222
[[{0, 1, 2, 3, 4}], [{0}, {3}, {1, 2, 4}], [{0}, {3}, {1}, {4}, {2}], [{0}, 

[[{0, 1, 2, 3, 4}], [{0, 3}, {4}, {1, 2}], [{4}, {1, 2}, {3}, {0}], [{4}, {3}, {0}, {2}, {1}], [{4}, {3}, {0}, {2}, {1}]]
0.04118506708172613
135
[[{0}, {2}, {1, 3, 4}], [{0}, {2}, {1, 3, 4}], [{0}, {2}, {4}, {1, 3}], [{0}, {2}, {4}, {3}, {1}], [{0}, {2}, {4}, {3}, {1}]]
0.04121442769705196
3290
[[{0, 1, 2}, {4}, {3}], [{0, 1, 2}, {4}, {3}], [{0, 1, 2}, {4}, {3}], [{4}, {3}, {1}, {0, 2}], [{4}, {3}, {1}, {0, 2}]]
0.04130897958212716
5846
[[{0, 1, 2, 3, 4}], [{0, 1, 2}, {4}, {3}], [{4}, {3}, {2}, {0, 1}], [{4}, {3}, {2}, {0, 1}], [{4}, {3}, {2}, {1}, {0}]]
0.041356770848587236
4107
[[{0, 1, 2, 3, 4}], [{0}, {1}, {3, 4}, {2}], [{0}, {1}, {3, 4}, {2}], [{0}, {1}, {2}, {4}, {3}], [{0}, {1}, {2}, {4}, {3}]]
0.04141034553410708
2761
[[{0, 1, 3}, {4}, {2}], [{0, 1, 3}, {4}, {2}], [{4}, {2}, {3}, {0, 1}], [{4}, {2}, {3}, {0, 1}], [{4}, {2}, {3}, {1}, {0}]]
0.041431863735572685
1411
[[{1, 3, 4}, {0, 2}], [{4}, {1, 3}, {2}, {0}], [{4}, {1, 3}, {2}, {0}], [{4}, {2}, {0}, {3}, {1}], [{4}, {2}, {0}

[[{0, 1, 2, 3, 4}], [{0, 2}, {3}, {1, 4}], [{3}, {1, 4}, {2}, {0}], [{3}, {2}, {0}, {4}, {1}], [{3}, {2}, {0}, {4}, {1}]]
0.05388080344869002
2342
[[{0, 1}, {4}, {2, 3}], [{0, 1}, {4}, {2, 3}], [{0, 1}, {4}, {3}, {2}], [{0, 1}, {4}, {3}, {2}], [{0, 1}, {4}, {3}, {2}]]
0.05388799438192039
838
[[{0, 3}, {2, 4}, {1}], [{0, 3}, {2, 4}, {1}], [{0, 3}, {1}, {4}, {2}], [{1}, {4}, {2}, {3}, {0}], [{1}, {4}, {2}, {3}, {0}]]
0.053898229467622724
3607
[[{3}, {0, 1, 2, 4}], [{3}, {2}, {0, 1, 4}], [{3}, {2}, {0, 1, 4}], [{3}, {2}, {4}, {0, 1}], [{3}, {2}, {4}, {1}, {0}]]
0.05394562193247886
2608
[[{2, 3}, {0, 1, 4}], [{2, 3}, {1, 4}, {0}], [{2, 3}, {0}, {4}, {1}], [{0}, {4}, {1}, {3}, {2}], [{0}, {4}, {1}, {3}, {2}]]
0.05396551857391839
3389
[[{3, 4}, {0, 1, 2}], [{3, 4}, {0, 1, 2}], [{4}, {3}, {1, 2}, {0}], [{4}, {3}, {1, 2}, {0}], [{4}, {3}, {0}, {2}, {1}]]
0.05399442722776018
5470
[[{0, 1, 2, 3, 4}], [{0, 1, 4}, {3}, {2}], [{3}, {2}, {4}, {0, 1}], [{3}, {2}, {4}, {0, 1}], [{3}, {2}, {4}, {0, 1}]

[[{0, 1, 2, 3, 4}], [{1, 3}, {0, 2, 4}], [{0, 2, 4}, {3}, {1}], [{3}, {1}, {0}, {4}, {2}], [{3}, {1}, {0}, {4}, {2}]]
0.07324418442749359
3542
[[{3}, {0, 1, 2, 4}], [{3}, {1, 4}, {0, 2}], [{3}, {1, 4}, {0, 2}], [{3}, {0, 2}, {4}, {1}], [{3}, {4}, {1}, {2}, {0}]]
0.0732446455967213
4224
[[{0, 1, 2, 3, 4}], [{0}, {3}, {1, 2, 4}], [{0}, {3}, {2, 4}, {1}], [{0}, {3}, {2, 4}, {1}], [{0}, {3}, {1}, {4}, {2}]]
0.07327720947667082
3808
[[{4}, {0, 1, 2, 3}], [{4}, {1, 2, 3}, {0}], [{4}, {0}, {2}, {1, 3}], [{4}, {0}, {2}, {1, 3}], [{4}, {0}, {2}, {3}, {1}]]
0.07329390713613218
584
[[{0, 4}, {2, 3}, {1}], [{0, 4}, {2, 3}, {1}], [{0, 4}, {1}, {3}, {2}], [{0, 4}, {1}, {3}, {2}], [{1}, {3}, {2}, {4}, {0}]]
0.07330134031520205
2482
[[{2, 3, 4}, {0, 1}], [{2, 3, 4}, {0, 1}], [{2, 3, 4}, {1}, {0}], [{1}, {0}, {2}, {4}, {3}], [{1}, {0}, {2}, {4}, {3}]]
0.07332081993814432
4328
[[{0, 1, 2, 3, 4}], [{1, 2, 3, 4}, {0}], [{0}, {3, 4}, {1, 2}], [{0}, {1, 2}, {4}, {3}], [{0}, {4}, {3}, {2}, {1}]]
0.0733429929

4193
[[{0, 1, 2, 3, 4}], [{0}, {2}, {1, 3, 4}], [{0}, {2}, {1, 3, 4}], [{0}, {2}, {3}, {1, 4}], [{0}, {2}, {3}, {4}, {1}]]
0.086390235235344
2852
[[{2, 4}, {0, 1, 3}], [{2, 4}, {3}, {0, 1}], [{2, 4}, {3}, {0, 1}], [{3}, {0, 1}, {4}, {2}], [{3}, {0, 1}, {4}, {2}]]
0.08639290436921518
5071
[[{0, 1, 2, 3, 4}], [{0, 2, 3}, {4}, {1}], [{0, 2, 3}, {4}, {1}], [{4}, {1}, {2, 3}, {0}], [{4}, {1}, {0}, {3}, {2}]]
0.0864279355395502
4869
[[{0, 1, 2, 3, 4}], [{1, 3, 4}, {0, 2}], [{0, 2}, {3, 4}, {1}], [{1}, {2}, {0}, {4}, {3}], [{1}, {2}, {0}, {4}, {3}]]
0.08647757224706656
4301
[[{0, 1, 2, 3, 4}], [{1, 2, 3, 4}, {0}], [{0}, {1, 3}, {4}, {2}], [{0}, {1, 3}, {4}, {2}], [{0}, {4}, {2}, {3}, {1}]]
0.08650964432854055
6021
[[{0, 1, 2, 3, 4}], [{3}, {0, 1, 2, 4}], [{3}, {2}, {0, 1, 4}], [{3}, {2}, {1}, {0, 4}], [{3}, {2}, {1}, {4}, {0}]]
0.08653231330292845
3855
[[{4}, {0, 1, 2, 3}], [{4}, {1, 3}, {0, 2}], [{4}, {1, 3}, {0, 2}], [{4}, {0, 2}, {3}, {1}], [{4}, {0, 2}, {3}, {1}]]
0.08655466293033126
254


2744
[[{0, 1, 3}, {4}, {2}], [{4}, {2}, {1}, {0, 3}], [{4}, {2}, {1}, {0, 3}], [{4}, {2}, {1}, {0, 3}], [{4}, {2}, {1}, {0, 3}]]
0.09449018804537694
2423
[[{2, 3, 4}, {0, 1}], [{2, 3, 4}, {1}, {0}], [{1}, {0}, {3}, {2, 4}], [{1}, {0}, {3}, {2, 4}], [{1}, {0}, {3}, {2, 4}]]
0.09459912157496853
2581
[[{2, 3}, {0, 1, 4}], [{0, 1, 4}, {3}, {2}], [{3}, {2}, {1}, {0, 4}], [{3}, {2}, {1}, {0, 4}], [{3}, {2}, {1}, {0, 4}]]
0.09465558189370038
361
[[{1, 2, 3, 4}, {0}], [{0}, {3, 4}, {1, 2}], [{0}, {3, 4}, {1, 2}], [{0}, {3, 4}, {2}, {1}], [{0}, {2}, {1}, {4}, {3}]]
0.09465717410217024
5884
[[{0, 1, 2, 3, 4}], [{3, 4}, {0, 1, 2}], [{3, 4}, {1, 2}, {0}], [{1, 2}, {0}, {4}, {3}], [{1, 2}, {0}, {4}, {3}]]
0.09470089102120685
5611
[[{0, 1, 2, 3, 4}], [{2, 4}, {0, 1, 3}], [{0, 1, 3}, {4}, {2}], [{4}, {2}, {3}, {0, 1}], [{4}, {2}, {3}, {1}, {0}]]
0.09470125767266813
1469
[[{1, 3, 4}, {0, 2}], [{1, 3, 4}, {0, 2}], [{3}, {1, 4}, {2}, {0}], [{3}, {1, 4}, {2}, {0}], [{3}, {2}, {0}, {4}, {1}]]
0.0947056977

1232
[[{1, 2}, {0, 3, 4}], [{1, 2}, {0, 3, 4}], [{1, 2}, {3}, {0, 4}], [{3}, {0, 4}, {2}, {1}], [{3}, {2}, {1}, {4}, {0}]]
0.10471381664325566
640
[[{0, 4}, {3}, {1, 2}], [{0, 4}, {3}, {1, 2}], [{0, 4}, {3}, {1, 2}], [{0, 4}, {3}, {1, 2}], [{3}, {1, 2}, {4}, {0}]]
0.10474177479082866
6833
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{0, 1}, {4}, {2, 3}], [{4}, {1}, {0}, {3}, {2}], [{4}, {1}, {0}, {3}, {2}]]
0.10476930851625108
1422
[[{1, 3, 4}, {0, 2}], [{0, 2}, {4}, {1, 3}], [{0, 2}, {4}, {1, 3}], [{4}, {1, 3}, {2}, {0}], [{4}, {2}, {0}, {3}, {1}]]
0.1047932381299326
5397
[[{0, 1, 2, 3, 4}], [{2, 3, 4}, {0, 1}], [{0, 1}, {3}, {2, 4}], [{0, 1}, {3}, {4}, {2}], [{0, 1}, {3}, {4}, {2}]]
0.10485779505915799
6931
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{0, 1, 3}, {4}, {2}], [{4}, {2}, {3}, {0, 1}], [{4}, {2}, {3}, {1}, {0}]]
0.10487195651013609
2432
[[{2, 3, 4}, {0, 1}], [{2, 3, 4}, {1}, {0}], [{2, 3, 4}, {1}, {0}], [{1}, {0}, {3}, {2, 4}], [{1}, {0}, {3}, {2, 4}]]
0.10492835216761981
4895
[[{0

0.11495624984621944
603
[[{0, 4}, {2}, {1, 3}], [{0, 4}, {2}, {1, 3}], [{2}, {1, 3}, {4}, {0}], [{2}, {1, 3}, {4}, {0}], [{2}, {1, 3}, {4}, {0}]]
0.11501569427495173
4338
[[{0, 1, 2, 3, 4}], [{1, 2, 3, 4}, {0}], [{0}, {3}, {1, 2, 4}], [{0}, {3}, {2}, {1, 4}], [{0}, {3}, {2}, {1, 4}]]
0.11502307708339023
5960
[[{0, 1, 2, 3, 4}], [{3}, {0, 1, 2, 4}], [{3}, {1, 2, 4}, {0}], [{3}, {0}, {2}, {1, 4}], [{3}, {0}, {2}, {1, 4}]]
0.11502307708339023
6017
[[{0, 1, 2, 3, 4}], [{3}, {0, 1, 2, 4}], [{3}, {2}, {0, 1, 4}], [{3}, {2}, {1, 4}, {0}], [{3}, {2}, {1, 4}, {0}]]
0.11502307708339023
5979
[[{0, 1, 2, 3, 4}], [{3}, {0, 1, 2, 4}], [{3}, {1, 2}, {0, 4}], [{3}, {1, 2}, {0, 4}], [{3}, {1, 2}, {4}, {0}]]
0.11502452038178612
3137
[[{2}, {0, 1, 3, 4}], [{2}, {0, 1, 3, 4}], [{2}, {1, 3}, {0, 4}], [{2}, {0, 4}, {3}, {1}], [{2}, {3}, {1}, {4}, {0}]]
0.1150329383857484
1885
[[{1, 4}, {0, 2, 3}], [{1, 4}, {0, 2, 3}], [{0, 2, 3}, {4}, {1}], [{4}, {1}, {2}, {0, 3}], [{4}, {1}, {2}, {3}, {0}]]
0.1150362091080

0.1241292051832652
1187
[[{1, 2}, {0, 3, 4}], [{1, 2}, {4}, {0, 3}], [{1, 2}, {4}, {0, 3}], [{4}, {0, 3}, {2}, {1}], [{4}, {0, 3}, {2}, {1}]]
0.12418955955302818
2441
[[{2, 3, 4}, {0, 1}], [{2, 3, 4}, {0, 1}], [{3, 4}, {2}, {1}, {0}], [{3, 4}, {2}, {1}, {0}], [{3, 4}, {2}, {1}, {0}]]
0.12419350226329354
642
[[{0, 4}, {3}, {1, 2}], [{0, 4}, {3}, {1, 2}], [{0, 4}, {3}, {1, 2}], [{0, 4}, {3}, {1, 2}], [{0, 4}, {3}, {2}, {1}]]
0.12419398751984959
4059
[[{4}, {0, 1, 2, 3}], [{4}, {0, 1, 2, 3}], [{4}, {0, 1, 2, 3}], [{4}, {1, 2}, {0, 3}], [{4}, {2}, {1}, {3}, {0}]]
0.12424293193802152
6434
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{0, 4}, {3}, {1, 2}], [{0, 4}, {3}, {2}, {1}], [{3}, {2}, {1}, {4}, {0}]]
0.12424924250324684
6267
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{0}, {2, 3, 4}, {1}], [{0}, {1}, {3, 4}, {2}], [{0}, {1}, {2}, {4}, {3}]]
0.12424953544984374
3105
[[{2}, {0, 1, 3, 4}], [{2}, {4}, {0, 1, 3}], [{2}, {4}, {0, 1, 3}], [{2}, {4}, {0, 1, 3}], [{2}, {4}, {1, 3}, {0}]]
0.1243016864031

0.13960777815981207
5889
[[{0, 1, 2, 3, 4}], [{3, 4}, {0, 1, 2}], [{3, 4}, {1, 2}, {0}], [{3, 4}, {1, 2}, {0}], [{1, 2}, {0}, {4}, {3}]]
0.13961748780145625
613
[[{0, 4}, {2}, {1, 3}], [{0, 4}, {2}, {1, 3}], [{0, 4}, {2}, {1, 3}], [{0, 4}, {2}, {3}, {1}], [{0, 4}, {2}, {3}, {1}]]
0.13963026918775923
618
[[{0, 4}, {3}, {1, 2}], [{0, 4}, {3}, {1, 2}], [{0, 4}, {3}, {1, 2}], [{0, 4}, {3}, {1, 2}], [{0, 4}, {3}, {1, 2}]]
0.13963489875416538
6277
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{0}, {1, 4}, {3}, {2}], [{0}, {1, 4}, {3}, {2}], [{0}, {1, 4}, {3}, {2}]]
0.13969619774358732
539
[[{1, 2, 3, 4}, {0}], [{1, 2, 3, 4}, {0}], [{1, 2, 3, 4}, {0}], [{0}, {3}, {1, 2, 4}], [{0}, {3}, {2, 4}, {1}]]
0.13972011708008883
1029
[[{1, 2, 4}, {0, 3}], [{1, 2, 4}, {0, 3}], [{1, 2, 4}, {3}, {0}], [{3}, {0}, {2}, {1, 4}], [{3}, {0}, {2}, {1, 4}]]
0.13974528600151476
4651
[[{0, 1, 2, 3, 4}], [{1, 2, 4}, {0, 3}], [{1, 2, 4}, {3}, {0}], [{3}, {0}, {2, 4}, {1}], [{3}, {0}, {2, 4}, {1}]]
0.13977593944824884
438


0.15641069780875141
5100
[[{0, 1, 2, 3, 4}], [{1, 4}, {0, 2, 3}], [{0, 2, 3}, {4}, {1}], [{0, 2, 3}, {4}, {1}], [{4}, {1}, {2}, {0, 3}]]
0.15646280977467444
1659
[[{1, 3}, {0, 2, 4}], [{1, 3}, {4}, {0, 2}], [{1, 3}, {4}, {0, 2}], [{1, 3}, {4}, {0, 2}], [{1, 3}, {4}, {2}, {0}]]
0.1565350323065186
6598
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{0, 2}, {3, 4}, {1}], [{0, 2}, {3, 4}, {1}], [{1}, {2}, {0}, {4}, {3}]]
0.15656621697708117
7058
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{3, 4}, {0, 1, 2}], [{3, 4}, {0}, {2}, {1}], [{0}, {2}, {1}, {4}, {3}]]
0.15659664977109766
4769
[[{0, 1, 2, 3, 4}], [{1, 2}, {0, 3, 4}], [{1, 2}, {4}, {0, 3}], [{1, 2}, {4}, {0, 3}], [{4}, {0, 3}, {2}, {1}]]
0.156630081905119
6629
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{1, 3, 4}, {0, 2}], [{3}, {1, 4}, {2}, {0}], [{3}, {2}, {0}, {4}, {1}]]
0.1566313128844751
5832
[[{0, 1, 2, 3, 4}], [{2}, {0, 1, 3, 4}], [{2}, {0, 1, 3, 4}], [{2}, {0, 1, 3, 4}], [{2}, {0, 1}, {4}, {3}]]
0.15664819409406042
7168
[[{0, 1, 2, 3, 4}], 

[[{1, 3, 4}, {0, 2}], [{1, 3, 4}, {0, 2}], [{0, 2}, {4}, {1, 3}], [{0, 2}, {4}, {1, 3}], [{4}, {1, 3}, {2}, {0}]]
0.16956086044472557
6004
[[{0, 1, 2, 3, 4}], [{3}, {0, 1, 2, 4}], [{3}, {2, 4}, {0, 1}], [{3}, {2, 4}, {0, 1}], [{3}, {2, 4}, {0, 1}]]
0.16958108644054537
5310
[[{0, 1, 2, 3, 4}], [{1}, {0, 2, 3, 4}], [{1}, {0, 2, 3, 4}], [{1}, {0, 2, 3, 4}], [{1}, {0}, {3}, {2, 4}]]
0.16958612451414118
1510
[[{1, 3, 4}, {0, 2}], [{1, 3, 4}, {0, 2}], [{1, 3, 4}, {0, 2}], [{0, 2}, {3, 4}, {1}], [{0, 2}, {1}, {4}, {3}]]
0.16959585768285146
2826
[[{2, 4}, {0, 1, 3}], [{2, 4}, {1, 3}, {0}], [{2, 4}, {1, 3}, {0}], [{2, 4}, {1, 3}, {0}], [{0}, {4}, {2}, {3}, {1}]]
0.16960158517430615
3460
[[{3, 4}, {0, 1, 2}], [{3, 4}, {0, 1, 2}], [{3, 4}, {0, 1, 2}], [{3, 4}, {2}, {0, 1}], [{2}, {0, 1}, {4}, {3}]]
0.16960254391326252
6069
[[{0, 1, 2, 3, 4}], [{3}, {0, 1, 2, 4}], [{3}, {0, 1, 2, 4}], [{3}, {2, 4}, {0, 1}], [{3}, {2, 4}, {0, 1}]]
0.169630584548949
1626
[[{1, 3}, {0, 2, 4}], [{1, 3}, {2, 4}, {0}], 

2012
[[{1}, {0, 2, 3, 4}], [{1}, {2, 3}, {0, 4}], [{1}, {2, 3}, {0, 4}], [{1}, {2, 3}, {0, 4}], [{1}, {2, 3}, {4}, {0}]]
0.1863466788652177
7205
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{0}, {1}, {3}, {2, 4}], [{0}, {1}, {3}, {4}, {2}]]
0.1863747499953071
7207
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{0}, {1}, {4}, {2, 3}], [{0}, {1}, {4}, {3}, {2}]]
0.1864794685393144
6634
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{1, 3, 4}, {0, 2}], [{4}, {1, 3}, {2}, {0}], [{4}, {1, 3}, {2}, {0}]]
0.18649614067506673
7214
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{0}, {1, 4}, {3}, {2}], [{0}, {3}, {2}, {4}, {1}]]
0.18662219944191913
3224
[[{2}, {0, 1, 3, 4}], [{2}, {0, 1, 3, 4}], [{2}, {0, 1, 3, 4}], [{2}, {1, 4}, {0, 3}], [{2}, {0, 3}, {4}, {1}]]
0.18668244145759777
6560
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{1, 2}, {0, 3, 4}], [{0, 3, 4}, {2}, {1}], [{2}, {1}, {3, 4}, {0}]]
0.18671562794480248
6383
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{1, 2, 

0.20633309847856957
7132
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{3}, {0, 1, 2, 4}], [{3}, {0, 1, 2, 4}], [{3}, {1, 4}, {0, 2}]]
0.20642159551672698
2681
[[{2, 3}, {0, 1, 4}], [{2, 3}, {0, 1, 4}], [{2, 3}, {1, 4}, {0}], [{2, 3}, {1, 4}, {0}], [{1, 4}, {0}, {3}, {2}]]
0.20646731071432242
1502
[[{1, 3, 4}, {0, 2}], [{1, 3, 4}, {0, 2}], [{1, 3, 4}, {0, 2}], [{0, 2}, {3, 4}, {1}], [{0, 2}, {3, 4}, {1}]]
0.20662218503953833
5526
[[{0, 1, 2, 3, 4}], [{2, 3}, {0, 1, 4}], [{2, 3}, {1}, {0, 4}], [{2, 3}, {1}, {0, 4}], [{2, 3}, {1}, {4}, {0}]]
0.20691358111345262
5936
[[{0, 1, 2, 3, 4}], [{3, 4}, {0, 1, 2}], [{3, 4}, {0, 1, 2}], [{3, 4}, {0, 1, 2}], [{0, 1, 2}, {4}, {3}]]
0.20692963157674957
2501
[[{2, 3, 4}, {0, 1}], [{2, 3, 4}, {0, 1}], [{2, 3, 4}, {0, 1}], [{0, 1}, {3}, {2, 4}], [{0, 1}, {3}, {2, 4}]]
0.206968370186512
7271
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{0, 4}, {3}, {1, 2}], [{3}, {1, 2}, {4}, {0}]]
0.20700483142038703
6624
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [

[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{2, 3}, {0, 1, 4}], [{2, 3}, {1, 4}, {0}], [{2, 3}, {0}, {4}, {1}]]
0.25329811595200646
7476
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{3}, {0, 1, 2, 4}], [{3}, {1, 2, 4}, {0}]]
0.2534413783434751
7488
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{3}, {0, 1, 2, 4}], [{3}, {2}, {0, 1, 4}]]
0.2534413783434751
7077
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{3, 4}, {0, 1, 2}], [{3, 4}, {0, 1, 2}], [{3, 4}, {1}, {0, 2}]]
0.2536514700466737
5050
[[{0, 1, 2, 3, 4}], [{1, 3}, {0, 2, 4}], [{1, 3}, {0, 2, 4}], [{1, 3}, {0, 2, 4}], [{3}, {1}, {0}, {4}, {2}]]
0.2536651597481451
7466
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{3, 4}, {0, 1, 2}], [{0, 1, 2}, {4}, {3}]]
0.25379894742224446
1694
[[{1, 3}, {0, 2, 4}], [{1, 3}, {0, 2, 4}], [{1, 3}, {2}, {0, 4}], [{1, 3}, {2}, {0, 4}], [{1, 3}, {2}, {0, 4}]]
0.25387281042964527
815
[[{1, 2, 3}, {0, 4}], [{1, 2, 3}, {0, 4}], [{1, 2, 3}, {0, 4}], [{1, 2, 3}, {0, 4}], [{3},

0.29525258345908245
6438
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{1, 2, 3}, {0, 4}], [{0, 4}, {2, 3}, {1}], [{0, 4}, {2, 3}, {1}]]
0.29591876957122143
6965
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{2, 4}, {0, 1, 3}], [{2, 4}, {0, 1, 3}], [{4}, {2}, {1}, {0, 3}]]
0.2961737531594568
7250
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{1, 2, 3, 4}, {0}], [{0}, {2, 3}, {1, 4}]]
0.29620702968799084
2736
[[{2, 3}, {0, 1, 4}], [{2, 3}, {0, 1, 4}], [{2, 3}, {0, 1, 4}], [{2, 3}, {0, 1, 4}], [{2, 3}, {4}, {0, 1}]]
0.2967840105958853
6538
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{1, 2, 4}, {0, 3}], [{1, 2, 4}, {0, 3}], [{1, 2, 4}, {3}, {0}]]
0.296873450797261
1948
[[{1, 4}, {0, 2, 3}], [{1, 4}, {0, 2, 3}], [{1, 4}, {0, 2, 3}], [{1, 4}, {0, 2, 3}], [{0, 2, 3}, {4}, {1}]]
0.2970670152283125
6447
[[{0, 1, 2, 3, 4}], [{0, 1, 2, 3, 4}], [{1, 2, 3}, {0, 4}], [{0, 4}, {2}, {1, 3}], [{0, 4}, {2}, {1, 3}]]
0.2971091394705196
5672
[[{0, 1, 2, 3, 4}], [{2, 4}, {0, 1, 3}], [{2, 4}, {0, 1, 3}], [{2,

In [18]:
# SROCC
def load_obj(name):
    with open('./ntask/' + name + '.pkl', 'rb') as f:
        return pickle.load(f)

real_log = load_obj('real_results_'+data+'_'+backbone_type+'_' + real_date)
idx = real_log['layout']
real_results = real_log['val_acc']
real = np.mean(real_results[:], axis=1)

est = [-layout_list[i].score for i in idx]
scipy.stats.spearmanr(est,real)

SpearmanrResult(correlation=-0.14736842105263157, pvalue=0.5352473458918693)

# Qizheng

In [5]:
weights = torch.load(weight_PATH)['state_dict']

In [27]:
D = 0
for key in weights:
    if 'weight' in key and 'backbone' in key:
        D += 1
D = int(D/3)

D_params = {task: [] for task in tasks}
for key in weights:
    if 'weight' in key and 'backbone' in key:
        t_idx = int(key.split('.')[2]) % 3
        D_params[tasks[t_idx]].append(weights[key].cpu().numpy().reshape(-1))

assert len(D_params[tasks[0]]) == len(D_params[tasks[1]]) == len(D_params[tasks[2]]) == D

In [28]:
# compute PSA
PSA = np.ones((D,T,T))
for d in range(D):
    for two_task in itertools.combinations(tasks, 2):
#         rho = scipy.stats.spearmanr(D_params[two_task[0]][d], D_params[two_task[1]][d])[0]
        rho = np.linalg.norm(D_params[two_task[0]][d] - D_params[two_task[1]][d])
        PSA[d,tasks.index(two_task[0]),tasks.index(two_task[1])] = PSA[d,tasks.index(two_task[1]),tasks.index(two_task[0])] = rho
print(PSA)

[[[ 1.          5.02830696  4.037117  ]
  [ 5.02830696  1.          3.63590288]
  [ 4.037117    3.63590288  1.        ]]

 [[ 1.          2.27176285  6.36746168]
  [ 2.27176285  1.          5.70963478]
  [ 6.36746168  5.70963478  1.        ]]

 [[ 1.          8.24909496  6.26709414]
  [ 8.24909496  1.          5.65949821]
  [ 6.26709414  5.65949821  1.        ]]

 [[ 1.          1.55947053  5.54519272]
  [ 1.55947053  1.          5.11428213]
  [ 5.54519272  5.11428213  1.        ]]

 [[ 1.          7.70594931  5.81748295]
  [ 7.70594931  1.          5.33898735]
  [ 5.81748295  5.33898735  1.        ]]

 [[ 1.          1.7155509   5.9160285 ]
  [ 1.7155509   1.          4.93741941]
  [ 5.9160285   4.93741941  1.        ]]

 [[ 1.          5.39825487  4.28039837]
  [ 5.39825487  1.          3.43453503]
  [ 4.28039837  3.43453503  1.        ]]

 [[ 1.          2.56993985  4.3446207 ]
  [ 2.56993985  1.          3.0612309 ]
  [ 4.3446207   3.0612309   1.        ]]

 [[ 1.          5.346772

In [54]:
# 'segment_semantic', 'normal', 'depth_zbuffer'
distance_sum = []
for d in range(D):
    distance_sum.append(PSA[d,0,2])
print('segment_semantic vs normal')
print(mean(distance_sum))
print(max(distance_sum))
print(min(distance_sum))
print('=' * 20)

distance_sum = []
for d in range(D):
    distance_sum.append(PSA[d,0,1])
print('segment_semantic vs depth_zbuffer')
print(mean(distance_sum))
print(max(distance_sum))
print(min(distance_sum))
print('=' * 20)

distance_sum = []
for d in range(D):
    distance_sum.append(PSA[d,1,2])
print('normal vs depth_zbuffer')
print(mean(distance_sum))
print(max(distance_sum))
print(min(distance_sum))
print('=' * 20)

segment_semantic vs normal
5.4316812422540455
20.918720245361328
0.6082245707511902
segment_semantic vs depth_zbuffer
5.506681526700656
22.134845733642578
0.6820645928382874
normal vs depth_zbuffer
3.010250359773636
8.058507919311523
0.27812889218330383


In [30]:
np.min(PSA)

0.27812889218330383

In [29]:
# PSA[PSA==1] = 0
np.max(PSA)

22.134845733642578