# 迁移学习---使用google Inception-v3预训练模型实现花的分类

## 1. 下载数据集

```
curl http://download.tensorflow.org/example_images/flower_photos.tgz
tar xzf flower_photos.tgz
```

## 2.下载预训练的Inception-v3 模型

```
wget https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip
unzip tensorflow/examples/label_image/data/inception_dec_2015.zip
```

## 3. 实现迁移学习代码

In [2]:
import glob
import os.path
import random
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile

In [3]:
BOTTLENECK_TENSOR_SIZE = 2048
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'

JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'

#下载的Inception-v3 模型文件目录和模型文件名
MODEL_DIR = '/path/to/model'
MODEL_FILE = 'classify_image_graph_def.pb'

CACHE_DIR = '/tmp/bottleneck'

# 图片数据文件夹
INPUT_DATA = '/path/to/flower_data'
# 验证集和测试集的数据百分比
VALIDATION_PERCENTAGE = 10
TEST_PERCENTAGE = 10

# 神经网络的设置
LEARNING_RATE = 0.01
STEPS = 4000
BATCH = 100

In [4]:
# 从数据集中读取所有图片列表，并按train, validation, test 分开
def create_image_lists(testing_percentage, validation_percentage):
    result = {}
    #获取当前目录下所有子目录
    sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
    print('sub_dirs', sub_dirs)
    print('os_walk',os.walk(INPUT_DATA))
    is_root_dir = True
    for sub_dir in sub_dirs:
        if is_root_dir:
            is_root_dir = False
            continue
        # 获取当前目录下所有有效图片文件
        extensions = ['jpg', 'jpeg', 'JPG','JPEG']
        file_list =[]
        dir_name = os.path.basename(sub_dir)
        for extension in extensions:
            file_glob =  os.path.join(INPUT_DATA, dir_name, '*.'+extension)
            file_list.extend(glob.glob(file_glob))
        if not file_list:
            continue
        
        # 通过目录名获取类别名
        label_name = dir_name.lower()
        training_images = []
        testing_images = []
        validation_images = []
        for file_name in file_list:
            base_name= os.path.basename(file_name)
            chance = np.random.randint(100)
            if chance< validation_percentage:
                validation_images.append(base_name)
            elif chance<(validation_percentage+testing_percentage):
                testing_images.append(base_name)
            else:
                training_images.append(base_name)
                
        result[label_name] = {
            'dir': dir_name,
            'training': training_images,
            'testing': testing_images,
            'validation': validation_images,
        }
    return result

In [None]:
create_image_lists(TEST_PERCENTAGE, VALIDATION_PERCENTAGE)

In [5]:
#获取图片地址
def get_image_path(image_lists, image_dir, label_name,index, category):
    label_lists = image_lists[label_name]
    category_list = label_lists[category]
    mod_index = index%len(category_list)\
    # get image name
    base_name = category_list[mod_index]
    sub_dir = label_lists['dir']
    full_path =os.path.join(image_dir, sub_dir, base_name)
    return full_path

In [6]:
#获取 经Inception-v3 模型处理后的特征向量文件地址
def get_bottleneck_path(image_lists, label_name, index, category):
    return get_image_path(image_lists, CACHE_DIR, label_name, index, category) + '.txt'

# 用Inception-v3模型处理一张图片，获得其特征向量
def run_bottleneck_on_image(sess, image_data, image_data_tensor, bottleneck_tensor):
    bottleneck_values = sess.run(bottleneck_tensor, {image_data_tensor:image_data})
    print(bottleneck_values.shape)
    bottleneck_values = np.squeeze(bottleneck_values)
    return bottleneck_values


In [7]:
# 获得图片经过Inception-V3模型处理之后的特征向量
def get_or_create_bottleneck(sess, image_lists, label_name, index, category, jpeg_data_tensor, bottleneck_tensor):
    label_lists = image_lists[label_name]
    sub_dir = label_lists['dir']
    sub_dir_path = os.path.join(CACHE_DIR, sub_dir)
    if not os.path.exists(sub_dir_path):
        os.mkdirs(sub_dir_path)
    bottleneck_path = get_bottleneck_path(image_lists, label_name, index, category)
    
    if not os.path.exists(bottleneck_path):
        image_path = get_image_path(image_lists, INPUT_DATA, label_name, index, category)
        # get image data
        image_data = gfile.FastGFile(image_path, 'rb').read()
        bottleneck_values = run_bottleneck_on_image(sess, image_data, jpeg_data_tensor, bottleneck_tensor)
        bottleneck_string = ','.join(str(x) for x in bottleneck_values)
        with open(bottleneck_path, 'w') as bottleneck_file:
            bottleneck_file.write(bottleneck_string)
    else:
        with open(bottleneck_path, 'r') as bottleneck_file:
            bottleneck_string = bottleneck_file.read()
        bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
        
    return bottleneck_values