In [1]:
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import os
import itertools
import sys
sys.path.append('../')
from pathlib import Path
import random
from scipy import stats
from scipy.optimize import linear_sum_assignment
from collections import OrderedDict
from ptflops import get_model_complexity_info
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torch.utils.data import DataLoader

from framework.layer_node import Conv2dNode, InputNode
from main.layout import Layout
from main.algorithms import enum_layout_wo_rdt, init_S, coarse_to_fined
from main.auto_models import MTSeqBackbone, MTSeqModel, ComputeBlock
from main.head import ASPPHeadNode
from main.trainer import Trainer
from main.algs_FMTL import simple_alignment, complex_alignment

from data.nyuv2_dataloader_adashare import NYU_v2
from data.pixel2pixel_loss import NYUCriterions
from data.pixel2pixel_metrics import NYUMetrics

In [2]:
assert torch.cuda.is_available()

In [3]:
import pickle
def save_obj(obj, name):
    with open('./exp/'+ name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(name):
    with open('./exp/' + name + '.pkl', 'rb') as f:
        return pickle.load(f)

# backbone and data

In [4]:
# backbone
# mobilenet
backbone_type = 'mobilenet'
prototxt = '../models/mobilenetv2.prototxt'
D = coarse_B = 5
mapping = {0:[0,1,2,3,4,5,6], 1:[7,8,9,10,11,12,13,14,15,16,17], 2:[18,19,20,21,22], 
           3:[23,24,25,26,27,28,29,30], 4:[31], 5:[32]}

In [12]:
# data
# NYUv2
data = 'NYUv2'
dataroot = '/mnt/nfs/work1/huiguan/lijunzhang/policymtl/data/NYUv2/'
tasks = ['segment_semantic', 'normal', 'depth_zbuffer']
cls_num = {'segment_semantic': 40, 'normal':3, 'depth_zbuffer': 1}

dataset = NYU_v2(dataroot, 'train', crop_h=321, crop_w=321)
trainDataloader = DataLoader(dataset, 32, shuffle=True)

criterionDict = {}
metricDict = {}
for task in tasks:
    print(task, flush=True)
    criterionDict[task] = NYUCriterions(task)
    metricDict[task] = NYUMetrics(task)

input_dim = (3,321,321)
T = len(tasks)

segment_semantic
normal
depth_zbuffer


In [6]:
# fix dataloader for fix mini-batches
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)
trainDataloaderFix = DataLoader(dataset, 96, shuffle=True,worker_init_fn=seed_worker,generator=g)

In [7]:
len(trainDataloaderFix)

9

In [8]:
# ind. weights
ckpt_PATH = '/mnt/nfs/work1/huiguan/lijunzhang/multibranch/checkpoint/'
weight_PATH = ckpt_PATH + 'NYUv2/ind/mobilenet/segment_semantic_normal_depth_zbuffer.model' # NYUv2 + MobileNetV2, from the same init
# weight_PATH = ckpt_PATH + 'NYUv2/baseline/WPreMobile/2/segment_semantic_normal_depth_zbuffer.model' # NYUv2 + MobileNetV2, from the same init

# load independent model weights

In [9]:
with torch.no_grad():
    backbone = MTSeqBackbone(prototxt)
    fined_B = len(backbone.basic_blocks)
    feature_dim = backbone(torch.rand(1,3,224,224)).shape[1]

In [10]:
# ind. layout
S = []
for i in range(fined_B):
    S.append([set([x]) for x in range(T)])
layout = Layout(T, fined_B, S) 
print('Ind. Layout:', flush=True)
print(layout, flush=True)

# model
with torch.no_grad():
    model = MTSeqModel(prototxt, layout=layout, feature_dim=feature_dim, cls_num=cls_num)
#     model = model.cuda()

    # load ind. model weights
    model.load_state_dict(torch.load(weight_PATH)['state_dict'])

