In [1]:
import mxnet as mx

def get_iterators(batch_size, data_shape=(3, 224, 224)):
    train = mx.io.ImageRecordIter(
        path_imgrec         = '/data4/srip_face/img/jump_detector/jump_train.rec', 
        data_name           = 'data',
        label_name          = 'softmax_label',
        batch_size          = batch_size,
        data_shape          = data_shape,
        shuffle             = True,
        rand_crop           = True,
        rand_mirror         = True)
    val = mx.io.ImageRecordIter(
        path_imgrec         = '/data4/srip_face/img/jump_detector/jump_valid.rec',
        data_name           = 'data',
        label_name          = 'softmax_label',
        batch_size          = batch_size,
        data_shape          = data_shape,
        rand_crop           = False,
        rand_mirror         = False)
    return (train, val)

In [2]:
import os, urllib
def download(url):
    filename = url.split("/")[-1]
    if not os.path.exists(filename):
        urllib.urlretrieve(url, filename)
        
def get_model(prefix, epoch):
    download(prefix+'-symbol.json')
    download(prefix+'-%04d.params' % (epoch,))

get_model('http://data.mxnet.io/models/imagenet/resnet/152-layers/resnet-152', 0)
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-152', 0)

In [3]:
def get_fine_tune_model(symbol, arg_params, num_classes, layer_name='flatten0'):
    """
    symbol: the pre-trained network symbol
    arg_params: the argument parameters of the pre-trained model
    num_classes: the number of classes for the fine-tune datasets
    layer_name: the layer name before the last fully-connected layer
    """
    all_layers = symbol.get_internals()
    net = all_layers[layer_name+'_output']
    net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc1')
    net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
    new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})
    return (net, new_args)

In [4]:
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

def fit(symbol, arg_params, aux_params, train, val, batch_size, num_gpus=1, num_epoch=1):
    devs = [mx.gpu(i) for i in range(num_gpus)] # replace mx.gpu by mx.cpu for CPU training
    mod = mx.mod.Module(symbol=symbol, context=devs)
    mod.bind(data_shapes=train.provide_data, label_shapes=train.provide_label)
    mod.init_params(initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2))
    mod.set_params(arg_params, aux_params, allow_missing=True)
    mod.fit(train, val, 
        num_epoch=num_epoch,
        batch_end_callback = mx.callback.Speedometer(batch_size, 10),        
        kvstore='device',
        optimizer='sgd',
        optimizer_params={'learning_rate':0.009},
        eval_metric='acc')
    
    return mod

In [5]:
num_classes = 2 # This is binary classification
batch_per_gpu = 4
num_gpus = 1
epoch = 3
(new_sym, new_args) = get_fine_tune_model(sym, arg_params, num_classes)

batch_size = batch_per_gpu * num_gpus
(train, val) = get_iterators(batch_size)
mod = fit(new_sym, new_args, aux_params, train, val, batch_size, num_gpus, epoch)
metric = mx.metric.Accuracy()
mod_score = mod.score(val, metric)
print mod_score

2017-09-08 17:24:36,699 Already bound, ignoring bind()
  allow_missing=allow_missing, force_init=force_init)
2017-09-08 17:24:42,953 Epoch[0] Batch [10]	Speed: 11.33 samples/sec	accuracy=0.681818
2017-09-08 17:24:46,415 Epoch[0] Batch [20]	Speed: 11.56 samples/sec	accuracy=0.725000
2017-09-08 17:24:49,886 Epoch[0] Batch [30]	Speed: 11.53 samples/sec	accuracy=0.775000
2017-09-08 17:24:53,372 Epoch[0] Batch [40]	Speed: 11.48 samples/sec	accuracy=0.775000
2017-09-08 17:24:56,831 Epoch[0] Batch [50]	Speed: 11.57 samples/sec	accuracy=0.725000
2017-09-08 17:25:00,322 Epoch[0] Batch [60]	Speed: 11.46 samples/sec	accuracy=0.875000
2017-09-08 17:25:03,795 Epoch[0] Batch [70]	Speed: 11.52 samples/sec	accuracy=0.700000
2017-09-08 17:25:07,287 Epoch[0] Batch [80]	Speed: 11.46 samples/sec	accuracy=0.825000
2017-09-08 17:25:10,754 Epoch[0] Batch [90]	Speed: 11.54 samples/sec	accuracy=0.700000
2017-09-08 17:25:14,223 Epoch[0] Batch [100]	Speed: 11.54 samples/sec	accuracy=0.850000
2017-09-08 17:25:17,

