Skip to content

Commit

Permalink
replace upx8 by convtranpose, 4% performance boost to 79%
Browse files Browse the repository at this point in the history
  • Loading branch information
windyrobin committed May 30, 2019
1 parent 1a420d4 commit bd3085e
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 25 deletions.
18 changes: 10 additions & 8 deletions furnace/datasets/cityscapes/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@ def get_class_colors(*args):
return [[128, 64, 128], [244, 35, 232], [70, 70, 70],
[102, 102, 156], [190, 153, 153], [153, 153, 153],
[250, 170, 30], [220, 220, 0], [107, 142, 35],
[152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0],
[0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
[0, 0, 230], [119, 11, 32]]
[152, 251, 152], [70, 130, 180], [220, 20, 60],
[255, 0, 0],[0, 0, 142], [0, 0, 70],
[0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]]

@classmethod
def get_class_names(*args):
return ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
'traffic light', 'traffic sign',
'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
'truck', 'bus', 'train', 'motorcycle', 'bicycle']
return ['road', 'sidewalk', 'building',
'wall', 'fence', 'pole',
'traffic light', 'traffic sign','vegetation',
'terrain', 'sky', 'person',
'rider', 'car','truck',
'bus', 'train', 'motorcycle', 'bicycle']

@classmethod
def transform_label(cls, pred, name):
Expand All @@ -39,4 +41,4 @@ def transform_label(cls, pred, name):
new_name = (name.split('.')[0]).split('_')[:-1]
new_name = '_'.join(new_name) + '.png'

return label, new_name
return label, new_name
14 changes: 9 additions & 5 deletions furnace/engine/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from utils.pyt_utils import load_model, link_file, ensure_dir
from utils.img_utils import pad_image_to_shape, normalize

import tensorrt as trt
import example
import common
import pycuda.driver as cuda
import pycuda.autoinit
#import tensorrt as trt
#import example
#import common
#import pycuda.driver as cuda
#import pycuda.autoinit

logger = get_logger()

Expand Down Expand Up @@ -252,7 +252,11 @@ def val_func_process(self, input_data, device=None):
dtype=np.float32)
print("inputs:")
print('input shape:', input_data.shape)
print('self.iter:', self.iter)
self.print_statics(input_data)
input_data.tofile('sample_bin/np_' + str(self.iter) + '.bin')
self.iter = self.iter + 1

input_data = torch.FloatTensor(input_data).cuda(device)

with torch.cuda.device(input_data.get_device()):
Expand Down
7 changes: 7 additions & 0 deletions furnace/seg_opr/loss_opr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,19 @@ def __init__(self, ignore_label, reduction='mean', thresh=0.6, min_kept=256,
ignore_index=ignore_label)

def forward(self, pred, target):
#print('pred shape:')
#print(pred.shape)
#print('target shape:')
#print(target.shape)

b, c, h, w = pred.size()
target = target.view(-1)
valid_mask = target.ne(self.ignore_label)
target = target * valid_mask.long()
num_valid = valid_mask.sum()

#print('num valid:')
#print(num_valid)
prob = F.softmax(pred, dim=1)
prob = (prob.transpose(0, 1)).reshape(c, -1)

Expand Down
37 changes: 37 additions & 0 deletions furnace/seg_opr/seg_oprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,43 @@ def forward(self, x):
#print ('running weight :')
#print (self.bn.weight)
return [x, x1, x2]

class DeConvBnRelu(nn.Module):
def __init__(self, in_planes, out_planes, ksize, stride, pad, output_pad, dilation=1,
groups=1, has_bn=True, norm_layer=nn.BatchNorm2d, bn_eps=1e-5,
has_relu=True, inplace=True, has_bias=False, debug=False):
super(DeConvBnRelu, self).__init__()
self.conv = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=ksize,
stride=stride, padding=pad, output_padding=output_pad,
dilation=dilation, groups=groups, bias=has_bias)
self.has_bn = has_bn
if self.has_bn:
self.bn = norm_layer(out_planes, eps=bn_eps)
self.has_relu = has_relu
if self.has_relu:
self.relu = nn.ReLU(inplace=inplace)
self.debug = debug

def forward(self, x):
if self.debug == False:
x = self.conv(x)
if self.has_bn:
x = self.bn(x)
if self.has_relu:
x = self.relu(x)

return x
else:
x = self.conv(x)
x1 = self.bn(x)
x2 = self.relu(x1)
#print ('running meaning :')
#print (self.bn.running_mean)
#print ('running var :')
#print (self.bn.running_var)
#print ('running weight :')
#print (self.bn.weight)
return [x, x1, x2]



Expand Down
2 changes: 1 addition & 1 deletion furnace/utils/img_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,6 @@ def normalize(img, mean, std):
# pytorch pretrained model need the input range: 0-1
new_img = img.astype(np.float32) / 255.0
new_img = new_img - mean
# img = img / std
img = img / std