Ind. Layout:
[[{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}]]
Construct MTSeqModel from Layout:
[[{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}], [{0}, {1}, {2}],

In [15]:
# compute r0 --> stop convergency loss for layout convergency iter estimation
loss_lst = {task:[] for task in tasks}
# model = model.cuda()
model.train()
for i, data in enumerate(trainDataloader):
    if i > 50:
        break
    x = data['input'].cuda()
    output = model(x)
    for task in tasks:
        y = data[task].cuda()
        if task + '_mask' in data:
            tloss = criterionDict[task](output[task], y, data[task + '_mask'].cuda())
        else:
            tloss = criterionDict[task](output[task], y)
        loss_lst[task].append(tloss.item())
        print('{}: {:.4f}'.format(task,tloss.item()))
    print('-'*30)

target = {task: np.mean(loss_lst[task]) for task in tasks}
print('r0: {}'.format(target))
# save_obj(target, 'r0')

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same

In [10]:
target = load_obj('r0')

In [11]:
target

{'segment_semantic': 0.7207628560066223,
 'normal': 0.07310347080230713,
 'depth_zbuffer': 1.1167707586288451}

# enum layouts and channel alignment

In [12]:
# enum layout
layout_list = [] 
S0 = init_S(T, coarse_B) # initial state
L = Layout(T, coarse_B, S0) # initial layout
layout_list.append(L)
enum_layout_wo_rdt(L, layout_list)

In [13]:
align_choice = 2 # 0: no align; 1: simple align (use out_ord only); 2: complex align

In [14]:
model = model.cpu()
if align_choice == 1:
    simple_alignment(model, tasks)
elif align_choice == 2:
    complex_alignment(model, tasks)
elif align_choice == 0:
    pass

# est. convergence rate

In [None]:
seed = 10
total_iters, short_iters, start, step, batch_num = 20000, 200, 0, 50, 10
load_weight = False
smooth_weight, target_ratio = 0.0, 0.8


layout_est = []
layout_idx = -1
# For each layout
for L in layout_list:
    set_seed(seed)
    
    layout_idx += 1
    print('Layout {}'.format(layout_idx))
    layout = coarse_to_fined(L, fined_B, mapping)
    print('Fined Layout:', flush=True)
    print(layout, flush=True)
    
    mtl_model = MTSeqModel(prototxt, layout=layout, feature_dim=feature_dim, cls_num=cls_num, verbose=False)
    if load_weight:
        # Step 1: create weight init state_dict
        mtl_init = OrderedDict()
        for name, module in mtl_model.named_modules():
            if isinstance(module, ComputeBlock):
                task_set = module.task_set
                layer_idx = module.layer_idx
                if len(task_set) > 1:
                    merge_flag = True
                else:
                    # Type 1: save the whole block weights from the corresponding ind. model when no merging
                    merge_flag = False
                    for block in model.backbone.mtl_blocks:
                        if task_set == block.task_set and block.layer_idx == layer_idx:
                            for ind_name, param in block.named_parameters():
                                mtl_init['.'.join([name, ind_name])] = param  
                            # for BN running mean and running var
                            for ind_name, param in block.named_buffers():
                                mtl_init['.'.join([name, ind_name])] = param

            # # Type 2: when the current block have merged operators, save mean weights for convs
            elif isinstance(module, Conv2dNode) and merge_flag: 
                task_convs = [] # store conv weights from task's ind. block
                for task in task_set:
                    # identify task-corresponding block in the well-trained ind. models 
                    for block in model.backbone.mtl_blocks:
                        if task in block.task_set and block.layer_idx == layer_idx:
                            task_module = block.compute_nodes[int(name.split('.')[-1])]  
                            temp_weight = task_module.basicOp.weight # no channel alignment or no align variable
                            if align_choice == 1 and task_module.out_ord is not None: # simple alignment
                                temp_weight = temp_weight[task_module.out_ord]
                            elif align_choice == 2: # complex alignment
                                if task_module.in_ord is not None:
                                    temp_weight = temp_weight[:,task_module.in_ord]
                                if task_module.out_ord is not None: 
                                    temp_weight = temp_weight[task_module.out_ord]
                            task_convs.append(temp_weight)
                weight_anchor = torch.mean(torch.stack(task_convs),dim=0)
                mtl_init[name+'.basicOp.weight'] = weight_anchor

            # Type 3: save heads' weights
            elif 'heads' in name and isinstance(module, ASPPHeadNode): 
                ind_head = model.heads[name.split('.')[-1]]
                for ind_name, param in ind_head.named_parameters():
                    mtl_init['.'.join([name, ind_name])] = param
                for ind_name, param in ind_head.named_buffers():
                    mtl_init['.'.join([name, ind_name])] = param
        mtl_model.load_state_dict(mtl_init,strict=False)
        print('Finish Weight Loading.', flush=True)
        print('-'*80)
    
    mtl_model = mtl_model.cuda()
    mtl_model.train()
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, mtl_model.parameters()), lr=0.01)
    
    # For each task
    alpha_lst = {task: [] for task in tasks}
    est_iter_lst ={task: [] for task in tasks}
    final_loss_lst = {task: [] for task in tasks}
    
    # Step 2: Save short train loss list
    for idx, data in enumerate(trainDataloaderFix):
        if idx >= batch_num:
            break
        x = data['input'].cuda()
        if load_weight:
            mtl_model.load_state_dict(mtl_init,strict=False)
        else:
            mtl_model.reset_parameters()
        
        loss_lst = {task: [] for task in tasks}
        for it in range(short_iters):
            optimizer.zero_grad()
            output = mtl_model(x)
            loss = 0
            for task in tasks:
                y = data[task].cuda()
                if task + '_mask' in data:
                    tloss = criterionDict[task](output[task], y, data[task + '_mask'].cuda())
                else:
                    tloss = criterionDict[task](output[task], y)
                loss_lst[task].append(tloss.item())
                loss += tloss
            loss.backward()
            optimizer.step()
    
        for task in tasks:
            print('Task {}:'.format(task))
            sm_loss_lst = smooth(loss_lst[task], smooth_weight)

            # Step 3: Take smoothed loss samples from window slices
            loss_samples = window_loss_samples(sm_loss_lst, start, step=step)
            print('\t\tLoss Samples: {}'.format(loss_samples))
            if loss_samples == False:
                print('\t\tBad Loss Samples')
                continue

            # Step 4: Compute convergence rate 
            alpha = compute_alpha2(loss_samples)
            alpha_lst[task].append(alpha)
            print('\t\tAlpha: {}'.format(alpha))

            # Step 5,6: Estimate final loss after 20000 iters and iters to reach target loss
            n = (total_iters - start)//step
            est_n, final_loss = est_final_loss2(loss_samples, n, alpha, target[task]*target_ratio)
            if est_n != -1:
                est_iter = start + est_n*step
            else:
                est_iter = est_n
            print('\t\tEst Iter: {}'.format(est_iter))
            print('\t\tFinal Loss: {}'.format(final_loss))
            est_iter_lst[task].append(est_iter)
            final_loss_lst[task].append(final_loss)
        print('-'*80)
           
    layout_est.append({'alpha':alpha_lst, 'est_iter': est_iter_lst, 'est_loss': final_loss_lst})
    print('='*80)
    torch.cuda.empty_cache()

