Skip to content

Commit

Permalink
Updated training code implemented in our IJCAI-2018 paper and the cor…
Browse files Browse the repository at this point in the history
…responding readme.
  • Loading branch information
DeppMeng authored and zlmzju committed Jul 12, 2018
1 parent b20f61f commit 659e6bf
Show file tree
Hide file tree
Showing 8 changed files with 807 additions and 7 deletions.
23 changes: 16 additions & 7 deletions README.md
Expand Up @@ -35,23 +35,26 @@ DFN-MR1 | 56 | 1.7M | 4.94 | 24.46 | 1.66
DFN-MR2 | 32 | 14.9M | 3.94 | 19.25 | **1.51**
DFN-MR3 | 50 | 24.8M | **3.57** | **19.00**| 1.55

- Empirical results on ImageNet:
- Training and validation error (%) on ImageNet:

![imagenet_curve](visualize/paper/imagenet_curve.png)
Method | Depth | #Params | Top-1 train | Top-5 train | Top-1 val | Top-5 val |
--------|:-----:|:-------:|:-----------:|:-----------:|:---------:|:---------:|
ResNet | 98 | 45.0M | 15.09 | 3.25 | 23.38 | 6.79 |
DFN-MR | 50 | 46.4M | **14.46** | **3.16** | **23.16** | **6.61** |

## Requirements
- Install [MXNet](http://mxnet.readthedocs.io/en/latest/how_to/build.html) on a machine (Windows, Linux, and Mac OS) with CUDA GPU and optional [cuDNN](https://developer.nvidia.com/cudnn).

- Apply my modified data processing patch on the latest MXNet by merging the pull request:

```shell
git pull origin pull/3936/head master
git pull origin pull/3936/head master
```

- (Recommended) If you fail to apply the above patch, you can simply use [my MXNet repository](https://github.com/zlmzju/mxnet/tree/fusenet):

```shell
git clone --recursive -b fusenet https://github.com/zlmzju/mxnet.git
git clone --recursive -b fusenet https://github.com/zlmzju/mxnet.git
```

## How to Train
Expand All @@ -60,18 +63,24 @@ Step by step tutorial with jupyter notebook is now available, please check the f

### dataset
You can prepare the `*.rec` file by yourself, or simply download the `Cifar` dataset from [data.dmlc.ml](http://data.dmlc.ml/mxnet/data/) or my [google drive](https://drive.google.com/open?id=0By55MQnF3PHCQmRhRTBuWk5DRkk) (recommended), which includes both `Cifar` and `SVHN` datasets.
For ImageNet dataset, follow the [mxnet official document](http://mxnet.incubator.apache.org/tutorials/vision/large_scale_classification.html?highlight=imagenet)
to prepare.

### training
Current code supports training different deeply-fused nets on Cifar-10, Cifar-100 and SVHN, such as `plain` network, `resnet`, `cross` (dfn-mr),`half` (dfn-il), `side` (dfn-il without identities), `fuse3` (three fusions), `fuse6` (six fusions), and `ensemble` (with sharing weights, training code will come later). All the networks are contained in the `network` folder.

Note that the codes for training on `ImageNet` are available in the [imagenet](https://github.com/zlmzju/fusenet/tree/imagenet) branch, but they still need refactoring to merge into the `master` branch.
Current code supports training different deeply-fused nets on Cifar-10, Cifar-100, SVHN and ImageNet, such as `plain` network, `resnet`, `cross` (dfn-mr),`half` (dfn-il), `side` (dfn-il without identities), `fuse3` (three fusions), `fuse6` (six fusions), and `ensemble` (with sharing weights, training code will come later). All the networks are contained in the `network` folder.

For example, running the following command can train the `DFN-MR` network (we call it `cross` in the coding stage) on Cifar-10.

```shell
python train_model.py --dataset=cifar10 --network=cross --depth=56 --gpus=0,1 --dataset=<dataset location>
```

To train DFN-MR network on ImageNet, run

```shell
python train_imagenet.py --network=symbol_cross --gpus=0,1,2,3 --data-dir=<dataset location>
```

## Other usages

### visualization
Expand Down
182 changes: 182 additions & 0 deletions show_results.py
@@ -0,0 +1,182 @@
import sys
import argparse
import numpy as np
import matplotlib.pyplot as plt

def read_file(file_name):
file=open(file_name)
contents=file.readlines()
train_result_dict={}
val_result_dict={}
for oneline in contents:
array=oneline.split()
if len(array)<4:
continue
if len(array)<6:
epoch=int(array[3].replace(']','').split('[')[-1])+1
if not val_result_dict.has_key(epoch):
val_result_dict[epoch]={}
val_result_dict[epoch]['loss']=[]
val_result_dict[epoch]['top1']=[]
val_result_dict[epoch]['top5']=[]
key=array[4][:14]
if key=='Validation-cro':
val_result_dict[epoch]['loss'].append(float(array[4].split('=')[-1]))
elif key=='Validation-acc':
val_result_dict[epoch]['top1'].append(float(array[4].split('=')[-1]))
elif key=='Validation-top':
val_result_dict[epoch]['top5'].append(float(array[4].split('=')[-1]))
elif array[5]=='[5000]':
epoch=int(array[3].replace(']','').split('[')[-1])+1
if not val_result_dict.has_key(epoch):
train_result_dict[epoch]={}
train_result_dict[epoch]['loss']=[]
train_result_dict[epoch]['top1']=[]
train_result_dict[epoch]['top5']=[]
scale= 0.01 if float(array[-1].replace(')','').split('(')[-1])>1 else 1
train_result_dict[epoch]['loss'].append(float(array[9].replace(')','').split('(')[-1]))
train_result_dict[epoch]['top1'].append(scale*float(array[11].replace(')','').split('(')[-1]))
train_result_dict[epoch]['top5'].append(scale*float(array[-1].replace(')','').split('(')[-1]))
return train_result_dict,val_result_dict

def get_paper(style='acc'):
x=[1,5,10,15,20,25,30,
31,35,40,45,50,55,60,
61,65,70,75,80,85,90]
top1_val_coarse=[89.3,46.5,45,44.3,42.8,42,41.7,
32, 29, 29,29.2,29,28.8,29,
26.2,26.1,26,25.8,25.6,25.0,24.7];
top1_train_coarse=[95,55,50,48,47,46.5,46,
34,32,30.5,30,30,30,30,
24,23.5,22.8,21,20.5,20.3,20.1];
xvals = np.linspace(1, 90, 90)
top1_val_fit = np.interp(xvals, x, top1_val_coarse)
top1_train_fit = np.interp(xvals, x, top1_train_coarse)
#return
bias=0.0
scale=1.0
best=np.argmin
if style=='acc':
bias=100.0
scale*=-1.0
best=np.argmax
top1_val=[bias+scale*value for value in top1_val_fit]
top5_val=[bias+scale*7.8]*len(top1_val)
top1_train=[bias+scale*value for value in top1_train_fit]
top5_train=[-1.0]*len(top1_train)
return top1_val,top5_val,top1_train,top5_train,best

def print_result(result_dict,prefix,style='acc'):
top1=[]
top5=[]
loss=[]
bias=0.0
scale=100.0
best=np.argmax
if style=='err':
bias=100.0
scale=-100.0
best=np.argmin
length=len(result_dict)
total_epochs=100
length=total_epochs if length>total_epochs else length
for epoch in range(1,length+1):
res=result_dict[epoch]
top1.append(bias+scale*res['top1'][-1])
top5.append(bias+scale*res['top5'][-1])
loss.append(res['loss'][-1])
# print(' * %s epoch # %02d top1: %7.3f top5: %7.3f loss: %7.3f'%\
# (prefix, epoch,top1[epoch-1],top5[epoch-1],loss[epoch-1]))
return top1,top5,loss,best

def main(argv):
parser = argparse.ArgumentParser()
# Required arguments: input and output files.
parser.add_argument(
"--net",
default='all',
help="Network to show: plain, origin, fuse[1-3], and all." +
"For example: 'python script.py fuse3'."
)
parser.add_argument(
"--style",
default='acc',
help="Style to show the accuracy: acc or err (default: acc)." +
"For example: accuracy=100.0%% for 'acc' and err=0.0%% for 'err'."
)
parser.add_argument(
"--dir",
default='snapshot',
help="Directory of the result txt file (default: snapshot)."
)
parser.add_argument(
"--plot",
default='val',
help="plot train or val, or both of them (default: all)."
)
args = parser.parse_args()
#whole
if '50_4' in args.net:
networks=['dfn-mr_50_4gpu','resnet_50_4gpu']
elif '50' in args.net:
networks=['dfn-mr_50_8gpu','resnet_50_8gpu']
elif '101' in args.net:
networks=['dfn-mr_101_8gpu','resnet_101_8gpu']
#'dfn-mr_50_8gpu','resnet_50_8gpu','origin_resnet_50'
#'dfn-mr_50_4gpu','resnet_50_4gpu','origin_resnet_50'
#'dfn-mr_101_8gpu','resnet_101_8gpu','origin_resnet_101']
else:
networks=['dfn-mr_101_8gpu','resnet_101_8gpu']
#figure params
colors=['black','blue','orange','green','red','cyan','pink']
plt.figure()
#one net
print args.net
import glob
all_logs=glob.glob('snapshot/*/*/*.txt')
for idx in range(len(networks)):
net=networks[idx]
if 'paper' not in net:
for log_name in all_logs:
if net in log_name:
file_name=log_name
print file_name
train_res,val_res=read_file(file_name)
else:
top1_val,top5_val,top1_train,top5_train,best=get_paper(args.style)
if 'middle' in net:
net=net.replace('middle','dfn-mr')
title='(solid lines: 1-crop val error; dashed lines: training error)'
if args.plot=='train' or args.plot=='all':
if net!='paper':
top1,top5,loss,best=print_result(train_res,'Finished',args.style)
report_idx=-1#best(top1)
plot1=plt.plot(range(1,len(top1)+1),top1,color=colors[idx], linestyle=':', linewidth=2.0, alpha=1 if args.plot=='train' else 0.5,
marker='o' if args.plot=='train' else None,
label='%6s top1:%7.2f%% top5:%7.2f%%'%(net,top1[report_idx],top5[report_idx]) if args.plot=='train' else None)

if args.plot=='val' or args.plot=='all':
if net!='paper':
top1,top5,loss,best=print_result(val_res,'Finished',args.style)
elif net=='paper':
top1=top1_val
top5=top5_val
if args.plot=='val':
title='(validation error)'
report_idx=-1#best(top1)
plot2=plt.plot(range(1,len(top1)+1),top1,color=colors[idx], linestyle='-',
marker='o' if args.plot=='val' else None,linewidth=2.0,
label='%6s (%d) top1:%7.2f%% top5:%7.2f%%'%(net,len(top1),top1[report_idx],top5[report_idx]))
#other figure style
plt.ylim([15,90])
plt.xticks(np.arange(0, 50, 10))
#title
plt.title('Performance of different architectures on ImageNet %s'%(title),fontsize=28)
plt.ylabel('top1 error',fontsize=24)
plt.xlabel('epoch',fontsize=24)
plt.legend(prop={'size':24,'family':'monospace'})
plt.grid(True)
plt.show()

if __name__ == "__main__":
main(sys.argv)
95 changes: 95 additions & 0 deletions symbol_cross.py
@@ -0,0 +1,95 @@
import mxnet as mx
import math
import random


def get_conv(name, data, kout, kernel, stride, pad, relu=True):
conv = mx.symbol.Convolution(name=name, data=data, num_filter=kout, kernel=kernel, stride=stride, pad=pad, no_bias=True)
bn = mx.symbol.BatchNorm(name=name + '_bn', data=conv, fix_gamma=False, momentum=0.9, eps=2e-5)
return (mx.symbol.Activation(name=name + '_relu', data=bn, act_type='relu') if relu else bn)

def get_pre(name, data, kout, kernel, stride, pad, relu=True):
data = mx.symbol.BatchNorm(name=name + '_bn', data=data, fix_gamma=False, momentum=0.9, eps=2e-5)
data = mx.symbol.Activation(name=name + '_relu', data=data, act_type='relu')
conv = mx.symbol.Convolution(name=name, data=data, num_filter=kout, kernel=kernel, stride=stride, pad=pad, no_bias=True)
return conv

def get_deep(name, data, kin, kout, stride,relu=True):
conv1 = get_conv(name=name+'_conv1', data=data , kout=kout, kernel=(3, 3), stride=stride, pad=(1, 1))
conv2 = get_conv(name=name+'_conv2', data=conv1, kout=kout, kernel=(3, 3), stride=(1, 1), pad=(1, 1),relu=relu)
return conv2

def get_deep2(name, data, kin, kout, stride,relu=True):
conv = mx.symbol.Convolution(name=name+'_conv1', data=data, num_filter=kout, kernel=(3,3), stride=stride, pad=(1,1), no_bias=True)
conv = mx.symbol.BatchNorm(name=name + '_bn1', data=conv, fix_gamma=False, momentum=0.9, eps=2e-5)
conv = mx.symbol.Activation(name=name + '_relu1', data=conv, act_type='relu')

conv = mx.symbol.Convolution(name=name+'_conv2', data=conv, num_filter=kout, kernel=(3,3), stride=(1,1), pad=(1,1), no_bias=True)
conv = mx.symbol.BatchNorm(name=name + '_bn2', data=conv, fix_gamma=False, momentum=0.9, eps=2e-5)
return conv

def get_deep2bott(name, data, kin, kout, stride,relu=True):
conv = mx.symbol.Convolution(name=name+'_conv1', data=data, num_filter=int(kout/4), kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True)
conv = mx.symbol.BatchNorm(name=name + '_bn1', data=conv, fix_gamma=False, momentum=0.9, eps=2e-5)
conv = mx.symbol.Activation(name=name + '_relu1', data=conv, act_type='relu')
conv = mx.symbol.Convolution(name=name+'_conv2', data=conv, num_filter=int(kout/4), kernel=(3,3), stride=stride, pad=(1,1), no_bias=True)
conv = mx.symbol.BatchNorm(name=name + '_bn2', data=conv, fix_gamma=False, momentum=0.9, eps=2e-5)
conv = mx.symbol.Activation(name=name + '_relu2', data=conv, act_type='relu')
conv = mx.symbol.Convolution(name=name+'_conv3', data=conv, num_filter=kout, kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True)
conv = mx.symbol.BatchNorm(name=name + '_bn3', data=conv, fix_gamma=False, momentum=0.9, eps=2e-5)

return conv


def get_fusion(name, datal, datar, kin, kout, stride):
if kin == kout:
shortcut1l = datal
deep1l = get_deep2bott(name+'_deep1l', datal, kin, kout, stride)

shortcut1r = datar
deep1r = get_deep2bott(name+'_deep1r', datar, kin, kout, stride)

fuse = 0.5 * (shortcut1l + shortcut1r)

datal=deep1l+fuse
datar=deep1r+fuse
else:
datal = get_deep2bott(name+'_deep1l', datal, kin, kout, stride)
datar = get_deep2bott(name+'_deep1r', datar, kin, kout, stride)


datal = mx.symbol.Activation(name=name+'_relu', data=datal, act_type='relu')
datar = mx.symbol.Activation(name=name+'_relu', data=datar, act_type='relu')
return datal, datar

def get_group(name, datal, datar, num_block, kin, kout, stride):
for idx in range(num_block):
datal, datar = get_fusion(name=name+'_b%d'%(idx+1), datal=datal, datar=datar, kin=kin, kout=kout, stride=stride if idx == 0 else (1, 1))
kin = kout
return datal, datar

def get_symbol(num_classes=1000):
# setup model parameters
blocks_num = (1,2,11,2)
channels = 64
# start network definition
data = mx.symbol.Variable(name='data')
# stage conv1_x
conv1 = mx.symbol.Convolution(name='g0', data=data, num_filter=channels, kernel=(7,7), stride=(2,2), pad=(3,3), no_bias=True)
pool1 = mx.symbol.Pooling(name='g0_pool', data=conv1, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max')
# stage conv2_x, conv3_x, conv4_x, conv5_x
conv2_x_1, conv2_x_2 = get_group(name='g1', datal=pool1, datar=pool1, num_block=blocks_num[0], kin=channels, kout=channels*4, stride=(1,1))

conv3_x_1, conv3_x_2 = get_group(name='g2', datal=conv2_x_1, datar=conv2_x_2, num_block=blocks_num[1], kin=channels*4, kout=channels*8, stride=(2,2))

conv4_x_1, conv4_x_2 = get_group(name='g3', datal=conv3_x_1, datar=conv3_x_2, num_block=blocks_num[2], kin=channels*8, kout=channels*16, stride=(2,2))

conv5_x_1, conv5_x_2 = get_group(name='g4', datal=conv4_x_1, datar=conv4_x_2, num_block=blocks_num[3], kin=channels*16, kout=channels*32, stride=(2,2))

conv5_x= mx.symbol.Concat(conv5_x_1,conv5_x_2)
avg = mx.symbol.Pooling(name='global_pool', data=conv5_x, kernel=(7, 7), stride=(1, 1), pool_type='avg')
flatten = mx.sym.Flatten(name="flatten", data=avg)
fc = mx.symbol.FullyConnected(name='fc_score', data=flatten, num_hidden=num_classes)
softmax = mx.symbol.SoftmaxOutput(name='softmax', data=fc)

return softmax

0 comments on commit 659e6bf

Please sign in to comment.