# 图像识别迁移学习
使用VGG16, InceptionV3, Resnet50, Resnet152等等预训练好的模型，迁移学习调优自己的图像识别数据集。

自己的数据集整理成以下的格式:
![](./data_structure.png)

大家可以把自己的数据集整理成上述格式，这里以Oxford 102花数据集为例。

### 下载数据与预处理

In [None]:
#!/usr/bin/env python
import os
import glob
import tarfile
import numpy as np
from scipy.io import loadmat
from shutil import copyfile, rmtree
import sys
import config

if sys.version_info[0] >= 3:
    from urllib.request import urlretrieve
else:
    from urllib import urlretrieve

data_path = 'data'


def download_file(url, dest=None):
    if not dest:
        dest = os.path.join(data_path, url.split('/')[-1])
    urlretrieve(url, dest)


# Download the Oxford102 dataset into the current directory
if not os.path.exists(data_path):
    os.mkdir(data_path)

flowers_archive_path = os.path.join(data_path, '102flowers.tgz')
if not os.path.isfile(flowers_archive_path):
    print ('Downloading images...')
    download_file('http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz')
tarfile.open(flowers_archive_path).extractall(path=data_path)

image_labels_path = os.path.join(data_path, 'imagelabels.mat')
if not os.path.isfile(image_labels_path):
    print("Downloading image labels...")
    download_file('http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat')

setid_path = os.path.join(data_path, 'setid.mat')
if not os.path.isfile(setid_path):
    print("Downloading train/test/valid splits...")
    download_file('http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat')

# Read .mat file containing training, testing, and validation sets.
setid = loadmat(setid_path)

idx_train = setid['trnid'][0] - 1
idx_test = setid['tstid'][0] - 1
idx_valid = setid['valid'][0] - 1

# Read .mat file containing image labels.
image_labels = loadmat(image_labels_path)['labels'][0]

# Subtract one to get 0-based labels
image_labels -= 1

files = sorted(glob.glob(os.path.join(data_path, 'jpg', '*.jpg')))
labels = np.array([i for i in zip(files, image_labels)])

# Get current working directory for making absolute paths to images
cwd = os.path.dirname(os.path.realpath(__file__))

if os.path.exists(config.data_dir):
    rmtree(config.data_dir, ignore_errors=True)
os.mkdir(config.data_dir)


def move_files(dir_name, labels):
    cur_dir_path = os.path.join(config.data_dir, dir_name)
    if not os.path.exists(cur_dir_path):
        os.mkdir(cur_dir_path)

    for i in range(0, 102):
        class_dir = os.path.join(config.data_dir, dir_name, str(i))
        os.mkdir(class_dir)

    for label in labels:
        src = str(label[0])
        dst = os.path.join(cwd, config.data_dir, dir_name, label[1], src.split(os.sep)[-1])
        copyfile(src, dst)


move_files('train', labels[idx_test, :])
move_files('test', labels[idx_train, :])
move_files('valid', labels[idx_valid, :])

### 迁移学习与调优

In [None]:
import numpy as np
import argparse
import traceback
import os

np.random.seed(1337)  # for reproducibility

import util
import config


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', help='Path to data dir')
    parser.add_argument('--model', type=str, required=True, help='Base model architecture', choices=[
        config.MODEL_RESNET50,
        config.MODEL_RESNET152,
        config.MODEL_INCEPTION_V3,
        config.MODEL_VGG16])
    parser.add_argument('--nb_epoch', type=int, default=1000)
    parser.add_argument('--freeze_layers_number', type=int, help='will freeze the first N layers and unfreeze the rest')
    return parser.parse_args()


def init():
    util.lock()
    util.set_img_format()
    util.override_keras_directory_iterator_next()
    util.set_classes_from_train_dir()
    util.set_samples_info()
    if not os.path.exists(config.trained_dir):
        os.mkdir(config.trained_dir)


def train(nb_epoch, freeze_layers_number):
    model = util.get_model_class_instance(
        class_weight=util.get_class_weight(config.train_dir),
        nb_epoch=nb_epoch,
        freeze_layers_number=freeze_layers_number)
    model.train()
    print('Training is finished!')


if __name__ == '__main__':
    try:
        args = parse_args()
        if args.data_dir:
            config.data_dir = args.data_dir
            config.set_paths()
        if args.model:
            config.model = args.model

        init()
        train(args.nb_epoch, args.freeze_layers_number)
    except Exception as e:
        print(e)
        traceback.print_exc()
    finally:
        util.unlock()