In [1]:
import numpy as np
import os
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import matplotlib.pyplot as plt
from scipy.special import softmax

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from framework.mtl_model import MTLModel
from framework.trainer import Trainer
from data.heads.pixel2pixel import ASPPHeadNode

from data.dataloader.cityscapes_dataloader import CityScapes
from data.metrics.pixel2pixel_loss import CityScapesCriterions
from data.metrics.pixel2pixel_metrics import CityScapesMetrics

from data.dataloader.nyuv2_dataloader import NYU_v2
from data.metrics.pixel2pixel_loss import NYUCriterions
from data.metrics.pixel2pixel_metrics import NYUMetrics

from data.dataloader.taskonomy_dataloader import Taskonomy
from data.metrics.pixel2pixel_loss import TaskonomyCriterions
from data.metrics.pixel2pixel_metrics import TaskonomyMetrics

# Data

In [2]:
dataroot = '/mnt/nfs/work1/huiguan/lijunzhang/policymtl/data/Cityscapes/'
# dataroot = 'datasets/Cityscapes/'

headsDict = nn.ModuleDict()
trainDataloaderDict = {}
valDataloaderDict = {}
criterionDict = {}
metricDict = {}

tasks = ['segment_semantic', 'depth_zbuffer']
task_cls_num = {'segment_semantic': 19, 'depth_zbuffer': 1}
for task in tasks:
    headsDict[task] = ASPPHeadNode(512, task_cls_num[task])

    # For model trainer
    dataset = CityScapes(dataroot, 'train', task, crop_h=224, crop_w=224)
    trainDataloaderDict[task] = DataLoader(dataset, 16, shuffle=True)

    dataset = CityScapes(dataroot, 'test', task)
    valDataloaderDict[task] = DataLoader(dataset, 8, shuffle=True)

    criterionDict[task] = CityScapesCriterions(task)
    metricDict[task] = CityScapesMetrics(task)

In [25]:
dataroot = '/mnt/nfs/work1/huiguan/lijunzhang/policymtl/data/NYUv2/'
# dataroot = 'datasets/NYUv2/'

headsDict = nn.ModuleDict()
trainDataloaderDict = {}
valDataloaderDict = {}
criterionDict = {}
metricDict = {}

tasks = ['segment_semantic', 'normal', 'depth_zbuffer']
task_cls_num = {'segment_semantic': 40, 'normal':3, 'depth_zbuffer': 1}
for task in tasks:
    headsDict[task] = ASPPHeadNode(512, task_cls_num[task])

    # For model trainer
    dataset = NYU_v2(dataroot, 'train', task, crop_h=321, crop_w=321)
    trainDataloaderDict[task] = DataLoader(dataset, 16, shuffle=True)

    dataset = NYU_v2(dataroot, 'test', task, crop_h=321, crop_w=321)
    valDataloaderDict[task] = DataLoader(dataset, 8, shuffle=True)

    criterionDict[task] = NYUCriterions(task)
    metricDict[task] = NYUMetrics(task)

In [22]:
dataroot = '/mnt/nfs/work1/huiguan/lijunzhang/policymtl/data/Taskonomy/'
# dataroot = 'datasets/Taskonomy/'

headsDict = nn.ModuleDict()
trainDataloaderDict = {}
valDataloaderDict = {}
criterionDict = {}
metricDict = {}

tasks = ['segment_semantic', 'normal', 'depth_zbuffer', 'keypoints2d', 'edge_texture']
task_cls_num = {'segment_semantic': 17, 'normal': 3, 'depth_zbuffer': 1, 'keypoints2d': 1, 'edge_texture': 1}
for task in tasks:
    headsDict[task] = ASPPHeadNode(512, task_cls_num[task])

    # For model trainer
    dataset = Taskonomy(dataroot, 'train', task, crop_h=224, crop_w=224)
    trainDataloaderDict[task] = DataLoader(dataset, batch_size=16, shuffle=True)

    dataset = Taskonomy(dataroot, 'test_small', task, crop_h=224, crop_w=224)
    valDataloaderDict[task] = DataLoader(dataset, batch_size=1, shuffle=False)

    criterionDict[task] = TaskonomyCriterions(task, dataroot)
    metricDict[task] = TaskonomyMetrics(task, dataroot)

# Model

In [3]:
prototxt = 'models/deeplab_resnet34_adashare.prototxt'
# prototxt = 'models/mobilenetv2.prototxt' # the input dim of heads should be changed to 1280
# prototxt = 'models/mnasnet.prototxt' # the input dim of heads should be changed to 1280
mtlmodel = MTLModel(prototxt, headsDict)
mtlmodel = mtlmodel.cuda()

# Train

In [None]:
checkpoint = 'checkpoint/'
trainer = Trainer(mtlmodel, trainDataloaderDict, valDataloaderDict, criterionDict, metricDict, 
                  print_iters=100, val_iters=500, save_num=1, policy_update_iters=100)

### pre_train

In [None]:
trainer.pre_train(iters=10000, lr=0.0001, savePath=checkpoint+'Cityscapes/')

### alter_train

In [None]:
loss_lambda = {'segment_semantic': 1, 'depth_zbuffer': 1, 'policy':0.0005}
trainer.alter_train_with_reg(iters=20000, policy_network_iters=(100,400), policy_lr=0.01, network_lr=0.0001, 
                             loss_lambda=loss_lambda,
                             savePath=checkpoint+'Cityscapes/')

### sample policy from trained policy distribution and save

In [None]:
policy_list = {'segment_semantic': [], 'depth_zbuffer': []}
name_list = {'segment_semantic': [], 'depth_zbuffer': []}

In [None]:
for name, param in mtlmodel.named_parameters():
    if 'policy' in name and not torch.eq(param, torch.tensor([0., 0., 0.]).cuda()).all():
        if 'segment_semantic' in name:
            policy_list['segment_semantic'].append(param.data.cpu().detach().numpy())
            name_list['segment_semantic'].append(name)
        elif 'depth_zbuffer' in name:
            policy_list['depth_zbuffer'].append(param.data.cpu().detach().numpy())
            name_list['depth_zbuffer'].append(name)

In [None]:
sample_policy_dict = OrderedDict()
for task in tasks:
    for name, policy in zip(name_list[task], policy_list[task]):
        distribution = softmax(policy, axis=-1)
        distribution /= sum(distribution)
        choice = np.random.choice((0,1,2), p=distribution)
        if choice == 0:
            sample_policy_dict[name] = torch.tensor([1.0,0.0,0.0]).cuda()
        elif choice == 1:
            sample_policy_dict[name] = torch.tensor([0.0,1.0,0.0]).cuda()
        elif choice == 2:
            sample_policy_dict[name] = torch.tensor([0.0,0.0,1.0]).cuda()

In [None]:
sample_state = {'state_dict': sample_policy_dict}
torch.save(sample_state, sample_path + 'sample_policy.model')

### post train from scratch

In [None]:
loss_lambda = {'segment_semantic': 1, 'depth_zbuffer': 1}
trainer.post_train(iters=40000, lr=0.001, 
                   decay_lr_freq=4000, decay_lr_rate=0.5,
                   loss_lambda=loss_lambda,
                   savePath=checkpoint+'Cityscapes/', reload='sample_policy.model')

### get the validation results in the paper 

In [None]:
mtlmodel.load_state_dict(torch.load('CityScapes.model'))
trainer.validate('mtl', hard=True) 

# Inference

In [4]:
mtlmodel.load_state_dict(torch.load('/mnt/nfs/work1/huiguan/lijunzhang/policymtl/checkpoint/Github-Models/CityScapes.model'))

<All keys matched successfully>

In [27]:
mtlmodel.load_state_dict(torch.load('/mnt/nfs/work1/huiguan/lijunzhang/policymtl/checkpoint/Github-Models/NYUv2.model'))

<All keys matched successfully>

In [5]:
mtlmodel.load_state_dict(torch.load('/mnt/nfs/work1/huiguan/lijunzhang/policymtl/checkpoint/Github-Models/Taskonomy.model'))

<All keys matched successfully>

