In [1]:
import sys
sys.path.append('../')
import numpy as np
import random
import os
import argparse
from pathlib import Path
import scipy.stats
import itertools
from statistics import mean
import pickle
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from main.layout import Layout
from main.algorithms import enum_layout_wo_rdt, init_S, coarse_to_fined
from data.domainnet_dataloader import DomainNet, MultiDomainSampler
from main.efficientnet import EffNetV2_FC

In [2]:
# backbone
backbone_type = 'effnetv2'
D = coarse_B = 4
fined_B = 42
mapping = {0: [0, 1, 2, 3, 4, 5, 6], 1: [7, 8, 9, 10],
           2: [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25],
           3: [26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41], 4: [42]}

# data
data_type = 'DomainNet'
tasks = ['real', 'painting', 'quickdraw', 'clipart', 'infograph', 'sketch']
cls_num = {task: 345 for task in tasks}
T = len(tasks)

# dataloader
dataroot = '/work/lijunzhang_umass_edu/data/multibranch/data/DomainNet/'
with open(os.path.join(dataroot, 'DomainNet.json'), 'r') as f:
    split_domain_info = json.load(f)
sampler = MultiDomainSampler(split_domain_info, 'val', 1, tasks, False)
dataset = DomainNet(dataroot, 'val')
dataloader = DataLoader(dataset, batch_sampler=sampler)

# ind. weights
layout_idx = 5 # ind. models
ckpt_PATH = '/work/lijunzhang_umass_edu/data/multibranch/checkpoint/DomainNet/'
weight_PATH = ckpt_PATH + 'verify_1228/5/' + '_'.join(tasks) + '.model'

# real results
real_date = '1228'

In [3]:
# layout enumerate
with open('../ntask/DomainNetLayout.pkl', 'rb') as f:
    layout_list = pickle.load(f)

In [4]:
# ind. model
layout = layout_list[layout_idx]
print('Coarse Layout:', flush=True)
print(layout, flush=True)

layout = coarse_to_fined(layout, fined_B, mapping)
print('Fined Layout:', flush=True)
print(layout, flush=True)

# model
with torch.no_grad():
    model = EffNetV2_FC(tasks=tasks, layout=layout, cls_num=cls_num)
    model = model.cuda()
    
# load ind. model weights
model.load_state_dict(torch.load(weight_PATH)['state_dict'])

