Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Updated training code implemented in our IJCAI-2018 paper and the cor…
…responding readme.
- Loading branch information
Showing
8 changed files
with
807 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
import sys | ||
import argparse | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
def read_file(file_name): | ||
file=open(file_name) | ||
contents=file.readlines() | ||
train_result_dict={} | ||
val_result_dict={} | ||
for oneline in contents: | ||
array=oneline.split() | ||
if len(array)<4: | ||
continue | ||
if len(array)<6: | ||
epoch=int(array[3].replace(']','').split('[')[-1])+1 | ||
if not val_result_dict.has_key(epoch): | ||
val_result_dict[epoch]={} | ||
val_result_dict[epoch]['loss']=[] | ||
val_result_dict[epoch]['top1']=[] | ||
val_result_dict[epoch]['top5']=[] | ||
key=array[4][:14] | ||
if key=='Validation-cro': | ||
val_result_dict[epoch]['loss'].append(float(array[4].split('=')[-1])) | ||
elif key=='Validation-acc': | ||
val_result_dict[epoch]['top1'].append(float(array[4].split('=')[-1])) | ||
elif key=='Validation-top': | ||
val_result_dict[epoch]['top5'].append(float(array[4].split('=')[-1])) | ||
elif array[5]=='[5000]': | ||
epoch=int(array[3].replace(']','').split('[')[-1])+1 | ||
if not val_result_dict.has_key(epoch): | ||
train_result_dict[epoch]={} | ||
train_result_dict[epoch]['loss']=[] | ||
train_result_dict[epoch]['top1']=[] | ||
train_result_dict[epoch]['top5']=[] | ||
scale= 0.01 if float(array[-1].replace(')','').split('(')[-1])>1 else 1 | ||
train_result_dict[epoch]['loss'].append(float(array[9].replace(')','').split('(')[-1])) | ||
train_result_dict[epoch]['top1'].append(scale*float(array[11].replace(')','').split('(')[-1])) | ||
train_result_dict[epoch]['top5'].append(scale*float(array[-1].replace(')','').split('(')[-1])) | ||
return train_result_dict,val_result_dict | ||
|
||
def get_paper(style='acc'): | ||
x=[1,5,10,15,20,25,30, | ||
31,35,40,45,50,55,60, | ||
61,65,70,75,80,85,90] | ||
top1_val_coarse=[89.3,46.5,45,44.3,42.8,42,41.7, | ||
32, 29, 29,29.2,29,28.8,29, | ||
26.2,26.1,26,25.8,25.6,25.0,24.7]; | ||
top1_train_coarse=[95,55,50,48,47,46.5,46, | ||
34,32,30.5,30,30,30,30, | ||
24,23.5,22.8,21,20.5,20.3,20.1]; | ||
xvals = np.linspace(1, 90, 90) | ||
top1_val_fit = np.interp(xvals, x, top1_val_coarse) | ||
top1_train_fit = np.interp(xvals, x, top1_train_coarse) | ||
#return | ||
bias=0.0 | ||
scale=1.0 | ||
best=np.argmin | ||
if style=='acc': | ||
bias=100.0 | ||
scale*=-1.0 | ||
best=np.argmax | ||
top1_val=[bias+scale*value for value in top1_val_fit] | ||
top5_val=[bias+scale*7.8]*len(top1_val) | ||
top1_train=[bias+scale*value for value in top1_train_fit] | ||
top5_train=[-1.0]*len(top1_train) | ||
return top1_val,top5_val,top1_train,top5_train,best | ||
|
||
def print_result(result_dict,prefix,style='acc'): | ||
top1=[] | ||
top5=[] | ||
loss=[] | ||
bias=0.0 | ||
scale=100.0 | ||
best=np.argmax | ||
if style=='err': | ||
bias=100.0 | ||
scale=-100.0 | ||
best=np.argmin | ||
length=len(result_dict) | ||
total_epochs=100 | ||
length=total_epochs if length>total_epochs else length | ||
for epoch in range(1,length+1): | ||
res=result_dict[epoch] | ||
top1.append(bias+scale*res['top1'][-1]) | ||
top5.append(bias+scale*res['top5'][-1]) | ||
loss.append(res['loss'][-1]) | ||
# print(' * %s epoch # %02d top1: %7.3f top5: %7.3f loss: %7.3f'%\ | ||
# (prefix, epoch,top1[epoch-1],top5[epoch-1],loss[epoch-1])) | ||
return top1,top5,loss,best | ||
|
||
def main(argv): | ||
parser = argparse.ArgumentParser() | ||
# Required arguments: input and output files. | ||
parser.add_argument( | ||
"--net", | ||
default='all', | ||
help="Network to show: plain, origin, fuse[1-3], and all." + | ||
"For example: 'python script.py fuse3'." | ||
) | ||
parser.add_argument( | ||
"--style", | ||
default='acc', | ||
help="Style to show the accuracy: acc or err (default: acc)." + | ||
"For example: accuracy=100.0%% for 'acc' and err=0.0%% for 'err'." | ||
) | ||
parser.add_argument( | ||
"--dir", | ||
default='snapshot', | ||
help="Directory of the result txt file (default: snapshot)." | ||
) | ||
parser.add_argument( | ||
"--plot", | ||
default='val', | ||
help="plot train or val, or both of them (default: all)." | ||
) | ||
args = parser.parse_args() | ||
#whole | ||
if '50_4' in args.net: | ||
networks=['dfn-mr_50_4gpu','resnet_50_4gpu'] | ||
elif '50' in args.net: | ||
networks=['dfn-mr_50_8gpu','resnet_50_8gpu'] | ||
elif '101' in args.net: | ||
networks=['dfn-mr_101_8gpu','resnet_101_8gpu'] | ||
#'dfn-mr_50_8gpu','resnet_50_8gpu','origin_resnet_50' | ||
#'dfn-mr_50_4gpu','resnet_50_4gpu','origin_resnet_50' | ||
#'dfn-mr_101_8gpu','resnet_101_8gpu','origin_resnet_101'] | ||
else: | ||
networks=['dfn-mr_101_8gpu','resnet_101_8gpu'] | ||
#figure params | ||
colors=['black','blue','orange','green','red','cyan','pink'] | ||
plt.figure() | ||
#one net | ||
print args.net | ||
import glob | ||
all_logs=glob.glob('snapshot/*/*/*.txt') | ||
for idx in range(len(networks)): | ||
net=networks[idx] | ||
if 'paper' not in net: | ||
for log_name in all_logs: | ||
if net in log_name: | ||
file_name=log_name | ||
print file_name | ||
train_res,val_res=read_file(file_name) | ||
else: | ||
top1_val,top5_val,top1_train,top5_train,best=get_paper(args.style) | ||
if 'middle' in net: | ||
net=net.replace('middle','dfn-mr') | ||
title='(solid lines: 1-crop val error; dashed lines: training error)' | ||
if args.plot=='train' or args.plot=='all': | ||
if net!='paper': | ||
top1,top5,loss,best=print_result(train_res,'Finished',args.style) | ||
report_idx=-1#best(top1) | ||
plot1=plt.plot(range(1,len(top1)+1),top1,color=colors[idx], linestyle=':', linewidth=2.0, alpha=1 if args.plot=='train' else 0.5, | ||
marker='o' if args.plot=='train' else None, | ||
label='%6s top1:%7.2f%% top5:%7.2f%%'%(net,top1[report_idx],top5[report_idx]) if args.plot=='train' else None) | ||
|
||
if args.plot=='val' or args.plot=='all': | ||
if net!='paper': | ||
top1,top5,loss,best=print_result(val_res,'Finished',args.style) | ||
elif net=='paper': | ||
top1=top1_val | ||
top5=top5_val | ||
if args.plot=='val': | ||
title='(validation error)' | ||
report_idx=-1#best(top1) | ||
plot2=plt.plot(range(1,len(top1)+1),top1,color=colors[idx], linestyle='-', | ||
marker='o' if args.plot=='val' else None,linewidth=2.0, | ||
label='%6s (%d) top1:%7.2f%% top5:%7.2f%%'%(net,len(top1),top1[report_idx],top5[report_idx])) | ||
#other figure style | ||
plt.ylim([15,90]) | ||
plt.xticks(np.arange(0, 50, 10)) | ||
#title | ||
plt.title('Performance of different architectures on ImageNet %s'%(title),fontsize=28) | ||
plt.ylabel('top1 error',fontsize=24) | ||
plt.xlabel('epoch',fontsize=24) | ||
plt.legend(prop={'size':24,'family':'monospace'}) | ||
plt.grid(True) | ||
plt.show() | ||
|
||
if __name__ == "__main__": | ||
main(sys.argv) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import mxnet as mx | ||
import math | ||
import random | ||
|
||
|
||
def get_conv(name, data, kout, kernel, stride, pad, relu=True): | ||
conv = mx.symbol.Convolution(name=name, data=data, num_filter=kout, kernel=kernel, stride=stride, pad=pad, no_bias=True) | ||
bn = mx.symbol.BatchNorm(name=name + '_bn', data=conv, fix_gamma=False, momentum=0.9, eps=2e-5) | ||
return (mx.symbol.Activation(name=name + '_relu', data=bn, act_type='relu') if relu else bn) | ||
|
||
def get_pre(name, data, kout, kernel, stride, pad, relu=True): | ||
data = mx.symbol.BatchNorm(name=name + '_bn', data=data, fix_gamma=False, momentum=0.9, eps=2e-5) | ||
data = mx.symbol.Activation(name=name + '_relu', data=data, act_type='relu') | ||
conv = mx.symbol.Convolution(name=name, data=data, num_filter=kout, kernel=kernel, stride=stride, pad=pad, no_bias=True) | ||
return conv | ||
|
||
def get_deep(name, data, kin, kout, stride,relu=True): | ||
conv1 = get_conv(name=name+'_conv1', data=data , kout=kout, kernel=(3, 3), stride=stride, pad=(1, 1)) | ||
conv2 = get_conv(name=name+'_conv2', data=conv1, kout=kout, kernel=(3, 3), stride=(1, 1), pad=(1, 1),relu=relu) | ||
return conv2 | ||
|
||
def get_deep2(name, data, kin, kout, stride,relu=True): | ||
conv = mx.symbol.Convolution(name=name+'_conv1', data=data, num_filter=kout, kernel=(3,3), stride=stride, pad=(1,1), no_bias=True) | ||
conv = mx.symbol.BatchNorm(name=name + '_bn1', data=conv, fix_gamma=False, momentum=0.9, eps=2e-5) | ||
conv = mx.symbol.Activation(name=name + '_relu1', data=conv, act_type='relu') | ||
|
||
conv = mx.symbol.Convolution(name=name+'_conv2', data=conv, num_filter=kout, kernel=(3,3), stride=(1,1), pad=(1,1), no_bias=True) | ||
conv = mx.symbol.BatchNorm(name=name + '_bn2', data=conv, fix_gamma=False, momentum=0.9, eps=2e-5) | ||
return conv | ||
|
||
def get_deep2bott(name, data, kin, kout, stride,relu=True): | ||
conv = mx.symbol.Convolution(name=name+'_conv1', data=data, num_filter=int(kout/4), kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True) | ||
conv = mx.symbol.BatchNorm(name=name + '_bn1', data=conv, fix_gamma=False, momentum=0.9, eps=2e-5) | ||
conv = mx.symbol.Activation(name=name + '_relu1', data=conv, act_type='relu') | ||
conv = mx.symbol.Convolution(name=name+'_conv2', data=conv, num_filter=int(kout/4), kernel=(3,3), stride=stride, pad=(1,1), no_bias=True) | ||
conv = mx.symbol.BatchNorm(name=name + '_bn2', data=conv, fix_gamma=False, momentum=0.9, eps=2e-5) | ||
conv = mx.symbol.Activation(name=name + '_relu2', data=conv, act_type='relu') | ||
conv = mx.symbol.Convolution(name=name+'_conv3', data=conv, num_filter=kout, kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True) | ||
conv = mx.symbol.BatchNorm(name=name + '_bn3', data=conv, fix_gamma=False, momentum=0.9, eps=2e-5) | ||
|
||
return conv | ||
|
||
|
||
def get_fusion(name, datal, datar, kin, kout, stride): | ||
if kin == kout: | ||
shortcut1l = datal | ||
deep1l = get_deep2bott(name+'_deep1l', datal, kin, kout, stride) | ||
|
||
shortcut1r = datar | ||
deep1r = get_deep2bott(name+'_deep1r', datar, kin, kout, stride) | ||
|
||
fuse = 0.5 * (shortcut1l + shortcut1r) | ||
|
||
datal=deep1l+fuse | ||
datar=deep1r+fuse | ||
else: | ||
datal = get_deep2bott(name+'_deep1l', datal, kin, kout, stride) | ||
datar = get_deep2bott(name+'_deep1r', datar, kin, kout, stride) | ||
|
||
|
||
datal = mx.symbol.Activation(name=name+'_relu', data=datal, act_type='relu') | ||
datar = mx.symbol.Activation(name=name+'_relu', data=datar, act_type='relu') | ||
return datal, datar | ||
|
||
def get_group(name, datal, datar, num_block, kin, kout, stride): | ||
for idx in range(num_block): | ||
datal, datar = get_fusion(name=name+'_b%d'%(idx+1), datal=datal, datar=datar, kin=kin, kout=kout, stride=stride if idx == 0 else (1, 1)) | ||
kin = kout | ||
return datal, datar | ||
|
||
def get_symbol(num_classes=1000): | ||
# setup model parameters | ||
blocks_num = (1,2,11,2) | ||
channels = 64 | ||
# start network definition | ||
data = mx.symbol.Variable(name='data') | ||
# stage conv1_x | ||
conv1 = mx.symbol.Convolution(name='g0', data=data, num_filter=channels, kernel=(7,7), stride=(2,2), pad=(3,3), no_bias=True) | ||
pool1 = mx.symbol.Pooling(name='g0_pool', data=conv1, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max') | ||
# stage conv2_x, conv3_x, conv4_x, conv5_x | ||
conv2_x_1, conv2_x_2 = get_group(name='g1', datal=pool1, datar=pool1, num_block=blocks_num[0], kin=channels, kout=channels*4, stride=(1,1)) | ||
|
||
conv3_x_1, conv3_x_2 = get_group(name='g2', datal=conv2_x_1, datar=conv2_x_2, num_block=blocks_num[1], kin=channels*4, kout=channels*8, stride=(2,2)) | ||
|
||
conv4_x_1, conv4_x_2 = get_group(name='g3', datal=conv3_x_1, datar=conv3_x_2, num_block=blocks_num[2], kin=channels*8, kout=channels*16, stride=(2,2)) | ||
|
||
conv5_x_1, conv5_x_2 = get_group(name='g4', datal=conv4_x_1, datar=conv4_x_2, num_block=blocks_num[3], kin=channels*16, kout=channels*32, stride=(2,2)) | ||
|
||
conv5_x= mx.symbol.Concat(conv5_x_1,conv5_x_2) | ||
avg = mx.symbol.Pooling(name='global_pool', data=conv5_x, kernel=(7, 7), stride=(1, 1), pool_type='avg') | ||
flatten = mx.sym.Flatten(name="flatten", data=avg) | ||
fc = mx.symbol.FullyConnected(name='fc_score', data=flatten, num_hidden=num_classes) | ||
softmax = mx.symbol.SoftmaxOutput(name='softmax', data=fc) | ||
|
||
return softmax |
Oops, something went wrong.