In [15]:
%load_ext autoreload
%autoreload 2
import numpy as np
import os
import sys
sys.path.append('/home/yiminghuang/AutoMTL')
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
# import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from mtl_pytorch.layer_node import Conv2dNode, BN2dNode
from mtl_pytorch.base_node import BasicNode

from mtl_pytorch.trainer import Trainer
from data.heads.pixel2pixel import ASPPHeadNode

from data.dataloader.cityscapes_dataloader import CityScapes
from data.metrics.pixel2pixel_loss import CityScapesCriterions
from data.metrics.pixel2pixel_metrics import CityScapesMetrics

from mobilenetv2 import mobilenet_v2

from mtl_model import mtl_model

from data.dataloader.nyuv2_dataloader import NYU_v2
from data.metrics.pixel2pixel_loss import NYUCriterions
from data.metrics.pixel2pixel_metrics import NYUMetrics

from data.dataloader.taskonomy_dataloader import Taskonomy
from data.metrics.pixel2pixel_loss import TaskonomyCriterions
from data.metrics.pixel2pixel_metrics import TaskonomyMetrics

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Data

In [28]:
print(isinstance(Conv2dNode(1, 1, 1), BasicNode))


False


In [2]:
print(torch.cuda.is_available())
dataroot = 'datasets/Cityscapes/'

headsDict = nn.ModuleDict()
trainDataloaderDict = {}
valDataloaderDict = {}
criterionDict = {}
metricDict = {}

tasks = ['segment_semantic', 'depth_zbuffer']
task_cls_num = {'segment_semantic': 19, 'depth_zbuffer': 1}
for task in tasks:
    headsDict[task] = ASPPHeadNode(1280, task_cls_num[task])

    # For model trainer
    dataset = CityScapes(dataroot, 'train', task, crop_h=224, crop_w=224)
    trainDataloaderDict[task] = DataLoader(dataset, 16, shuffle=True)

    dataset = CityScapes(dataroot, 'test', task)
    valDataloaderDict[task] = DataLoader(dataset, 16, shuffle=True)

    criterionDict[task] = CityScapesCriterions(task)
    metricDict[task] = CityScapesMetrics(task)

False


# Model

In [5]:
mtlmodel = mobilenet_v2(False, heads_dict=headsDict)

# Train

In [7]:
checkpoint = 'checkpoint/'
trainer = Trainer(mtlmodel, trainDataloaderDict, valDataloaderDict, criterionDict, metricDict, 
                  print_iters=100, val_iters=500, save_num=1, policy_update_iters=100)

### pre_train

In [9]:
trainer.pre_train(iters=1, lr=0.0001, savePath=checkpoint+'Cityscapes/')

100%|██████████| 1/1 [00:08<00:00,  8.78s/it]


### alter_train

In [10]:
loss_lambda = {'segment_semantic': 1, 'depth_zbuffer': 1, 'policy':0.0005}
trainer.alter_train_with_reg(iters=1, policy_network_iters=(100,400), policy_lr=0.01, network_lr=0.0001,
                             loss_lambda=loss_lambda,
                             savePath=checkpoint+'Cityscapes/')

100%|██████████| 1/1 [00:06<00:00,  6.99s/it]


### sample policy from trained policy distribution and save

In [11]:
policy_list = {'segment_semantic': [], 'depth_zbuffer': []}
name_list = {'segment_semantic': [], 'depth_zbuffer': []}

In [13]:
for name, param in mtlmodel.named_parameters():
    if 'policy' in name :
        print(name)
        if 'segment_semantic' in name:
            policy_list['segment_semantic'].append(param.data.cpu().detach().numpy())
            name_list['segment_semantic'].append(name)
        elif 'depth_zbuffer' in name:
            policy_list['depth_zbuffer'].append(param.data.cpu().detach().numpy())
            name_list['depth_zbuffer'].append(name)