Coarse Layout:
[[{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}]]
Fined Layout:
[[{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}

<All keys matched successfully>

In [5]:
def organize_batch(batch):
        assert batch['img'].shape[0] == len(tasks)
        assert len(batch['img_idx']) == len(tasks)
        new_batch = {}
        for d_idx, task in enumerate(tasks):
            new_batch['%s_img' % task] = batch['img'][d_idx: (d_idx + 1)]
            new_batch['%s_gt' % task] = batch['img_idx'][d_idx: (d_idx + 1) ]
        return new_batch

# extract features
K = 500
D_pos = [mapping[d][-1] for d in range(D)]
D_feats = {task: [None]*D for task in tasks}

for i, data in enumerate(dataloader):
    if i >= K:
        break
    
    with torch.no_grad():
        batch = organize_batch(data)
        for t in tasks:
            x = batch['%s_img' % t].cuda()
            all_feats = model.extract_features(x, t)
            for d in range(D):
                vec_feat = all_feats[D_pos[d]].reshape(1,-1)
                if D_feats[t][d] is None:
                    D_feats[t][d] = vec_feat
                else:
                    temp = D_feats[t][d]
                    D_feats[t][d] = np.vstack((temp,vec_feat))

In [6]:
# compute RDM
RDM = {task: np.zeros((D,K,K)) for task in tasks}
for t in tasks:
    for d in range(D):
        RDM[t][d,:,:] = 1 - np.corrcoef(D_feats[t][d])

In [7]:
# compute RSA
RSA = np.ones((D,T,T))
for d in range(D):
    for two_task in itertools.combinations(tasks, 2):
        rho = scipy.stats.spearmanr(np.tril(RDM[two_task[0]][d,:,:]).reshape(-1), np.tril(RDM[two_task[1]][d,:,:]).reshape(-1))[0]
        RSA[d,tasks.index(two_task[0]),tasks.index(two_task[1])] = RSA[d,tasks.index(two_task[1]),tasks.index(two_task[0])] = rho
print(RSA)

[[[1.         0.85597928 0.8647014  0.86410371 0.85938134 0.85616359]
  [0.85597928 1.         0.85278386 0.85730403 0.85871903 0.84908825]
  [0.8647014  0.85278386 1.         0.85450798 0.86215098 0.85290658]
  [0.86410371 0.85730403 0.85450798 1.         0.87040819 0.84496261]
  [0.85938134 0.85871903 0.86215098 0.87040819 1.         0.85344221]
  [0.85616359 0.84908825 0.85290658 0.84496261 0.85344221 1.        ]]

 [[1.         0.86071242 0.86547909 0.86543589 0.85274173 0.85618191]
  [0.86071242 1.         0.86466844 0.85906337 0.85240213 0.85012968]
  [0.86547909 0.86466844 1.         0.86041607 0.85298934 0.8510069 ]
  [0.86543589 0.85906337 0.86041607 1.         0.852628   0.84587004]
  [0.85274173 0.85240213 0.85298934 0.852628   1.         0.85865953]
  [0.85618191 0.85012968 0.8510069  0.84587004 0.85865953 1.        ]]

 [[1.         0.85526802 0.85401772 0.85701891 0.85784956 0.85839344]
  [0.85526802 1.         0.85594198 0.86335698 0.85733581 0.85909114]
  [0.85401772 0.

In [8]:
# set score for each layout
for layout in layout_list:
    score = 0
    for d in range(D):
        C_cluster = []
        for task_set in layout[d]:
            max_dis = 0
            for two_task in itertools.combinations(task_set, 2):
                max_dis = max(max_dis, 1-RSA[d,two_task[0],two_task[1]])
            C_cluster.append(max_dis)
        score += mean(C_cluster)
    layout.score = score

In [9]:
# sort layout
layout_order = sorted(range(len(layout_list)), key=lambda k: layout_list[k],reverse=False)
step = 1
for i in range(0,len(layout_order),step):
    if i>100:
        break
    print(layout_order[i])
    L = layout_list[layout_order[i]]
    print(L)
    print(L.score)
    print('=' * 100)

5
[[{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}], [{0}, {1}, {2}, {3}, {5}, {4}]]
0
14
[[{0}, {1}, {2}, {5}, {3, 4}], [{0}, {1}, {2}, {5}, {4}, {3}], [{0}, {1}, {2}, {5}, {4}, {3}], [{0}, {1}, {2}, {5}, {4}, {3}]]
0.02591836162797658
10320
[[{0, 2}, {1}, {3}, {5}, {4}], [{1}, {3}, {5}, {4}, {2}, {0}], [{1}, {3}, {5}, {4}, {2}, {0}], [{1}, {3}, {5}, {4}, {2}, {0}]]
0.027059720502592956
6750
[[{0, 3}, {1}, {2}, {5}, {4}], [{1}, {2}, {5}, {4}, {3}, {0}], [{1}, {2}, {5}, {4}, {3}, {0}], [{1}, {2}, {5}, {4}, {3}, {0}]]
0.027179258636245308
60
[[{0}, {1}, {2, 4}, {5}, {3}], [{0}, {1}, {5}, {3}, {4}, {2}], [{0}, {1}, {5}, {3}, {4}, {2}], [{0}, {1}, {5}, {3}, {4}, {2}]]
0.027569803201217623
4720
[[{0, 4}, {1}, {2}, {5}, {3}], [{1}, {2}, {5}, {3}, {4}, {0}], [{1}, {2}, {5}, {3}, {4}, {0}], [{1}, {2}, {5}, {3}, {4}, {0}]]
0.02812373287747638
460
[[{0}, {1, 4}, {2}, {5}, {3}], [{0}, {2}, {5}, {3}, {4}, {1}], [{0}, {2}, {5}, {3}, {4}, {1}], [{0}, {2

In [10]:
# SROCC
def load_obj(name):
    with open('../ntask/' + name + '.pkl', 'rb') as f:
        return pickle.load(f)

real_log = load_obj('real_results_'+data_type+'_'+backbone_type+'_' + real_date)
idx = real_log['layout']
real_results = real_log['val_acc']
real = np.mean(real_results[:], axis=1)

est = [-layout_list[i].score for i in idx]
scipy.stats.spearmanr(est,real)

SpearmanrResult(correlation=0.519534797516449, pvalue=7.093542535844954e-09)