return new_img
6 changes: 4 additions & 2 deletions furnace/utils/init_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
def __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum,
**kwargs):
for name, m in feature.named_modules():
if isinstance(m, (nn.Conv2d, nn.Conv3d)):
#if isinstance(m, (nn.Conv2d, nn.Conv3d)):
if isinstance(m, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)):
conv_init(m.weight, **kwargs)
elif isinstance(m, norm_layer):
m.eps = bn_eps
Expand Down Expand Up @@ -39,7 +40,8 @@ def group_weight(weight_group, module, norm_layer, lr):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, (nn.Conv2d, nn.Conv3d)):
#elif isinstance(m, (nn.Conv2d, nn.Conv3d)):
elif isinstance(m, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
Expand Down
2 changes: 1 addition & 1 deletion model/bisenet/cityscapes.bisenet.R18.speed/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def func_per_iteration(self, data, device):

network = BiSeNet(config.num_classes, is_training=False,
criterion=None, ohem_criterion=None)
dataset = TestData('./fe_test')
dataset = TestData('./sample_test')

if args.speed_test:
device = all_dev[0]
Expand Down
80 changes: 72 additions & 8 deletions model/bisenet/cityscapes.bisenet.R18.speed/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
from config import config
from base_model import resnet18
from seg_opr.seg_oprs import ConvBnRelu, AttentionRefinement, FeatureFusion
from seg_opr.seg_oprs import ConvBnRelu, DeConvBnRelu, AttentionRefinement, FeatureFusion


def get():
Expand Down Expand Up @@ -50,15 +50,24 @@ def __init__(self, out_planes, is_training,
has_relu=True, has_bias=False)]

if is_training:
heads = [BiSeNetHead(conv_channel, out_planes, 16,
#heads = [BiSeNetHead(conv_channel, out_planes, 16,
# True, norm_layer),
# BiSeNetHead(conv_channel, out_planes, 8,
# True, norm_layer),
# BiSeNetHead(conv_channel * 2, out_planes, 8,
# False, norm_layer)]
heads = [XmHead(conv_channel, out_planes, 16,
True, norm_layer),
BiSeNetHead(conv_channel, out_planes, 8,
XmHead(conv_channel, out_planes, 8,
True, norm_layer),
BiSeNetHead(conv_channel * 2, out_planes, 8,
XmHead(conv_channel * 2, out_planes, 8,
False, norm_layer)]
else:
#heads = [None, None,
# BiSeNetHead(conv_channel * 2, out_planes, 8,
# False, norm_layer)]
heads = [None, None,
BiSeNetHead(conv_channel * 2, out_planes, 8,
XmHead(conv_channel * 2, out_planes, 8,
False, norm_layer)]

self.ffm = FeatureFusion(conv_channel * 2, conv_channel * 2,
Expand Down Expand Up @@ -119,11 +128,12 @@ def forward(self, data, label=None):
pred_out.append(concate_fm)

if self.is_training:
aux_loss0 = self.ohem_criterion(self.heads[0](pred_out[0]), label)
aux_loss1 = self.ohem_criterion(self.heads[1](pred_out[1]), label)
#aux_loss0 = self.ohem_criterion(self.heads[0](pred_out[0]), label)
#aux_loss1 = self.ohem_criterion(self.heads[1](pred_out[1]), label)
main_loss = self.ohem_criterion(self.heads[-1](pred_out[2]), label)

loss = main_loss + aux_loss0 + aux_loss1
loss = main_loss
#loss = main_loss + aux_loss0 + aux_loss1
return loss

head_out = self.heads[-1](pred_out[-1])
Expand Down Expand Up @@ -215,6 +225,60 @@ def forward(self, x):

return output

class XmHead(nn.Module):
def __init__(self, in_planes, out_planes, scale,
is_aux=False, norm_layer=nn.BatchNorm2d):
super(XmHead, self).__init__()
if is_aux:
#assert(scale == 16 and in_planes ==256)
dconv0 = DeConvBnRelu(in_planes, 128, 3, 2, 1, 1,
has_bn=True, norm_layer=norm_layer,
has_relu=True, has_bias=False)

dconv1 = DeConvBnRelu(128, 64, 3, 2, 1, 1,
has_bn=True, norm_layer=norm_layer,
has_relu=True, has_bias=False)
dconv2 = DeConvBnRelu(64, 32, 3, 2, 1, 1,
has_bn=True, norm_layer=norm_layer,
has_relu=True, has_bias=False)

deconv_arr = [dconv0, dconv1, dconv2]
else:
assert(scale == 8 and in_planes == 256)
dconv0 = DeConvBnRelu(in_planes, 128, 3, 2, 1, 1,
has_bn=True, norm_layer=norm_layer,
has_relu=True, has_bias=False)

dconv1 = DeConvBnRelu(128, 64, 3, 2, 1, 1,
has_bn=True, norm_layer=norm_layer,
has_relu=True, has_bias=False)
dconv2 = DeConvBnRelu(64, 32, 3, 2, 1, 1,
has_bn=True, norm_layer=norm_layer,
has_relu=True, has_bias=False)

deconv_arr = [dconv0, dconv1, dconv2]

self.deconv_3x3_arr = nn.ModuleList(deconv_arr)
# self.dropout = nn.Dropout(0.1)
if is_aux:
self.conv_1x1 = nn.Conv2d(32, out_planes, kernel_size=1,
stride=1, padding=0)
else:
self.conv_1x1 = nn.Conv2d(32, out_planes, kernel_size=1,
stride=1, padding=0)
self.scale = scale
self.in_planes = in_planes

def forward(self, x):
#aux not supported now
#if self.scale == 8:
fm = self.deconv_3x3_arr[0](x)
fm = self.deconv_3x3_arr[1](fm)
fm = self.deconv_3x3_arr[2](fm)
# fm = self.dropout(fm)
output = self.conv_1x1(fm)
return output


if __name__ == "__main__":
model = BiSeNet(22, None)
Expand Down
2 changes: 2 additions & 0 deletions model/bisenet/cityscapes.bisenet.R18.speed/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DataParallelModel(model, device_ids=engine.devices)
print('model to device:')
print(device)
model.to(device)

engine.register_state(dataloader=train_loader, model=model,
Expand Down

0 comments on commit bd3085e

Please sign in to comment.