2017-09-08 17:29:56,282 Epoch[0] Batch [910]	Speed: 11.48 samples/sec	accuracy=0.825000
2017-09-08 17:29:59,786 Epoch[0] Batch [920]	Speed: 11.42 samples/sec	accuracy=0.875000
2017-09-08 17:30:03,267 Epoch[0] Batch [930]	Speed: 11.49 samples/sec	accuracy=0.700000
2017-09-08 17:30:06,746 Epoch[0] Batch [940]	Speed: 11.50 samples/sec	accuracy=0.825000
2017-09-08 17:30:10,223 Epoch[0] Batch [950]	Speed: 11.51 samples/sec	accuracy=0.875000
2017-09-08 17:30:13,698 Epoch[0] Batch [960]	Speed: 11.52 samples/sec	accuracy=0.825000
2017-09-08 17:30:17,177 Epoch[0] Batch [970]	Speed: 11.50 samples/sec	accuracy=0.925000
2017-09-08 17:30:20,657 Epoch[0] Batch [980]	Speed: 11.50 samples/sec	accuracy=0.825000
2017-09-08 17:30:24,134 Epoch[0] Batch [990]	Speed: 11.51 samples/sec	accuracy=0.775000
2017-09-08 17:30:27,611 Epoch[0] Batch [1000]	Speed: 11.51 samples/sec	accuracy=0.775000
2017-09-08 17:30:31,088 Epoch[0] Batch [1010]	Speed: 11.51 samples/sec	accuracy=0.875000
2017-09-08 17:30:34,576 Epoch[

2017-09-08 17:35:20,100 Epoch[0] Batch [1840]	Speed: 11.48 samples/sec	accuracy=0.975000
2017-09-08 17:35:23,587 Epoch[0] Batch [1850]	Speed: 11.48 samples/sec	accuracy=0.850000
2017-09-08 17:35:27,069 Epoch[0] Batch [1860]	Speed: 11.49 samples/sec	accuracy=0.875000
2017-09-08 17:35:30,559 Epoch[0] Batch [1870]	Speed: 11.46 samples/sec	accuracy=0.925000
2017-09-08 17:35:34,047 Epoch[0] Batch [1880]	Speed: 11.47 samples/sec	accuracy=0.775000
2017-09-08 17:35:37,529 Epoch[0] Batch [1890]	Speed: 11.49 samples/sec	accuracy=0.950000
2017-09-08 17:35:41,011 Epoch[0] Batch [1900]	Speed: 11.49 samples/sec	accuracy=0.900000
2017-09-08 17:35:44,483 Epoch[0] Batch [1910]	Speed: 11.52 samples/sec	accuracy=0.775000
2017-09-08 17:35:47,960 Epoch[0] Batch [1920]	Speed: 11.51 samples/sec	accuracy=0.950000
2017-09-08 17:35:51,442 Epoch[0] Batch [1930]	Speed: 11.49 samples/sec	accuracy=0.825000
2017-09-08 17:35:54,924 Epoch[0] Batch [1940]	Speed: 11.49 samples/sec	accuracy=0.875000
2017-09-08 17:35:58,4

2017-09-08 17:40:43,786 Epoch[0] Batch [2770]	Speed: 11.51 samples/sec	accuracy=0.850000
2017-09-08 17:40:47,267 Epoch[0] Batch [2780]	Speed: 11.50 samples/sec	accuracy=0.775000
2017-09-08 17:40:50,743 Epoch[0] Batch [2790]	Speed: 11.51 samples/sec	accuracy=0.925000
2017-09-08 17:40:54,228 Epoch[0] Batch [2800]	Speed: 11.48 samples/sec	accuracy=0.825000
2017-09-08 17:40:57,709 Epoch[0] Batch [2810]	Speed: 11.49 samples/sec	accuracy=0.950000
2017-09-08 17:41:01,188 Epoch[0] Batch [2820]	Speed: 11.50 samples/sec	accuracy=0.975000
2017-09-08 17:41:04,667 Epoch[0] Batch [2830]	Speed: 11.50 samples/sec	accuracy=0.925000
2017-09-08 17:41:08,147 Epoch[0] Batch [2840]	Speed: 11.50 samples/sec	accuracy=0.900000
2017-09-08 17:41:11,626 Epoch[0] Batch [2850]	Speed: 11.50 samples/sec	accuracy=0.950000
2017-09-08 17:41:15,099 Epoch[0] Batch [2860]	Speed: 11.52 samples/sec	accuracy=0.925000
2017-09-08 17:41:18,575 Epoch[0] Batch [2870]	Speed: 11.51 samples/sec	accuracy=0.875000
2017-09-08 17:41:22,0

2017-09-08 17:46:07,760 Epoch[0] Batch [3700]	Speed: 11.50 samples/sec	accuracy=0.825000
2017-09-08 17:46:11,239 Epoch[0] Batch [3710]	Speed: 11.50 samples/sec	accuracy=0.875000
2017-09-08 17:46:14,715 Epoch[0] Batch [3720]	Speed: 11.51 samples/sec	accuracy=0.975000
2017-09-08 17:46:18,191 Epoch[0] Batch [3730]	Speed: 11.51 samples/sec	accuracy=0.850000
2017-09-08 17:46:21,669 Epoch[0] Batch [3740]	Speed: 11.51 samples/sec	accuracy=0.825000
2017-09-08 17:46:25,145 Epoch[0] Batch [3750]	Speed: 11.51 samples/sec	accuracy=0.900000
2017-09-08 17:46:28,623 Epoch[0] Batch [3760]	Speed: 11.51 samples/sec	accuracy=0.950000
2017-09-08 17:46:32,100 Epoch[0] Batch [3770]	Speed: 11.51 samples/sec	accuracy=0.825000
2017-09-08 17:46:35,577 Epoch[0] Batch [3780]	Speed: 11.51 samples/sec	accuracy=0.875000
2017-09-08 17:46:39,057 Epoch[0] Batch [3790]	Speed: 11.50 samples/sec	accuracy=0.850000
2017-09-08 17:46:42,537 Epoch[0] Batch [3800]	Speed: 11.50 samples/sec	accuracy=0.950000
2017-09-08 17:46:46,0

2017-09-08 17:53:14,221 Epoch[1] Batch [490]	Speed: 11.58 samples/sec	accuracy=0.900000
2017-09-08 17:53:17,665 Epoch[1] Batch [500]	Speed: 11.62 samples/sec	accuracy=0.850000
2017-09-08 17:53:21,119 Epoch[1] Batch [510]	Speed: 11.58 samples/sec	accuracy=0.900000
2017-09-08 17:53:24,580 Epoch[1] Batch [520]	Speed: 11.56 samples/sec	accuracy=0.750000
2017-09-08 17:53:28,047 Epoch[1] Batch [530]	Speed: 11.54 samples/sec	accuracy=0.900000
2017-09-08 17:53:31,505 Epoch[1] Batch [540]	Speed: 11.57 samples/sec	accuracy=0.975000
2017-09-08 17:53:34,964 Epoch[1] Batch [550]	Speed: 11.57 samples/sec	accuracy=0.875000
2017-09-08 17:53:38,424 Epoch[1] Batch [560]	Speed: 11.56 samples/sec	accuracy=1.000000
2017-09-08 17:53:41,903 Epoch[1] Batch [570]	Speed: 11.50 samples/sec	accuracy=0.925000
2017-09-08 17:53:45,361 Epoch[1] Batch [580]	Speed: 11.57 samples/sec	accuracy=0.925000
2017-09-08 17:53:48,824 Epoch[1] Batch [590]	Speed: 11.56 samples/sec	accuracy=0.925000
2017-09-08 17:53:52,282 Epoch[1]

2017-09-08 17:58:36,017 Epoch[1] Batch [1420]	Speed: 11.58 samples/sec	accuracy=0.900000
2017-09-08 17:58:39,466 Epoch[1] Batch [1430]	Speed: 11.60 samples/sec	accuracy=0.725000
2017-09-08 17:58:42,935 Epoch[1] Batch [1440]	Speed: 11.53 samples/sec	accuracy=0.900000
2017-09-08 17:58:46,393 Epoch[1] Batch [1450]	Speed: 11.57 samples/sec	accuracy=0.950000
2017-09-08 17:58:49,856 Epoch[1] Batch [1460]	Speed: 11.56 samples/sec	accuracy=0.850000
2017-09-08 17:58:53,308 Epoch[1] Batch [1470]	Speed: 11.59 samples/sec	accuracy=0.900000
2017-09-08 17:58:56,773 Epoch[1] Batch [1480]	Speed: 11.55 samples/sec	accuracy=0.825000
2017-09-08 17:59:00,229 Epoch[1] Batch [1490]	Speed: 11.58 samples/sec	accuracy=0.950000
2017-09-08 17:59:03,692 Epoch[1] Batch [1500]	Speed: 11.56 samples/sec	accuracy=0.900000
2017-09-08 17:59:07,146 Epoch[1] Batch [1510]	Speed: 11.58 samples/sec	accuracy=0.950000
2017-09-08 17:59:10,604 Epoch[1] Batch [1520]	Speed: 11.57 samples/sec	accuracy=0.975000
2017-09-08 17:59:14,0

2017-09-08 18:03:57,737 Epoch[1] Batch [2350]	Speed: 11.56 samples/sec	accuracy=0.900000
2017-09-08 18:04:01,199 Epoch[1] Batch [2360]	Speed: 11.56 samples/sec	accuracy=0.925000
2017-09-08 18:04:04,679 Epoch[1] Batch [2370]	Speed: 11.50 samples/sec	accuracy=0.925000
2017-09-08 18:04:08,136 Epoch[1] Batch [2380]	Speed: 11.58 samples/sec	accuracy=0.950000
2017-09-08 18:04:11,602 Epoch[1] Batch [2390]	Speed: 11.54 samples/sec	accuracy=0.925000
2017-09-08 18:04:15,071 Epoch[1] Batch [2400]	Speed: 11.53 samples/sec	accuracy=0.825000
2017-09-08 18:04:18,540 Epoch[1] Batch [2410]	Speed: 11.53 samples/sec	accuracy=0.875000
2017-09-08 18:04:22,000 Epoch[1] Batch [2420]	Speed: 11.56 samples/sec	accuracy=0.900000
2017-09-08 18:04:25,461 Epoch[1] Batch [2430]	Speed: 11.56 samples/sec	accuracy=0.950000
2017-09-08 18:04:28,918 Epoch[1] Batch [2440]	Speed: 11.58 samples/sec	accuracy=0.900000
2017-09-08 18:04:32,402 Epoch[1] Batch [2450]	Speed: 11.49 samples/sec	accuracy=0.875000
2017-09-08 18:04:35,8

2017-09-08 18:09:19,837 Epoch[1] Batch [3280]	Speed: 11.56 samples/sec	accuracy=0.900000
2017-09-08 18:09:23,296 Epoch[1] Batch [3290]	Speed: 11.57 samples/sec	accuracy=0.950000
2017-09-08 18:09:26,763 Epoch[1] Batch [3300]	Speed: 11.54 samples/sec	accuracy=0.975000
2017-09-08 18:09:30,221 Epoch[1] Batch [3310]	Speed: 11.57 samples/sec	accuracy=0.900000
2017-09-08 18:09:33,679 Epoch[1] Batch [3320]	Speed: 11.57 samples/sec	accuracy=0.875000
2017-09-08 18:09:37,132 Epoch[1] Batch [3330]	Speed: 11.59 samples/sec	accuracy=0.950000
2017-09-08 18:09:40,594 Epoch[1] Batch [3340]	Speed: 11.56 samples/sec	accuracy=0.950000
2017-09-08 18:09:44,055 Epoch[1] Batch [3350]	Speed: 11.56 samples/sec	accuracy=0.925000
2017-09-08 18:09:47,573 Epoch[1] Batch [3360]	Speed: 11.37 samples/sec	accuracy=0.950000
2017-09-08 18:09:51,024 Epoch[1] Batch [3370]	Speed: 11.59 samples/sec	accuracy=0.975000
2017-09-08 18:09:54,494 Epoch[1] Batch [3380]	Speed: 11.53 samples/sec	accuracy=0.925000
2017-09-08 18:09:57,9

2017-09-08 18:16:25,288 Epoch[2] Batch [70]	Speed: 11.57 samples/sec	accuracy=1.000000
2017-09-08 18:16:28,738 Epoch[2] Batch [80]	Speed: 11.60 samples/sec	accuracy=0.925000
2017-09-08 18:16:32,193 Epoch[2] Batch [90]	Speed: 11.58 samples/sec	accuracy=0.925000
2017-09-08 18:16:35,651 Epoch[2] Batch [100]	Speed: 11.57 samples/sec	accuracy=0.925000
2017-09-08 18:16:39,105 Epoch[2] Batch [110]	Speed: 11.59 samples/sec	accuracy=0.925000
2017-09-08 18:16:42,564 Epoch[2] Batch [120]	Speed: 11.57 samples/sec	accuracy=0.950000
2017-09-08 18:16:46,031 Epoch[2] Batch [130]	Speed: 11.54 samples/sec	accuracy=0.950000
2017-09-08 18:16:49,487 Epoch[2] Batch [140]	Speed: 11.58 samples/sec	accuracy=0.925000
2017-09-08 18:16:52,957 Epoch[2] Batch [150]	Speed: 11.53 samples/sec	accuracy=0.950000
2017-09-08 18:16:56,413 Epoch[2] Batch [160]	Speed: 11.58 samples/sec	accuracy=1.000000
2017-09-08 18:16:59,870 Epoch[2] Batch [170]	Speed: 11.58 samples/sec	accuracy=0.875000
2017-09-08 18:17:03,327 Epoch[2] Ba

2017-09-08 18:21:50,582 Epoch[2] Batch [1010]	Speed: 11.55 samples/sec	accuracy=0.925000
2017-09-08 18:21:54,061 Epoch[2] Batch [1020]	Speed: 11.50 samples/sec	accuracy=0.825000
2017-09-08 18:21:57,512 Epoch[2] Batch [1030]	Speed: 11.59 samples/sec	accuracy=0.900000
2017-09-08 18:22:00,978 Epoch[2] Batch [1040]	Speed: 11.55 samples/sec	accuracy=0.925000
2017-09-08 18:22:04,428 Epoch[2] Batch [1050]	Speed: 11.60 samples/sec	accuracy=0.975000
2017-09-08 18:22:07,911 Epoch[2] Batch [1060]	Speed: 11.49 samples/sec	accuracy=0.950000
2017-09-08 18:22:11,367 Epoch[2] Batch [1070]	Speed: 11.58 samples/sec	accuracy=0.925000
2017-09-08 18:22:14,829 Epoch[2] Batch [1080]	Speed: 11.56 samples/sec	accuracy=0.950000
2017-09-08 18:22:18,296 Epoch[2] Batch [1090]	Speed: 11.54 samples/sec	accuracy=0.950000
2017-09-08 18:22:21,750 Epoch[2] Batch [1100]	Speed: 11.59 samples/sec	accuracy=0.925000
2017-09-08 18:22:25,213 Epoch[2] Batch [1110]	Speed: 11.55 samples/sec	accuracy=0.950000
2017-09-08 18:22:28,6

2017-09-08 18:27:12,593 Epoch[2] Batch [1940]	Speed: 11.58 samples/sec	accuracy=0.950000
2017-09-08 18:27:16,045 Epoch[2] Batch [1950]	Speed: 11.59 samples/sec	accuracy=0.950000
2017-09-08 18:27:19,495 Epoch[2] Batch [1960]	Speed: 11.60 samples/sec	accuracy=0.975000
2017-09-08 18:27:22,944 Epoch[2] Batch [1970]	Speed: 11.60 samples/sec	accuracy=0.925000
2017-09-08 18:27:26,395 Epoch[2] Batch [1980]	Speed: 11.59 samples/sec	accuracy=0.925000
2017-09-08 18:27:29,854 Epoch[2] Batch [1990]	Speed: 11.57 samples/sec	accuracy=0.975000
2017-09-08 18:27:33,316 Epoch[2] Batch [2000]	Speed: 11.56 samples/sec	accuracy=0.875000
2017-09-08 18:27:36,765 Epoch[2] Batch [2010]	Speed: 11.60 samples/sec	accuracy=0.975000
2017-09-08 18:27:40,209 Epoch[2] Batch [2020]	Speed: 11.62 samples/sec	accuracy=0.950000
2017-09-08 18:27:43,658 Epoch[2] Batch [2030]	Speed: 11.60 samples/sec	accuracy=0.900000
2017-09-08 18:27:47,106 Epoch[2] Batch [2040]	Speed: 11.61 samples/sec	accuracy=0.975000
2017-09-08 18:27:50,5

2017-09-08 18:32:34,797 Epoch[2] Batch [2870]	Speed: 11.55 samples/sec	accuracy=1.000000
2017-09-08 18:32:38,258 Epoch[2] Batch [2880]	Speed: 11.56 samples/sec	accuracy=0.975000
2017-09-08 18:32:41,724 Epoch[2] Batch [2890]	Speed: 11.54 samples/sec	accuracy=0.950000
2017-09-08 18:32:45,186 Epoch[2] Batch [2900]	Speed: 11.56 samples/sec	accuracy=0.950000
2017-09-08 18:32:48,640 Epoch[2] Batch [2910]	Speed: 11.58 samples/sec	accuracy=0.925000
2017-09-08 18:32:52,125 Epoch[2] Batch [2920]	Speed: 11.49 samples/sec	accuracy=0.975000
2017-09-08 18:32:55,577 Epoch[2] Batch [2930]	Speed: 11.59 samples/sec	accuracy=0.950000
2017-09-08 18:32:59,046 Epoch[2] Batch [2940]	Speed: 11.54 samples/sec	accuracy=0.850000
2017-09-08 18:33:02,508 Epoch[2] Batch [2950]	Speed: 11.56 samples/sec	accuracy=0.975000
2017-09-08 18:33:05,964 Epoch[2] Batch [2960]	Speed: 11.58 samples/sec	accuracy=1.000000
2017-09-08 18:33:09,422 Epoch[2] Batch [2970]	Speed: 11.57 samples/sec	accuracy=1.000000
2017-09-08 18:33:12,8

2017-09-08 18:37:56,878 Epoch[2] Batch [3800]	Speed: 11.57 samples/sec	accuracy=0.975000
2017-09-08 18:38:00,334 Epoch[2] Batch [3810]	Speed: 11.58 samples/sec	accuracy=0.950000
2017-09-08 18:38:03,801 Epoch[2] Batch [3820]	Speed: 11.54 samples/sec	accuracy=0.950000
2017-09-08 18:38:07,265 Epoch[2] Batch [3830]	Speed: 11.55 samples/sec	accuracy=0.950000
2017-09-08 18:38:10,716 Epoch[2] Batch [3840]	Speed: 11.60 samples/sec	accuracy=0.900000
2017-09-08 18:38:14,186 Epoch[2] Batch [3850]	Speed: 11.53 samples/sec	accuracy=0.975000
2017-09-08 18:38:17,640 Epoch[2] Batch [3860]	Speed: 11.58 samples/sec	accuracy=0.800000
2017-09-08 18:38:21,097 Epoch[2] Batch [3870]	Speed: 11.58 samples/sec	accuracy=0.975000
2017-09-08 18:38:24,554 Epoch[2] Batch [3880]	Speed: 11.58 samples/sec	accuracy=0.950000
2017-09-08 18:38:28,017 Epoch[2] Batch [3890]	Speed: 11.55 samples/sec	accuracy=0.950000
2017-09-08 18:38:31,469 Epoch[2] Batch [3900]	Speed: 11.59 samples/sec	accuracy=0.875000
2017-09-08 18:38:34,9

[('accuracy', 0.9155470249520153)]


In [6]:
prefix = 'resnet-mxnet-jump-start'
mc = mod.save_checkpoint(prefix, epoch)

2017-09-08 18:43:27,838 Saved checkpoint to "resnet-mxnet-jump-start-0003.params"