In [None]:
mtlmodel.load_state_dict(torch.load('CityScapes.model'))
output = mtlmodel(x, task='segment_semantic', hard=True)

In [5]:
from framework.layer_node import Conv2dNode

In [7]:
from framework.layer_node import Conv2dNode
for name, module in mtlmodel.named_modules():
    if isinstance(module, Conv2dNode):
        print(name)
        print(module)

net.0
Conv2dNode(
  (taskOp): ModuleDict(
    (segment_semantic): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (depth_zbuffer): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  )
  (dsOp): ModuleDict(
    (segment_semantic): ModuleList(
      (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LazyLayer()
    )
    (depth_zbuffer): ModuleList(
      (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LazyLayer()
    )
  )
  (policy): ParameterDict(
      (segment_semantic): Parameter containing: [torch.cuda.FloatTensor of size 3 (GPU 0)]
      (depth_zbuffer): Parameter containing: [torch.cuda.FloatTensor of size 3 (GPU 0)]
  )
  (basicOp): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padd

In [6]:
print(mtlmodel)

MTLModel(
  (headsDict): ModuleDict(
    (segment_semantic): ASPPHeadNode(
      (fc1): Classification_Module(
        (conv1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6))
        (conv2): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
        (conv3): Conv2d(1024, 19, kernel_size=(1, 1), stride=(1, 1))
        (relu): ReLU(inplace=True)
        (dropout): Dropout(p=0.5, inplace=False)
      )
      (fc2): Classification_Module(
        (conv1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(12, 12), dilation=(12, 12))
        (conv2): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
        (conv3): Conv2d(1024, 19, kernel_size=(1, 1), stride=(1, 1))
        (relu): ReLU(inplace=True)
        (dropout): Dropout(p=0.5, inplace=False)
      )
      (fc3): Classification_Module(
        (conv1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(18, 18), dilation=(18, 18))
        (conv2): Conv2d(1024, 1024, 

In [36]:
import time
times = []
count = 0
for i, data in enumerate(valDataloaderDict[task]):
    if count < 50:
        count += 1
        x = data['input'].cuda()
        
        temp = 0
        for task in tasks:
            with torch.no_grad():
                start_time = time.time()
                output = mtlmodel(x, task=task, hard=True)
                end_time = time.time()
                temp += end_time-start_time
    else:
        break
    times.append(temp)

In [37]:
from statistics import mean
mean(times[1:])

0.08018125319967465

In [38]:
times

[0.10284781455993652,
 0.09690117835998535,
 0.09766864776611328,
 0.08123469352722168,
 0.0810401439666748,
 0.08152437210083008,
 0.07942008972167969,
 0.09122633934020996,
 0.0755317211151123,
 0.08093833923339844,
 0.07398033142089844,
 0.08546662330627441,
 0.07645416259765625,
 0.07405638694763184,
 0.07825493812561035,
 0.0767514705657959,
 0.08428430557250977,
 0.07411432266235352,
 0.0762331485748291,
 0.07668161392211914,
 0.07435083389282227,
 0.08052802085876465,
 0.07570457458496094,
 0.07380533218383789,
 0.07825303077697754,
 0.07603764533996582,
 0.07615303993225098,
 0.07601499557495117,
 0.0779409408569336,
 0.0773916244506836,
 0.07851552963256836,
 0.08067774772644043,
 0.08235836029052734,
 0.08546257019042969,
 0.07727241516113281,
 0.08809280395507812,
 0.0784299373626709,
 0.07792258262634277,
 0.07626867294311523,
 0.07564425468444824,
 0.08414697647094727,
 0.07849979400634766,
 0.08218908309936523,
 0.07685041427612305,
 0.07764649391174316,
 0.07741165161132

In [36]:
from statistics import mean
mean(times)

0.06985962390899658

In [22]:
times

[0.060842037200927734,
 0.05027294158935547,
 0.04941534996032715,
 0.04808211326599121,
 0.04568982124328613,
 0.04546856880187988,
 0.05202603340148926,
 0.047086238861083984,
 0.0366673469543457,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.03671693801879883,
 0.036716

In [None]:
0.060842037200927734