# 迁移学习(Xception)

In [None]:
import os
import shutil
from keras.models import *
from keras.layers import *
from keras.applications import *
from keras.preprocessing.image import *
from keras.applications import *
import matplotlib.pyplot as plt
import pandas as pd
import cv2
import numpy as np

In [None]:
#参数

train_file_name='train'
valid_file_name='valid'
IM_WIDTH, IM_HEIGHT = 299, 299 #fixed size for Xception
batch_size=128

nb_epoch =5
FC_SIZE = 128

In [None]:
#图片分类，创建软链接
train_filenames = os.listdir('imagedata/train')
test_filenames = os.listdir('imagedata/test')
valid_count=int(len(train_filenames)*0.1)

valid_data=train_filenames[-valid_count:]
train_data=train_filenames[:-valid_count]

train_cat = filter(lambda x:x[:3] == 'cat', train_data)
train_dog = filter(lambda x:x[:3] == 'dog', train_data)

valid_cat = filter(lambda x:x[:3] == 'cat', valid_data)
valid_dog = filter(lambda x:x[:3] == 'dog', valid_data)

def cre_rem_dir(dirname):
    if os.path.exists(dirname):
        shutil.rmtree(dirname)
    os.mkdir(dirname)

cre_rem_dir(train_file_name)
os.mkdir(train_file_name+'/cat')
os.mkdir(train_file_name+'/dog')

cre_rem_dir(valid_file_name)
os.mkdir(valid_file_name+'/cat')
os.mkdir(valid_file_name+'/dog')

cre_rem_dir('test')
cur_path=os.getcwd()

for filename in train_cat:
    os.symlink(cur_path+'/imagedata/train/'+filename, train_file_name+'/cat/'+filename)
for filename in train_dog:
    os.symlink(cur_path+'/imagedata/train/'+filename, train_file_name+'/dog/'+filename)
for filename in valid_cat:
    os.symlink(cur_path+'/imagedata/train/'+filename, valid_file_name+'/cat/'+filename)
for filename in valid_dog:
    os.symlink(cur_path+'/imagedata/train/'+filename, valid_file_name+'/dog/'+filename)
for filename in test_filenames:
    os.symlink(cur_path+'/imagedata/test/'+filename, 'test/'+filename)

In [None]:
#数据增强
train_datagen =  ImageDataGenerator(
    preprocessing_function=Xception.preprocess_input,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
)

valid_datagen = ImageDataGenerator(
    preprocessing_function=Xception.preprocess_input,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
)


train_generator = train_datagen.flow_from_directory(
  'train',
  target_size=(IM_WIDTH, IM_HEIGHT),
  batch_size=batch_size,
)
validation_generator = valid_datagen.flow_from_directory(
  'valid',
  target_size=(IM_WIDTH, IM_HEIGHT),
  batch_size=batch_size,
)



In [None]:
base_model = Xception(weights='imagenet', include_top=False)

#添加新层
def add_new_last_layer(base_model, nb_classes):
    
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(FC_SIZE, activation='relu')(x) 
    predictions = Dense(nb_classes, activation='softmax')(x) 
    model = Model(input=base_model.input, output=predictions)
    return model


def setup_to_transfer_learn(model, base_model):
    for layer in base_model.layers:
        layer.trainable = False
        model.compile(optimizer='rmsprop',
                      loss='categorical_crossentropy', 
                      metrics=['accuracy'])

        
model=add_new_last_layer(base_model,2)

setup_to_transfer_learn(model,base_model)

In [None]:
history = model.fit_generator(
  train_generator,
  nb_epoch=nb_epoch,
  validation_data=validation_generator,
  class_weight='auto')
model.save('Xception')

In [None]:
def plot_training(history):
    acc = history.history['acc']
    val_acc = history.history['val_acc']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs = range(len(acc))

    plt.plot(epochs, acc, 'r.')
    plt.plot(epochs, val_acc, 'r')
    plt.title('Training and validation accuracy')

    plt.figure()
    plt.plot(epochs, loss, 'r.')
    plt.plot(epochs, val_loss, 'r-')
    plt.title('Training and validation loss')
  
    plt.show()

plot_training(history)  

In [None]:
from keras.preprocessing import image

#加载图像并进行预测
#读取测试图片
path='/imagedata/test'
current_path=os.getcwd()
abosult_path=current_path+path
dataframe = pd.read_csv(current_path+'/imagedata/sample_submission.csv')
    
for image_name in os.listdir(abosult_path):
    img = image.load_img(current_path+path+"/"+image_name,target_size=(299,299))
    index=int(image_name[:-4])
    
    # 图像预处理
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = Xception.preprocess_input(x)
    
    
    
    # 对图像进行分类
    preds = model.predict(x)
    #preds=preds.clip(min=0.0005,max=0.9995)
    pred=0
    if preds[0][0]>preds[0][1]:
        pred=0.005
    else:
        pred=0.995
    dataframe.set_value(index-1,'label',pred)
    # 输出预测概率
print (dataframe)
    
    
dataframe.to_csv('test.csv',index=False,sep=',')