features.models.0.models.0.policy.segment_semantic
features.models.0.models.0.policy.depth_zbuffer
features.models.1.conv.models.0.policy.segment_semantic
features.models.1.conv.models.0.policy.depth_zbuffer
features.models.1.conv.models.3.policy.segment_semantic
features.models.1.conv.models.3.policy.depth_zbuffer
features.models.2.conv.models.0.policy.segment_semantic
features.models.2.conv.models.0.policy.depth_zbuffer
features.models.2.conv.models.3.policy.segment_semantic
features.models.2.conv.models.3.policy.depth_zbuffer
features.models.2.conv.models.6.policy.segment_semantic
features.models.2.conv.models.6.policy.depth_zbuffer
features.models.3.conv.models.0.policy.segment_semantic
features.models.3.conv.models.0.policy.depth_zbuffer
features.models.3.conv.models.3.policy.segment_semantic
features.models.3.conv.models.3.policy.depth_zbuffer
features.models.3.conv.models.6.policy.segment_semantic
features.models.3.conv.models.6.policy.depth_zbuffer
features.models.4.conv.models

In [10]:
from collections import OrderedDict
from scipy.special import softmax

sample_policy_dict = OrderedDict()
for task in tasks:
    for name, policy in zip(name_list[task], policy_list[task]):
        # distribution = softmax(policy, axis=-1
        distribution = softmax(policy, axis=-1)
        distribution /= sum(distribution)

        choice = np.random.choice((0,1,2), p=distribution)
        if choice == 0:
            sample_policy_dict[name] = torch.tensor([1.0,0.0,0.0])
        elif choice == 1:
            sample_policy_dict[name] = torch.tensor([0.0,1.0,0.0])
        elif choice == 2:
            sample_policy_dict[name] = torch.tensor([0.0,0.0,1.0])

In [11]:
sample_path = 'checkpoint/CityScapes/'
# app.run(debug=True, use_reloader=False)
sample_state = {'state_dict': sample_policy_dict}
torch.save(sample_state, sample_path + 'sample_policy.model')

FileNotFoundError: [Errno 2] No such file or directory: 'checkpoint/CityScapes/sample_policy.model'

### post train from scratch

In [None]:
loss_lambda = {'segment_semantic': 1, 'depth_zbuffer': 1}
trainer.post_train(iters=1, lr=0.001,
                   decay_lr_freq=4000, decay_lr_rate=0.5,
                   loss_lambda=loss_lambda,
                   savePath=checkpoint+'Cityscapes/',
                   reload='sample_policy.model')

### get the validation results in the paper

In [16]:
# for name, param in mtlmodel.named_parameters():
#     print(name, param)
import platform
print(platform.platform())
trainer.load_model('', 'CityScapes.model')
# mtlmodel.load_state_dict(torch.load('CityScapes.model', map_location=torch.device('cpu')))
trainer.validate('mtl', hard=True) 

Linux-3.10.0-1160.36.2.el7.x86_64-x86_64-with-centos-7.9.2009-Core


KeyboardInterrupt: 

# Inference

In [17]:
x = torch.from_numpy(np.random.random([32, 3, 3, 3])).float()
# mtlmodel.load_state_dict(torch.load('CityScapes.model', map_location=torch.device('cpu')))
output = mtlmodel(x, task='segment_semantic', hard=True)
print(output)

tensor([[[[ 9.0390e-01]],

         [[-2.6199e-01]],

         [[ 6.9180e-01]],

         [[-2.5097e-01]],

         [[-1.6299e-01]],

         [[-1.0068e-01]],

         [[-1.2644e-01]],

         [[-3.3990e-02]],

         [[ 3.6126e-01]],

         [[-2.3961e-01]],

         [[-1.3378e-01]],

         [[-5.0832e-02]],

         [[-3.5709e-02]],

         [[ 1.7853e-01]],

         [[-1.2255e-01]],

         [[-3.1906e-01]],

         [[-1.0427e-01]],

         [[-2.2184e-01]],

         [[-3.1533e-01]]],


        [[[ 9.0334e-01]],

         [[-4.0676e-01]],

         [[ 3.4300e-01]],

         [[ 5.5431e-04]],

         [[-2.5747e-01]],

         [[-2.0981e-01]],

         [[-4.8425e-01]],

         [[-3.8828e-01]],

         [[ 5.5483e-01]],

         [[-1.8591e-01]],

         [[-9.5170e-02]],

         [[-3.3478e-01]],

         [[-1.0877e-01]],

         [[ 1.8385e-01]],

         [[ 2.7789e-03]],

         [[ 4.4495e-02]],

         [[-2.6155e-01]],

         [[-1.4751e-03]],
