In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import sys
import time
import glob
import logging
import torch
import numpy as np
from tqdm import tqdm
from thop import profile
from random import shuffle
import torch.nn as nn
import torch.utils
import torchvision
from torch.utils.tensorboard import SummaryWriter
from config_search import config
from dataloader import get_train_loader, get_valid_loader
from tools.datasets import Cityscapes
from architect_lbt import Architect
from model_search import Network_Multi_Path as Network
from model_seg import Network_Multi_Path_Infer
from utils.darts_utils import create_exp_dir, save, plot_op, plot_path_width, objective_acc_lat
from utils.init_func import init_weight
from eval import SegEvaluator
from matplotlib import pyplot as plt



In [2]:
model = Network(config.num_classes, config.layers, Fch=config.Fch, width_mult_list=config.width_mult_list, prun_modes=['max'], stem_head_width=config.stem_head_width)
# model = model.cuda()
student = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False, progress=True, num_classes=config.num_classes, aux_loss=None)
# student = student.cuda()

In [3]:
architect = Architect(model, student)

init_weight(model, nn.init.kaiming_normal_, nn.BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu')
data_setting = {'img_root': config.img_root_folder,
                    'gt_root': config.gt_root_folder,
                    'train_source': config.train_source,
                    'eval_source': config.eval_source,
                    'down_sampling': config.down_sampling}
index_select = list(range(config.num_train_imgs))
shuffle(index_select)  # shuffle to make sure balanced dataset split
train_loader_model = get_train_loader(config, Cityscapes, portion=config.train_portion, index_select=index_select)
train_loader_arch = get_train_loader(config, Cityscapes, portion=config.train_portion-1, index_select=index_select)
evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None), config.num_classes, config.image_mean,
                             config.image_std, model, config.eval_scale_array, config.eval_flip, 0, config=config,
                             verbose=False, save_path=None, show_image=False)
valid_loader = get_valid_loader(config, Cityscapes)

using downsampling: 2
Found 1487 images
using downsampling: 2
Found 1488 images
using downsampling: 2
Found 500 images
using downsampling: 2
Found 500 images


In [4]:
base_lr = config.lr
parameters = []
parameters += list(model.stem.parameters())
parameters += list(model.cells.parameters())
parameters += list(model.refine32.parameters())
parameters += list(model.refine16.parameters())
parameters += list(model.head0.parameters())
parameters += list(model.head1.parameters())
parameters += list(model.head2.parameters())
parameters += list(model.head02.parameters())
parameters += list(model.head12.parameters())
optimizer = torch.optim.SGD(
    parameters,
    lr=base_lr,
    momentum=config.momentum,
    weight_decay=config.weight_decay)
optimizer_stud = torch.optim.SGD(student.parameters(),lr=base_lr,
    momentum=config.momentum,
    weight_decay=config.weight_decay)
# lr policy ##############################
lr_policy = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.978)
lr_policy_stud = torch.optim.lr_scheduler.ExponentialLR(optimizer_stud, 0.978)


In [5]:
minibatch = iter(train_loader_model).next()
imgs = minibatch['data']
labels = minibatch['label']
minibatch_val = iter(train_loader_arch).next()
imgs_val = minibatch_val['data']
labels_val = minibatch_val['label']
minibatch_unlabeled = iter(valid_loader).next()
imgs_unlabeled = minibatch_unlabeled['data']
labels_unlabeled = minibatch_unlabeled['label']

In [6]:
criterion = nn.CrossEntropyLoss()
out = student(imgs)
logits = nn.functional.interpolate(out['out'],size=(224//8,448//8))
loss = criterion(logits, labels)

In [None]:
loss_model = model._loss(imgs, labels)

In [None]:
architect.step(imgs, labels, imgs_val, labels_val, imgs_unlabeled, base_lr, optimizer, unrolled=True)
architect.step1(imgs, labels, imgs_val, labels_val, imgs_unlabeled, base_lr, optimizer, optimizer_stud, unrolled=True)