Layout 0
Fined Layout:
[[{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}], [{0, 1, 2}]]
Task segment_semantic:
		Loss Samples: [3.2886670446395874, 2.958666591644287, 2.836396155357361, 2.7248942041397095]
		Alpha: 0.09285630247136648
		Est Iter: 1100
		Final Loss: -41.23616573052018
Task normal:
		Loss Samples: [0.1140168035030365, 0.08978825688362121, 0.0878688395023346, 0.08668279767036438]
		Alpha: 0.1898632418534569
		Est Iter: 1450
		Final Loss: -0.33502806546669794
Task depth_zbuffer:
		Loss Samples: [1.3181030833721161, 0.9099568724632263, 0.6679672265052795, 0.4986503142118454]
		Alpha: 0.6831871100992039
		Est Iter: 100
		Final Lo

In [26]:
layout_est

[{'alpha': {'segment_semantic': [0.08767862140094398,
    0.16889808907113696,
    0.27173512013653806,
    0.13678814705173523,
    0.21557560522823227],
   'normal': [0.10529386608303633,
    0.009307225004601889,
    0.10804536014145633,
    0.16873231107424447,
    0.140873784388299],
   'depth_zbuffer': [0.6728202011950046,
    1.7250111370604329,
    0.3633880413110632,
    0.37397310308029286,
    1.2183369648647833]},
  'est_iter': {'segment_semantic': [1000, 1000, 750, 600, 950],
   'normal': [700, 100, 100, 100, 100],
   'depth_zbuffer': [100, 100, 100, 100, 100]},
  'est_loss': {'segment_semantic': [-42.52371637728947,
    -38.834884491691994,
    -50.80944338488352,
    -66.52761254552476,
    -41.640885575688145],
   'normal': [-0.3945712930723772,
    -0.349581303161049,
    -0.3808884450914222,
    -0.6318260920549031,
    -0.4086360111280515],
   'depth_zbuffer': [-24.06044101768165,
    0.39352482945122924,
    -49.70203580426662,
    -46.607082384837746,
    0.3440762

In [132]:
loss_samples = window_loss_samples(loss_lst['segment_semantic'], start, step=step, smooth_weight=0.1)
# loss_samples = window_loss_samples(loss_lst['segment_semantic'], start, indices=[10,20,30,50], smooth_weight=0.1)
print(loss_samples)

[3.6645989847183227, 3.5812311005592345, 3.485876207590103, 3.392624996685982]
False


In [127]:
# loss_samples=[2.9,2.4,2.1,1.99]
alpha = compute_alpha2(loss_samples)
est_final_loss2(loss_samples, 1000, alpha, target['segment_semantic'])

(19, -143.93745662272735)

In [128]:
alpha

0.6756569368775177

## helper functions in algs_EstCon.py

In [15]:
def smooth(scalars, weight):  # Weight between 0 and 1
    last = scalars[0]  # First value in the plot (first timestep)
    smoothed = list()
    for point in scalars:
        smoothed_val = last * weight + (1 - weight) * point  # Calculate smoothed value
        smoothed.append(smoothed_val)                        # Save it
        last = smoothed_val                                  # Anchor the last smoothed value
    return smoothed

def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

In [16]:
def window_loss_samples(loss_lst, start, step=None, indices=None):
    if (step != None and start + step * 4 > len(loss_lst)) or (indices !=None and len(indices) != 4):
        print('Wrong Window Slices for Loss Samples!')
        return False
    if step != None:
        samples = [np.mean(loss_lst[(start+step*i):(start+step*(i+1))]) for i in range(4)]
    elif indices !=None:
        samples = [np.mean(loss_lst[start:indices[i]]) if i == 0 else np.mean(loss_lst[indices[i-1]:indices[i]]) for i in range(4)]
    else:
        print('No Slices or Step designed for Loss Samples!')
        return False
    if judge_loss_samples(samples):
        return samples
    else:
        return False

def judge_loss_samples(samples):
    diff_prev = 1000
    for i in range(1, len(samples)):
        prev = samples[i-1]
        cur = samples[i]
        diff = prev - cur
        if prev < cur or diff > diff_prev:
            return False
        else:
            diff_prev = diff
    return True

In [17]:
def good_loss_samples(loss_lst, indices, tol=10):
    if len(indices) != 4:
        print('Wrong Indices for Loss Samples!')
        return False
    # sample 0
    samples = [loss_lst[indices[0]]]
    # sample 1,2,3
    diff_prev = 1000
    for idx in range(1,len(indices)):
        prev = samples[idx-1]
        cur = loss_lst[indices[idx]]
        diff = prev - cur
        if cur < prev and diff < diff_prev:
            samples.append(cur)
            diff_prev = diff
        else:
            for i in range(indices[idx]-tol,indices[idx]+tol):
                prev = samples[idx-1]
                cur = loss_lst[i]
                diff = prev - cur
                if cur < prev and diff < diff_prev:
                    samples.append(cur)
                    diff_prev = diff
                    break
                else:
                    return False
    return samples

In [18]:
def compute_alpha(loss_samples):
    return np.log(loss_samples[2]/loss_samples[1])/ np.log(loss_samples[1]/loss_samples[0])

def compute_alpha2(loss_samples):
    return np.log(np.abs((loss_samples[3]-loss_samples[2])/(loss_samples[2]-loss_samples[1])))/ \
            np.log(np.abs((loss_samples[2]-loss_samples[1])/(loss_samples[1]-loss_samples[0])))

In [19]:
def est_recov_n(loss_samples, target, alpha):
    up = np.log(target/loss_samples[0]) * (alpha - 1)
    down = np.log(loss_samples[1]/loss_samples[0])
    return np.log(up/down + 1) / np.log(alpha)

In [20]:
def est_final_loss(loss_samples, n, alpha):
    x0, x1 = np.log(loss_samples[0]), np.log(loss_samples[1])
    for i in range(2, n+1):
        x2 = alpha * (x1-x0) + x1
        x0 = x1
        x1 = x2
    return np.exp(x2)

def est_final_loss2(loss_samples, n, alpha, target):
    x0, x1 = np.log(loss_samples[0]-loss_samples[1]), np.log(loss_samples[1]-loss_samples[2])
    temp = loss_samples[2]
    est_n = -1
    flag = True
    
    for i in range(2, n+1):
        x2 = alpha * (x1-x0) + x1
        if np.isinf(x2):
            break
        else:
            temp -= np.exp(x2)
        if temp < target and flag:
            est_n = i
            flag = False
        x0 = x1
        x1 = x2
    return est_n, temp
