# Fine-tune with Pre-trained Models

In practice the dataset we use is relative small, so that we do not train an neural network from scratch, namely staring from random initialized parameters. Instead, it is common to train a neural network on a large-scale dataset and then use it either as an initialization or a fixed feature extractor. On [predict.ipynb](./predict.ipynb) we explained how to do the feature extraction, this tutorial will focus on how to use pre-trained model to fine tune a new network.

The idea of fine-tune is that, we take a pre-trained model, replace the last fully-connected layer with new one, which outputs the desired number of classes and initializes with random values. Then we train as normal except that we may often use a smaller learning rate since we may already very close the final result. 

We will use pre-trained models on the Imagenet dataset to fine-tune the smaller caltech-256 dataset as an example. But note that it can be used to other datasets as well, even for quite different applications such as face identification. 

We will show that, even with simple hyper-parameters setting, we can match and even outperform state-of-the-art results on caltech-256.

| Network | Accuracy | 
| --- | --- | 
| Resnet-50 | 77.4% | 
| Resnet-152 | 86.4% | 

## Prepare data

We follow the standard protocal to sample 60 images from each class as the training set, and the rest for the validation set. We resize images into 256x256 size and pack them into the rec file. The scripts to prepare the data is as following. 


```sh
wget http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar
tar -xf 256_ObjectCategories.tar

mkdir -p caltech_256_train_60
for i in 256_ObjectCategories/*; do
    c=`basename $i`
    mkdir -p caltech_256_train_60/$c
    for j in `ls $i/*.jpg | shuf | head -n 60`; do
        mv $j caltech_256_train_60/$c/
    done
done

python ~/mxnet/tools/im2rec.py --list True --recursive True caltech-256-60-train caltech_256_train_60/
python ~/mxnet/tools/im2rec.py --list True --recursive True caltech-256-60-val 256_ObjectCategories/
python ~/mxnet/tools/im2rec.py --resize 256 --quality 90 --num-thread 16 caltech-256-60-val 256_ObjectCategories/
python ~/mxnet/tools/im2rec.py --resize 256 --quality 90 --num-thread 16 caltech-256-60-train caltech_256_train_60/
```

The following codes download the pre-generated rec files. It may take a few minutes.

In [1]:
#import os, urllib
#def download(url):
#    filename = url.split("/")[-1]
#    if not os.path.exists(filename):
#        urllib.urlretrieve(url, filename)
#download('http://data.mxnet.io/data/caltech-256/caltech-256-60-train.rec')
#download('http://data.mxnet.io/data/caltech-256/caltech-256-60-val.rec')

Next we define the function which returns the data iterators.

In [29]:
import mxnet as mx

def get_iterators(batch_size, data_shape=(3, 224, 224)):
    train = mx.io.ImageRecordIter(
        path_imgrec         = 'caltech-256-60-train.rec',
        data_name           = 'data',
        label_name          = 'softmax_label',
        batch_size          = batch_size,
        data_shape          = data_shape,
        shuffle             = True,
        rand_crop           = True,
        rand_mirror         = True)
    val = mx.io.ImageRecordIter(
        path_imgrec         = 'caltech-256-60-val.rec',
        data_name           = 'data',
        label_name          = 'softmax_label',
        batch_size          = batch_size,
        data_shape          = data_shape,
        rand_crop           = False,
        rand_mirror         = False)
    return (train, val)

We then download a pretrained 50-layer ResNet model and load into memory. 

Note. If `load_checkpoint` reports error, we can remove the downloaded files and try `get_model` again.

In [16]:
#def get_model(prefix, epoch):
#    download(prefix+'-symbol.json')
#    download(prefix+'-%04d.params' % (epoch,))

# get_model('http://data.mxnet.io/models/imagenet/resnet/50-layers/resnet-50', 0)
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50', 0)

In [17]:
mx.model.load_checkpoint('resnet-50',0)

(<mxnet.symbol.Symbol at 0x81d9f60>,
 {'bn0_beta': <mxnet.ndarray.NDArray at 0x8263198>,
  'bn0_gamma': <mxnet.ndarray.NDArray at 0x81a8f28>,
  'bn1_beta': <mxnet.ndarray.NDArray at 0x81bc5c0>,
  'bn1_gamma': <mxnet.ndarray.NDArray at 0x81bcb38>,
  'bn_data_beta': <mxnet.ndarray.NDArray at 0x81a5d68>,
  'bn_data_gamma': <mxnet.ndarray.NDArray at 0x81b1f28>,
  'conv0_weight': <mxnet.ndarray.NDArray at 0x81b1160>,
  'fc1_bias': <mxnet.ndarray.NDArray at 0x81bc048>,
  'fc1_weight': <mxnet.ndarray.NDArray at 0x82630b8>,
  'stage1_unit1_bn1_beta': <mxnet.ndarray.NDArray at 0x81bc320>,
  'stage1_unit1_bn1_gamma': <mxnet.ndarray.NDArray at 0x81b19b0>,
  'stage1_unit1_bn2_beta': <mxnet.ndarray.NDArray at 0x81bc240>,
  'stage1_unit1_bn2_gamma': <mxnet.ndarray.NDArray at 0x81a8ba8>,
  'stage1_unit1_bn3_beta': <mxnet.ndarray.NDArray at 0x81b1cf8>,
  'stage1_unit1_bn3_gamma': <mxnet.ndarray.NDArray at 0x81bc278>,
  'stage1_unit1_conv1_weight': <mxnet.ndarray.NDArray at 0x81a88d0>,
  'stage1_unit1_

## Train

We first define a function which replaces the the last fully-connected layer for a given network. 

In [18]:
def get_fine_tune_model(symbol, arg_params, num_classes, layer_name='flatten0'):
    """
    symbol: the pre-trained network symbol
    arg_params: the argument parameters of the pre-trained model
    num_classes: the number of classes for the fine-tune datasets
    layer_name: the layer name before the last fully-connected layer
    """
    all_layers = sym.get_internals()
    net = all_layers[layer_name+'_output']
    net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc1')
    net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
    new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})
    return (net, new_args)

Now we create a module. We first call `init_params` to randomly initialize parameters, next use `set_params` to replace all parameters except for the last fully-connected layer with pre-trained model. 

In [19]:
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

def fit(symbol, arg_params, aux_params, train, val, batch_size, num_gpus):
    devs = [mx.gpu(i) for i in range(num_gpus)]
    mod = mx.mod.Module(symbol=new_sym, context=devs)
    mod.bind(data_shapes=train.provide_data, label_shapes=train.provide_label)
    mod.init_params(initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2))
    mod.set_params(new_args, aux_params, allow_missing=True)
    mod.fit(train, val, 
        num_epoch=8,
        batch_end_callback = mx.callback.Speedometer(batch_size, 10),        
        kvstore='device',
        optimizer='sgd',
        optimizer_params={'learning_rate':0.01},
        eval_metric='acc')

Then we can start training. We use AWS EC2 g2.8xlarge, which has 8 GPUs.

In [30]:
num_classes = 256
batch_per_gpu = 16
num_gpus = 8

(new_sym, new_args) = get_fine_tune_model(sym, arg_params, num_classes)

batch_size = batch_per_gpu * num_gpus
(train, val) = get_iterators(batch_size)
fit(new_sym, new_args, aux_params, train, val, batch_size, num_gpus)

MXNetError: [19:06:10] D:\chhong\mxnet\src\storage\storage.cc:44: Please compile with CUDA enabled


As can be seen, even for 8 data epochs, we can get 78% validation accuracy. It matches the state-of-the-art results training on caltech-256 alone, e.g. [VGG](http://www.robots.ox.ac.uk/~vgg/research/deep_eval/). 

Next we try to use another pretrained model. It uses the complete Imagenet dataset, which is 10x larger than the Imagenet 1K classes one, and is trained with a 3x deeper Resnet network. 

In [8]:
get_model('http://data.mxnet.io/models/imagenet-11k/resnet-152/resnet-152', 0)
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-152', 0)
(new_sym, new_args) = get_fine_tune_model(sym, arg_params, num_classes)
fit(new_sym, new_args, aux_params, train, val, batch_size, num_gpus)

2016-10-22 18:35:42,274 Already binded, ignoring bind()
2016-10-22 18:35:55,659 Epoch[0] Batch [10]	Speed: 139.63 samples/sec	Train-accuracy=0.070312
2016-10-22 18:36:04,814 Epoch[0] Batch [20]	Speed: 139.83 samples/sec	Train-accuracy=0.349219
2016-10-22 18:36:13,991 Epoch[0] Batch [30]	Speed: 139.49 samples/sec	Train-accuracy=0.585156
2016-10-22 18:36:23,163 Epoch[0] Batch [40]	Speed: 139.57 samples/sec	Train-accuracy=0.642188
2016-10-22 18:36:32,309 Epoch[0] Batch [50]	Speed: 139.97 samples/sec	Train-accuracy=0.728906
2016-10-22 18:36:41,426 Epoch[0] Batch [60]	Speed: 140.41 samples/sec	Train-accuracy=0.760156
2016-10-22 18:36:50,531 Epoch[0] Batch [70]	Speed: 140.60 samples/sec	Train-accuracy=0.778906
2016-10-22 18:36:59,631 Epoch[0] Batch [80]	Speed: 140.68 samples/sec	Train-accuracy=0.786719
2016-10-22 18:37:08,742 Epoch[0] Batch [90]	Speed: 140.51 samples/sec	Train-accuracy=0.797656
2016-10-22 18:37:17,857 Epoch[0] Batch [100]	Speed: 140.45 samples/sec	Train-accuracy=0.823438
201

As can be seen, even for a single data epoch, it reaches 83% validation accuracy. After 8 epoches, the validation accuracy increases to 86.4%. 