Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
314 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
|
||
[上一篇文章](http://blog.csdn.net/u012162613/article/details/45397033)总结了Keras的基本使用方法,相信用过的同学都会觉得不可思议,太简洁了。十多天前,我在github上发现这个框架的时候,关注Keras的人还比较少,这两天无论是github还是微薄,都看到越来越多的人关注和使用Keras。所以这篇文章就简单地再介绍一下Keras的使用,方便各位入门。 | ||
|
||
主要包括以下三个内容: | ||
|
||
- 训练CNN并保存训练好的模型。 | ||
- 将CNN用于特征提取,用提取出来的特征训练SVM。 | ||
- 可视化CNN卷积层后的特征图。 | ||
|
||
仍然以Mnist为例,代码中用的Mnist数据到这里下载 | ||
[http://pan.baidu.com/s/1qCdS6](http://pan.baidu.com/s/1qCdS6),本文的代码在我的github上:[dive_into _keras](https://github.com/wepe/MachineLearning/tree/master/DeepLearning%20Tutorials) | ||
|
||
|
||
---------- | ||
|
||
|
||
###1. 加载数据 | ||
|
||
数据是图片格式,利用pyhton的PIL模块读取,并转为numpy.array类型。这部分的代码在`data.py`里: | ||
|
||
|
||
---------- | ||
|
||
|
||
###2. 训练CNN并保存训练好的CNN模型 | ||
|
||
将上一步加载进来的数据分为训练数据(X_train,30000个样本)和验证数据(X_val,12000个样本),构建CNN模型并训练。训练过程中,每一个epoch得到的val-accuracy都不一样,我们保存达到最好的val-accuracy时的模型,利用Python的cPickle模块保持。(Keras的开发者最近在添加用hdf5保持模型的功能,我试了一下,没用成功,去github发了issue也没人回,估计还没完善,hdf5压缩率会更高,保存下来的文件会更小。) | ||
|
||
这部分的代码在`cnn.py`里,运行: | ||
|
||
``` | ||
python cnn.py | ||
``` | ||
|
||
在第Epoch 4得到96.45%的validation accuracy,运行完后会得到model.pkl这份文件,保存的就是96.45%对应的模型: | ||
|
||
![这里写图片描述](http://img.blog.csdn.net/20150508155724085) | ||
|
||
|
||
---------- | ||
|
||
|
||
###3.将CNN用于特征提取,用提取出来的特征训练SVM | ||
|
||
上一步得到了一个val-accuracy为96.45%的CNN模型,在一些论文中经常会看到用CNN的全连接层的输出作为特征,然后去训练其他分类器。这里我也试了一下,用全连接层的输出作为样本的特征向量,训练SVM。SVM用的是scikit learn里的算法。 | ||
|
||
这部分代码在`cnn-svm.py`,运行: | ||
|
||
``` | ||
python cnn-svm.py | ||
``` | ||
|
||
得到下图的输出,可以看到,cnn-svm的准确率提高到97.89%: | ||
|
||
![这里写图片描述](http://img.blog.csdn.net/20150508155806689) | ||
|
||
|
||
---------- | ||
|
||
|
||
###4.可视化CNN卷积层后的特征图 | ||
|
||
将卷积层和全连接层后的特征图、特征向量以图片形式展示出来,用到matplotlib这个库。这部分代码在`get_feature_map.py`里。运行: | ||
|
||
``` | ||
python get_feature_map.py | ||
``` | ||
|
||
得到全连接层的输出,以及第一个卷积层输出的4个特征图: | ||
|
||
![全连接层后的输出](http://img.blog.csdn.net/20150508155842678) | ||
|
||
![这里写图片描述](http://img.blog.csdn.net/20150508155724909) | ||
|
||
![这里写图片描述](http://img.blog.csdn.net/20150508155810914) | ||
|
||
![这里写图片描述](http://img.blog.csdn.net/20150508155833190) | ||
|
||
![这里写图片描述](http://img.blog.csdn.net/20150508160043578) | ||
|
||
|
||
---------- | ||
转载请注明出处:http://blog.csdn.net/u012162613/article/details/45581421 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
''' | ||
Author:wepon | ||
Code:https://github.com/wepe | ||
File: cnn-svm.py | ||
''' | ||
from __future__ import print_function | ||
import cPickle | ||
import theano | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.svm import SVC | ||
from data import load_data | ||
|
||
|
||
def svc(traindata,trainlabel,testdata,testlabel): | ||
print("Start training SVM...") | ||
svcClf = SVC(C=1.0,kernel="rbf",cache_size=3000) | ||
svcClf.fit(traindata,trainlabel) | ||
|
||
pred_testlabel = svcClf.predict(testdata) | ||
num = len(pred_testlabel) | ||
accuracy = len([1 for i in range(num) if testlabel[i]==pred_testlabel[i]])/float(num) | ||
print("cnn-svm Accuracy:",accuracy) | ||
|
||
def rf(traindata,trainlabel,testdata,testlabel): | ||
print("Start training Random Forest...") | ||
rfClf = RandomForestClassifier(n_estimators=400,criterion='gini') | ||
rfClf.fit(traindata,trainlabel) | ||
|
||
pred_testlabel = rfClf.predict(testdata) | ||
num = len(pred_testlabel) | ||
accuracy = len([1 for i in range(num) if testlabel[i]==pred_testlabel[i]])/float(num) | ||
print("cnn-rf Accuracy:",accuracy) | ||
|
||
if __name__ == "__main__": | ||
#load data,split into traindata and testdata | ||
data, label = load_data() | ||
(traindata,testdata) = (data[0:30000],data[30000:]) | ||
(trainlabel,testlabel) = (label[0:30000],label[30000:]) | ||
#use origin_model to predict testdata | ||
origin_model = cPickle.load(open("model.pkl","rb")) | ||
pred_testlabel = origin_model.predict_classes(testdata,batch_size=1, verbose=1) | ||
num = len(testlabel) | ||
accuracy = len([1 for i in range(num) if testlabel[i]==pred_testlabel[i]])/float(num) | ||
print(" Origin_model Accuracy:",accuracy) | ||
#define theano funtion to get output of FC layer | ||
get_feature = theano.function([origin_model.layers[0].input],origin_model.layers[11].output(train=False),allow_input_downcast=False) | ||
feature = get_feature(data) | ||
#train svm using FC-layer feature | ||
svc(feature[0:30000],label[0:30000],feature[30000:],label[30000:]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
#coding:utf-8 | ||
|
||
''' | ||
Author:wepon | ||
Code:https://github.com/wepe | ||
File:cnn.py | ||
GPU run command: | ||
THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python cnn.py | ||
CPU run command: | ||
python cnn.py | ||
''' | ||
#导入各种用到的模块组件 | ||
from __future__ import absolute_import | ||
from __future__ import print_function | ||
from keras.models import Sequential | ||
from keras.layers.core import Dense, Dropout, Activation, Flatten | ||
from keras.layers.convolutional import Convolution2D, MaxPooling2D | ||
from keras.optimizers import SGD | ||
from keras.utils import np_utils, generic_utils | ||
from six.moves import range | ||
from data import load_data | ||
import random,cPickle | ||
|
||
|
||
nb_epoch = 5 | ||
batch_size = 100 | ||
nb_class = 10 | ||
#加载数据 | ||
data, label = load_data() | ||
|
||
#label为0~9共10个类别,keras要求格式为binary class matrices,转化一下,直接调用keras提供的这个函数 | ||
label = np_utils.to_categorical(label, nb_class) | ||
|
||
#14layer, one "add" represent one layer | ||
def create_model(): | ||
model = Sequential() | ||
model.add(Convolution2D(4, 1, 5, 5, border_mode='valid')) | ||
model.add(Activation('relu')) | ||
|
||
model.add(Convolution2D(8,4, 3, 3, border_mode='valid')) | ||
model.add(Activation('relu')) | ||
model.add(MaxPooling2D(poolsize=(2, 2))) | ||
|
||
model.add(Convolution2D(16, 8, 3, 3, border_mode='valid')) | ||
model.add(Activation('relu')) | ||
model.add(MaxPooling2D(poolsize=(2, 2))) | ||
|
||
model.add(Flatten()) | ||
model.add(Dense(16*4*4, 128, init='normal')) | ||
model.add(Activation('relu')) | ||
model.add(Dropout(0.2)) | ||
|
||
model.add(Dense(128, nb_class, init='normal')) | ||
model.add(Activation('softmax')) | ||
return model | ||
|
||
|
||
############# | ||
#开始训练模型 | ||
############## | ||
model = create_model() | ||
sgd = SGD(l2=0.0,lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) | ||
model.compile(loss='categorical_crossentropy', optimizer=sgd) | ||
|
||
(X_train,X_val) = (data[0:30000],data[30000:]) | ||
(Y_train,Y_val) = (label[0:30000],label[30000:]) | ||
best_accuracy = 0.0 | ||
for e in range(nb_epoch): | ||
#shuffle the data each epoch | ||
num = len(Y_train) | ||
index = [i for i in range(num)] | ||
random.shuffle(index) | ||
X_train = X_train[index] | ||
Y_train = Y_train[index] | ||
|
||
print('Epoch', e) | ||
print("Training...") | ||
batch_num = len(Y_train)/batch_size | ||
progbar = generic_utils.Progbar(X_train.shape[0]) | ||
for i in range(batch_num): | ||
loss,accuracy = model.train(X_train[i*batch_size:(i+1)*batch_size], Y_train[i*batch_size:(i+1)*batch_size],accuracy=True) | ||
progbar.add(batch_size, values=[("train loss", loss),("train accuracy:", accuracy)] ) | ||
|
||
#save the model of best val-accuracy | ||
print("Validation...") | ||
val_loss,val_accuracy = model.evaluate(X_val, Y_val, batch_size=1,show_accuracy=True) | ||
if best_accuracy<val_accuracy: | ||
best_accuracy = val_accuracy | ||
cPickle.dump(model,open("./model.pkl","wb")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#coding:utf-8 | ||
""" | ||
Author:wepon | ||
Code:https://github.com/wepe | ||
File: data.py | ||
download data here: http://pan.baidu.com/s/1qCdS6 | ||
""" | ||
|
||
|
||
import os | ||
from PIL import Image | ||
import numpy as np | ||
|
||
#读取文件夹mnist下的42000张图片,图片为灰度图,所以为1通道,如果是将彩色图作为输入,则将1替换为3,图像大小28*28 | ||
def load_data(): | ||
data = np.empty((42000,1,28,28),dtype="float32") | ||
label = np.empty((42000,),dtype="uint8") | ||
imgs = os.listdir("./mnist") | ||
num = len(imgs) | ||
for i in range(num): | ||
img = Image.open("./mnist/"+imgs[i]) | ||
arr = np.asarray(img,dtype="float32") | ||
data[i,:,:,:] = arr | ||
label[i] = int(imgs[i].split('.')[0]) | ||
#归一化和零均值化 | ||
scale = np.max(data) | ||
data /= scale | ||
mean = np.std(data) | ||
data -= mean | ||
return data,label | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
""" | ||
Author:wepon | ||
Code:https://github.com/wepe | ||
File: get_feature_map.py | ||
1. visualize feature map of Convolution Layer, Fully Connected layer | ||
2. rewrite the code so you can treat CNN as feature extractor, see file: cnn-svm.py | ||
""" | ||
from __future__ import print_function | ||
import cPickle,theano | ||
from data import load_data | ||
import matplotlib.pyplot as plt | ||
import matplotlib.cm as cm | ||
|
||
#load the saved model | ||
model = cPickle.load(open("model.pkl","rb")) | ||
|
||
#define theano funtion to get output of FC layer | ||
get_feature = theano.function([model.layers[0].input],model.layers[11].output(train=False),allow_input_downcast=False) | ||
|
||
#define theano funtion to get output of first Conv layer | ||
get_featuremap = theano.function([model.layers[0].input],model.layers[2].output(train=False),allow_input_downcast=False) | ||
|
||
|
||
data, label = load_data() | ||
|
||
# visualize feature of Fully Connected layer | ||
#data[0:10] contains 10 images | ||
feature = get_feature(data[0:10]) #visualize these images's FC-layer feature | ||
plt.imshow(feature,cmap = cm.Greys_r) | ||
plt.show() | ||
|
||
#visualize feature map of Convolution Layer | ||
num_fmap = 4 #number of feature map | ||
for i in range(num_fmap): | ||
featuremap = get_featuremap(data[0:10]) | ||
plt.imshow(featuremap[0][i],cmap = cm.Greys_r) #visualize the first image's 4 feature map | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters