根据汽车安全部门的调查显示,五分之一的交通事故都是由于驾驶员分心(distracted)造成的。每年,distracted driving会造成约42500人受伤,3000人死亡。这个数字非常惊人。
State Farm希望通过车载的dashboard cameras来检测用户是否处于distracted driving的状态,从而发出警告。
这是我实际的开发过程,我不会上来就写出最优的解决方案,而是把我所遇到的“坑”都给写出来,这些“坑”真的很经典
由于训练量比较大,所以本项目的所有训练均是通过租用Amazon EC2进行训练,价格适中,下文有介绍。
数据集来自于kaggle,总共有三个文件需要下载,如下所示。其中imgs.zip是通过摄像头来抓取的驾驶员的状态的标记数据集。该数据集的大小有4G。
- imgs.zip - 所有训练/测试图片打包的zip文件(你需要自行下载)
- sample_submission.csv - 提交kaggle时候的格式
- driver_imgs_list.csv - 文件的信息,文件名对应的图像中的司机ID以及图像中司机的状态ID。
现在解压imgs.zip来观察
from keras.layers import Input
from keras.layers.core import Lambda
from keras.layers import Dense, GlobalAveragePooling2D, Dropout
from keras.models import Model
from keras.preprocessing.image import ImageDataGenerator
from urllib.request import urlretrieve
from os.path import isfile, isdir
from tqdm import tqdm
import zipfile
import os
import h5py
import numpy as np
zip_name = 'imgs.zip'
train_dir_name = 'train'
test_dir_name = 'test'
link_path = 'train_link'
link_train_path = 'train_link/train'
link_valid_path = 'train_link/validation'
test_link = 'test_link'
test_link_path = 'test_link/data'
resnet_50_model_save_name = 'model_resnet50.h5'
inceptionv3_model_save_name = 'model_inceptionv3.h5'
xception_model_save_name = 'model_xception.h5'
## check if the train and test data is exist
if not isdir(train_dir_name) or not isdir(test_dir_name):
if not isfile(zip_name):
print ("Please download imgs.zip from kaggle!")
assert(False)
else:
with zipfile.ZipFile(zip_name) as azip:
print ("Now to extract %s " % (zip_name))
azip.extractall()
print ("Data is ready!")
Data is ready!
我们将检测驾驶员10种驾驶员的驾驶状态,如下
- c0: 安全驾驶
- c1: 右手打字
- c2: 右手打电话
- c3: 左手打字
- c4: 左手打电话
- c5: 调收音机
- c6: 喝饮料
- c7: 拿后面的东西
- c8: 整理头发和化妆
- c9: 和其他乘客说话
其中,关于每一种驾驶状态的图片都分开存放,也就是有c0-c9文件夹存放各自状态的图片。此时文件夹的目录结构大概如下
|----imgs.zip
|----train
|-----c0
|-----c1
|-----c2
|-----c3
|-----c4
|-----c5
|-----c6
|-----c7
|-----c8
|-----c9
|----test
接下来就是了解数据集的基本信息:
- 统计训练测试样本数量
- 每一类训练数据的数量分布
import os
## get train and file nums
train_class_dir_names = os.listdir(train_dir_name)
test_size = len(os.listdir(test_dir_name))
train_size = 0
train_class_size = {}
for dname in train_class_dir_names:
file_names = os.listdir(train_dir_name + '/' + dname)
train_class_size[dname] = len(file_names)
train_size = train_class_size[dname] + train_size
print ("Test file numbers: ", test_size)
print ("Train file numbers: ", train_size)
Test file numbers: 79726
Train file numbers: 22424
import matplotlib.pyplot as plt
%matplotlib inline
fig = plt.figure()
plt.bar(train_class_size.keys(), train_class_size.values(), 0.4, color="green")
plt.xlabel("Classes")
plt.ylabel("File nums")
plt.title("Classes distribution")
plt.show()
从上面的结果可以看出,我们总共有79726张测试图片,22424张测试图片。在训练图片中,每一个状态(类)包含大约2000张图片,分布还是比较均匀的。上面的可视化是根据状态来显示的,现在我们从另一个角度,看看每个司机大约包含了多少张图片。
接下来就要解压文件driver_imgs_list.csv.zip
# get driver_imgs_list_file
driver_imgs_list_zip = 'driver_imgs_list.csv.zip'
driver_imgs_list_file = 'driver_imgs_list.csv'
if not isfile(driver_imgs_list_file):
if not isfile(driver_imgs_list_zip):
print ("Please download river_imgs_list.csv.zip from kaggle!")
assert(False)
else:
with zipfile.ZipFile(driver_imgs_list_zip) as azip:
print ("Now to extract %s " % (driver_imgs_list_zip))
azip.extractall()
import pandas as pd
df = pd.read_csv(driver_imgs_list_file)
df.describe()
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
subject | classname | img | |
---|---|---|---|
count | 22424 | 22424 | 22424 |
unique | 26 | 10 | 22424 |
top | p021 | c0 | img_2362.jpg |
freq | 1237 | 2489 | 1 |
ts = df['subject'].value_counts()
print (ts)
fig = plt.figure(figsize=(30,10))
plt.bar(ts.index.tolist(), ts.iloc[:].tolist(), 0.4, color="green")
plt.xlabel("Driver ID")
plt.ylabel("File nums")
plt.title("Driver ID distribution")
plt.show()
p021 1237
p022 1233
p024 1226
p026 1196
p016 1078
p066 1034
p049 1011
p051 920
p014 876
p015 875
p035 848
p047 835
p012 823
p081 823
p064 820
p075 814
p061 809
p056 794
p050 790
p052 740
p002 725
p045 724
p039 651
p041 605
p042 591
p072 346
Name: subject, dtype: int64
可以看出,我们的训练集的数据来自于26个不同的司机的状态,每个司机都是都拥有不同状态的图片,其中346号司机拥有的图片数量最少,为346张。21号司机拥有的图片数量最多,为1237张。可以看出,如果按照司机ID来观察数据,数据分布并不均匀。但是这并不影响我们的训练,因为我们主要关心的是每一个状态所拥有的图片是否均匀,而不是每一个司机所拥有的图片是否均匀。接下来我们来简单可视化一下每一类的样本
import cv2
import numpy as np
state_des = {'c0':'safe driving','c1':'texting - right hand','c2':'talking on the phone - right','c3':'texting - left hand', \
'c4':'talking on the phone - left hand','c5':'operating the radio','c6':'drinking','c7':'reaching behind','c8':'hair and makeup', \
'c9':'talking to passenger'};
## class that you want to display
c = 0
## random choose the filenames of the class
dis_dir = train_dir_name + '/c' + str(c)
dis_filenames = os.listdir(dis_dir)
dis_list = np.random.randint(len(dis_filenames), size=(6))
dis_list = [dis_filenames[index] for index in dis_list]
plt.figure(1, figsize=(13, 13))
for i,filename in enumerate(dis_list):
image = cv2.imread(dis_dir + '/' + str(filename))
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
ax1=plt.subplot(3,3,i+1)
plt.imshow(image)
plt.axis("off")
plt.title(state_des['c'+str(c)] + "\n" + str(image.shape))
plt.show()
在做过了几个kaggle项目之后,对计算机视觉类的项目有了一个大概的直觉,如果要采用当前流行的CNN模型来完成项目,那么总是可以尝试迁移学习的,因为对于CNN来说,一些图片的“底层信息”(直线,边缘,眼睛,耳朵,猫脸,狗脸)是可以共享的,所以我们就可以利用各大模型在大型计算机视觉数据集中学习到的通用的“知识”,将其迁移到我们的项目的学习中,从而帮助我们更好地提取通用的特征。并且
- 如果我们只拥有少量的数据集,那么就可以只训练top layer(全连接层,输出层)的权重,不学习top layer以外的层的权重。如果数据集很少的情况下执意要对top layer以外的层进行训练,极有可能破坏预训练模型在庞大数据集中学习的知识,反而造成不好的效果。
- 如果我们拥有中等数量的数据集,那么我们可以开放少量的卷积层进行权重的fine-tune,而且为了避免大的权重的更新对以前学习到的知识造成破坏,建议采用小的learning rate。
- 如果我们拥有庞大的数据集,并且项目拥有足够的时间来进行训练,那么此时依然可以采用迁移学习,但是这时候我们会开放比第二种情况更多的层,甚至所有的层
之前我们已经对数据集进行了一些基本的了解,发现训练集的数量为2万多,所以我将我们现在的情况定位为第一种(毕竟2万多相对于百万级别的数据还是太少)。所以现在我打算在单模型上面进行迁移学习, 并且只更新top layer的权重 (为什么这里标记为红色,后面你就知道了)
我第一个尝试的模型是在ILSVRC 2015比赛中获得了冠军resnet-50,该模型利用residual block,解决了当网络深度增加是所产生的Degradation问题,即准确率会先上升然后达到饱和,再持续增加深度则会导致准确率下降(注意哟,这是训练集的准确率哟,不是验证集哟,所以不是过拟合问题)。我相信50层神经网络能够轻松应付本项目,并且不用担心会出现Degradation问题,何乐而不为呢?废话不多说,加载预训练权重并开始调参过程吧!
这里说明一下:我的电脑的显卡是Gtx 960,这个显卡用于训练小的神经网络还可以,但是用于训练resnet这种规模的网络就有点儿捉襟见肘了。所以我租用的亚马逊AWS云主机,p3.2 xlarge,其竞价实例的价格大约是1美元每小时,如果没有接触过AWS的可以通过这篇文章来学习如何利用AWS来进行深度学习(不知不觉又给亚马逊打了一下广告),还有,如果你想用AWS的话,你需要一把梯子(翻墙)。如果不想翻墙的话,阿里巴巴也有类似的云主机,只不过就是价格贵了点儿,自己衡量吧!
上面说了,我只会训练top layer的权重,所以为了方便之后的调参过程,这里采用bottleneck feature的方式来减少重复的前向传播过程,也就是提取所有样本在top layer之前的输出结果(对于resnet,该输出结果的维度是112048,以后统一叫做特征向量),将其保存起来,从而将其作为top layer的输入来达到加速调参的过程。也算是一种用空间换取时间的策略吧!
## load pretrained resnet
resNet_input_shape = (224,224,3)
res_x = Input(shape=resNet_input_shape)
res_x = Lambda(resnet50.preprocess_input)(res_x)
res_model = resnet50.ResNet50(include_top=False, weights='imagenet', input_tensor=res_x, input_shape=resNet_input_shape)
res_model.summary()
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 224, 224, 3) 0
__________________________________________________________________________________________________
lambda_1 (Lambda) (None, 224, 224, 3) 0 input_1[0][0]
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D) (None, 230, 230, 3) 0 lambda_1[0][0]
__________________________________________________________________________________________________
conv1 (Conv2D) (None, 112, 112, 64) 9472 conv1_pad[0][0]
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization) (None, 112, 112, 64) 256 conv1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 112, 112, 64) 0 bn_conv1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 55, 55, 64) 0 activation_1[0][0]
__________________________________________________________________________________________________
res2a_branch2a (Conv2D) (None, 55, 55, 64) 4160 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 55, 55, 64) 256 res2a_branch2a[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 55, 55, 64) 0 bn2a_branch2a[0][0]
__________________________________________________________________________________________________
res2a_branch2b (Conv2D) (None, 55, 55, 64) 36928 activation_2[0][0]
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 55, 55, 64) 256 res2a_branch2b[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 55, 55, 64) 0 bn2a_branch2b[0][0]
__________________________________________________________________________________________________
res2a_branch2c (Conv2D) (None, 55, 55, 256) 16640 activation_3[0][0]
__________________________________________________________________________________________________
res2a_branch1 (Conv2D) (None, 55, 55, 256) 16640 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 55, 55, 256) 1024 res2a_branch2c[0][0]
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 55, 55, 256) 1024 res2a_branch1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 55, 55, 256) 0 bn2a_branch2c[0][0]
bn2a_branch1[0][0]
__________________________________________________________________________________________________
activation_4 (Activation) (None, 55, 55, 256) 0 add_1[0][0]
__________________________________________________________________________________________________
res2b_branch2a (Conv2D) (None, 55, 55, 64) 16448 activation_4[0][0]
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 55, 55, 64) 256 res2b_branch2a[0][0]
__________________________________________________________________________________________________
activation_5 (Activation) (None, 55, 55, 64) 0 bn2b_branch2a[0][0]
__________________________________________________________________________________________________
res2b_branch2b (Conv2D) (None, 55, 55, 64) 36928 activation_5[0][0]
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 55, 55, 64) 256 res2b_branch2b[0][0]
__________________________________________________________________________________________________
activation_6 (Activation) (None, 55, 55, 64) 0 bn2b_branch2b[0][0]
__________________________________________________________________________________________________
res2b_branch2c (Conv2D) (None, 55, 55, 256) 16640 activation_6[0][0]
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 55, 55, 256) 1024 res2b_branch2c[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 55, 55, 256) 0 bn2b_branch2c[0][0]
activation_4[0][0]
__________________________________________________________________________________________________
activation_7 (Activation) (None, 55, 55, 256) 0 add_2[0][0]
__________________________________________________________________________________________________
res2c_branch2a (Conv2D) (None, 55, 55, 64) 16448 activation_7[0][0]
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 55, 55, 64) 256 res2c_branch2a[0][0]
__________________________________________________________________________________________________
activation_8 (Activation) (None, 55, 55, 64) 0 bn2c_branch2a[0][0]
__________________________________________________________________________________________________
res2c_branch2b (Conv2D) (None, 55, 55, 64) 36928 activation_8[0][0]
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 55, 55, 64) 256 res2c_branch2b[0][0]
__________________________________________________________________________________________________
activation_9 (Activation) (None, 55, 55, 64) 0 bn2c_branch2b[0][0]
__________________________________________________________________________________________________
res2c_branch2c (Conv2D) (None, 55, 55, 256) 16640 activation_9[0][0]
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 55, 55, 256) 1024 res2c_branch2c[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 55, 55, 256) 0 bn2c_branch2c[0][0]
activation_7[0][0]
__________________________________________________________________________________________________
activation_10 (Activation) (None, 55, 55, 256) 0 add_3[0][0]
__________________________________________________________________________________________________
res3a_branch2a (Conv2D) (None, 28, 28, 128) 32896 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2a[0][0]
__________________________________________________________________________________________________
activation_11 (Activation) (None, 28, 28, 128) 0 bn3a_branch2a[0][0]
__________________________________________________________________________________________________
res3a_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_11[0][0]
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2b[0][0]
__________________________________________________________________________________________________
activation_12 (Activation) (None, 28, 28, 128) 0 bn3a_branch2b[0][0]
__________________________________________________________________________________________________
res3a_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_12[0][0]
__________________________________________________________________________________________________
res3a_branch1 (Conv2D) (None, 28, 28, 512) 131584 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3a_branch2c[0][0]
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 28, 28, 512) 2048 res3a_branch1[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 28, 28, 512) 0 bn3a_branch2c[0][0]
bn3a_branch1[0][0]
__________________________________________________________________________________________________
activation_13 (Activation) (None, 28, 28, 512) 0 add_4[0][0]
__________________________________________________________________________________________________
res3b_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_13[0][0]
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3b_branch2a[0][0]
__________________________________________________________________________________________________
activation_14 (Activation) (None, 28, 28, 128) 0 bn3b_branch2a[0][0]
__________________________________________________________________________________________________
res3b_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_14[0][0]
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3b_branch2b[0][0]
__________________________________________________________________________________________________
activation_15 (Activation) (None, 28, 28, 128) 0 bn3b_branch2b[0][0]
__________________________________________________________________________________________________
res3b_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_15[0][0]
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3b_branch2c[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 28, 28, 512) 0 bn3b_branch2c[0][0]
activation_13[0][0]
__________________________________________________________________________________________________
activation_16 (Activation) (None, 28, 28, 512) 0 add_5[0][0]
__________________________________________________________________________________________________
res3c_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_16[0][0]
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3c_branch2a[0][0]
__________________________________________________________________________________________________
activation_17 (Activation) (None, 28, 28, 128) 0 bn3c_branch2a[0][0]
__________________________________________________________________________________________________
res3c_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_17[0][0]
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3c_branch2b[0][0]
__________________________________________________________________________________________________
activation_18 (Activation) (None, 28, 28, 128) 0 bn3c_branch2b[0][0]
__________________________________________________________________________________________________
res3c_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_18[0][0]
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3c_branch2c[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, 28, 28, 512) 0 bn3c_branch2c[0][0]
activation_16[0][0]
__________________________________________________________________________________________________
activation_19 (Activation) (None, 28, 28, 512) 0 add_6[0][0]
__________________________________________________________________________________________________
res3d_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_19[0][0]
__________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3d_branch2a[0][0]
__________________________________________________________________________________________________
activation_20 (Activation) (None, 28, 28, 128) 0 bn3d_branch2a[0][0]
__________________________________________________________________________________________________
res3d_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_20[0][0]
__________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3d_branch2b[0][0]
__________________________________________________________________________________________________
activation_21 (Activation) (None, 28, 28, 128) 0 bn3d_branch2b[0][0]
__________________________________________________________________________________________________
res3d_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_21[0][0]
__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3d_branch2c[0][0]
__________________________________________________________________________________________________
add_7 (Add) (None, 28, 28, 512) 0 bn3d_branch2c[0][0]
activation_19[0][0]
__________________________________________________________________________________________________
activation_22 (Activation) (None, 28, 28, 512) 0 add_7[0][0]
__________________________________________________________________________________________________
res4a_branch2a (Conv2D) (None, 14, 14, 256) 131328 activation_22[0][0]
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4a_branch2a[0][0]
__________________________________________________________________________________________________
activation_23 (Activation) (None, 14, 14, 256) 0 bn4a_branch2a[0][0]
__________________________________________________________________________________________________
res4a_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_23[0][0]
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4a_branch2b[0][0]
__________________________________________________________________________________________________
activation_24 (Activation) (None, 14, 14, 256) 0 bn4a_branch2b[0][0]
__________________________________________________________________________________________________
res4a_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_24[0][0]
__________________________________________________________________________________________________
res4a_branch1 (Conv2D) (None, 14, 14, 1024) 525312 activation_22[0][0]
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4a_branch2c[0][0]
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, 14, 14, 1024) 4096 res4a_branch1[0][0]
__________________________________________________________________________________________________
add_8 (Add) (None, 14, 14, 1024) 0 bn4a_branch2c[0][0]
bn4a_branch1[0][0]
__________________________________________________________________________________________________
activation_25 (Activation) (None, 14, 14, 1024) 0 add_8[0][0]
__________________________________________________________________________________________________
res4b_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_25[0][0]
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4b_branch2a[0][0]
__________________________________________________________________________________________________
activation_26 (Activation) (None, 14, 14, 256) 0 bn4b_branch2a[0][0]
__________________________________________________________________________________________________
res4b_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_26[0][0]
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4b_branch2b[0][0]
__________________________________________________________________________________________________
activation_27 (Activation) (None, 14, 14, 256) 0 bn4b_branch2b[0][0]
__________________________________________________________________________________________________
res4b_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_27[0][0]
__________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4b_branch2c[0][0]
__________________________________________________________________________________________________
add_9 (Add) (None, 14, 14, 1024) 0 bn4b_branch2c[0][0]
activation_25[0][0]
__________________________________________________________________________________________________
activation_28 (Activation) (None, 14, 14, 1024) 0 add_9[0][0]
__________________________________________________________________________________________________
res4c_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_28[0][0]
__________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4c_branch2a[0][0]
__________________________________________________________________________________________________
activation_29 (Activation) (None, 14, 14, 256) 0 bn4c_branch2a[0][0]
__________________________________________________________________________________________________
res4c_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_29[0][0]
__________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4c_branch2b[0][0]
__________________________________________________________________________________________________
activation_30 (Activation) (None, 14, 14, 256) 0 bn4c_branch2b[0][0]
__________________________________________________________________________________________________
res4c_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_30[0][0]
__________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4c_branch2c[0][0]
__________________________________________________________________________________________________
add_10 (Add) (None, 14, 14, 1024) 0 bn4c_branch2c[0][0]
activation_28[0][0]
__________________________________________________________________________________________________
activation_31 (Activation) (None, 14, 14, 1024) 0 add_10[0][0]
__________________________________________________________________________________________________
res4d_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_31[0][0]
__________________________________________________________________________________________________
bn4d_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4d_branch2a[0][0]
__________________________________________________________________________________________________
activation_32 (Activation) (None, 14, 14, 256) 0 bn4d_branch2a[0][0]
__________________________________________________________________________________________________
res4d_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_32[0][0]
__________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4d_branch2b[0][0]
__________________________________________________________________________________________________
activation_33 (Activation) (None, 14, 14, 256) 0 bn4d_branch2b[0][0]
__________________________________________________________________________________________________
res4d_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_33[0][0]
__________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4d_branch2c[0][0]
__________________________________________________________________________________________________
add_11 (Add) (None, 14, 14, 1024) 0 bn4d_branch2c[0][0]
activation_31[0][0]
__________________________________________________________________________________________________
activation_34 (Activation) (None, 14, 14, 1024) 0 add_11[0][0]
__________________________________________________________________________________________________
res4e_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_34[0][0]
__________________________________________________________________________________________________
bn4e_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4e_branch2a[0][0]
__________________________________________________________________________________________________
activation_35 (Activation) (None, 14, 14, 256) 0 bn4e_branch2a[0][0]
__________________________________________________________________________________________________
res4e_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_35[0][0]
__________________________________________________________________________________________________
bn4e_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4e_branch2b[0][0]
__________________________________________________________________________________________________
activation_36 (Activation) (None, 14, 14, 256) 0 bn4e_branch2b[0][0]
__________________________________________________________________________________________________
res4e_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_36[0][0]
__________________________________________________________________________________________________
bn4e_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4e_branch2c[0][0]
__________________________________________________________________________________________________
add_12 (Add) (None, 14, 14, 1024) 0 bn4e_branch2c[0][0]
activation_34[0][0]
__________________________________________________________________________________________________
activation_37 (Activation) (None, 14, 14, 1024) 0 add_12[0][0]
__________________________________________________________________________________________________
res4f_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_37[0][0]
__________________________________________________________________________________________________
bn4f_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4f_branch2a[0][0]
__________________________________________________________________________________________________
activation_38 (Activation) (None, 14, 14, 256) 0 bn4f_branch2a[0][0]
__________________________________________________________________________________________________
res4f_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_38[0][0]
__________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4f_branch2b[0][0]
__________________________________________________________________________________________________
activation_39 (Activation) (None, 14, 14, 256) 0 bn4f_branch2b[0][0]
__________________________________________________________________________________________________
res4f_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_39[0][0]
__________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4f_branch2c[0][0]
__________________________________________________________________________________________________
add_13 (Add) (None, 14, 14, 1024) 0 bn4f_branch2c[0][0]
activation_37[0][0]
__________________________________________________________________________________________________
activation_40 (Activation) (None, 14, 14, 1024) 0 add_13[0][0]
__________________________________________________________________________________________________
res5a_branch2a (Conv2D) (None, 7, 7, 512) 524800 activation_40[0][0]
__________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5a_branch2a[0][0]
__________________________________________________________________________________________________
activation_41 (Activation) (None, 7, 7, 512) 0 bn5a_branch2a[0][0]
__________________________________________________________________________________________________
res5a_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_41[0][0]
__________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5a_branch2b[0][0]
__________________________________________________________________________________________________
activation_42 (Activation) (None, 7, 7, 512) 0 bn5a_branch2b[0][0]
__________________________________________________________________________________________________
res5a_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_42[0][0]
__________________________________________________________________________________________________
res5a_branch1 (Conv2D) (None, 7, 7, 2048) 2099200 activation_40[0][0]
__________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5a_branch2c[0][0]
__________________________________________________________________________________________________
bn5a_branch1 (BatchNormalizatio (None, 7, 7, 2048) 8192 res5a_branch1[0][0]
__________________________________________________________________________________________________
add_14 (Add) (None, 7, 7, 2048) 0 bn5a_branch2c[0][0]
bn5a_branch1[0][0]
__________________________________________________________________________________________________
activation_43 (Activation) (None, 7, 7, 2048) 0 add_14[0][0]
__________________________________________________________________________________________________
res5b_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_43[0][0]
__________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5b_branch2a[0][0]
__________________________________________________________________________________________________
activation_44 (Activation) (None, 7, 7, 512) 0 bn5b_branch2a[0][0]
__________________________________________________________________________________________________
res5b_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_44[0][0]
__________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5b_branch2b[0][0]
__________________________________________________________________________________________________
activation_45 (Activation) (None, 7, 7, 512) 0 bn5b_branch2b[0][0]
__________________________________________________________________________________________________
res5b_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_45[0][0]
__________________________________________________________________________________________________
bn5b_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5b_branch2c[0][0]
__________________________________________________________________________________________________
add_15 (Add) (None, 7, 7, 2048) 0 bn5b_branch2c[0][0]
activation_43[0][0]
__________________________________________________________________________________________________
activation_46 (Activation) (None, 7, 7, 2048) 0 add_15[0][0]
__________________________________________________________________________________________________
res5c_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_46[0][0]
__________________________________________________________________________________________________
bn5c_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5c_branch2a[0][0]
__________________________________________________________________________________________________
activation_47 (Activation) (None, 7, 7, 512) 0 bn5c_branch2a[0][0]
__________________________________________________________________________________________________
res5c_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_47[0][0]
__________________________________________________________________________________________________
bn5c_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5c_branch2b[0][0]
__________________________________________________________________________________________________
activation_48 (Activation) (None, 7, 7, 512) 0 bn5c_branch2b[0][0]
__________________________________________________________________________________________________
res5c_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_48[0][0]
__________________________________________________________________________________________________
bn5c_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5c_branch2c[0][0]
__________________________________________________________________________________________________
add_16 (Add) (None, 7, 7, 2048) 0 bn5c_branch2c[0][0]
activation_46[0][0]
__________________________________________________________________________________________________
activation_49 (Activation) (None, 7, 7, 2048) 0 add_16[0][0]
__________________________________________________________________________________________________
avg_pool (AveragePooling2D) (None, 1, 1, 2048) 0 activation_49[0][0]
==================================================================================================
Total params: 23,587,712
Trainable params: 23,534,592
Non-trainable params: 53,120
__________________________________________________________________________________________________
我们使用的在ImageNet上训练的resnet-50模型,并且,我们没有加载top layer,因为接下来是要提取bottleneck feature,无需top layer。从模型的打印信息可以知道,该模型的输出维度为112048,但是我要保存的是1维的特征向量,所以需要将其进行类似flatten的操作
out = GlobalAveragePooling2D()(res_model.output)
res_vec_model = Model(inputs=res_model.input, outputs=out)
在得到了提取bottleneck feature的模型之后,就可以开始着手提取特征了,考虑到之后也可能使用类似的操作来提取其他模型的bottleneck feature,所以我就写了一个函数(所以上面两个cell只是用来说明如何构建提取feature的模型的,真正使用的是下面这个函数)。该函数先构造一个提取bottleneck feature的模型,做的事情和上面两个cell一样。后面使用了image data generator的方式来来进行bottleneck feature提取,为什么采用这种方式呢?因为这种方式并不会一次性地把所有图片都加载到内存中,而是采用类似队列一样的边使用边加载的方式进行数据的读取,这样可以大大地减少内存的使用。我们可以来算一算,如果将所有的数据一次性加载到内存中会消耗多少的内存?假设一张图片的大小为2242243(resnet-50需要的大小),那么我们有22424张训练集(还没有包含庞大的测试集),加入我们以uint8的data type来加载数据,那么训练集所占用的内存大小就是2242243*22424 = 3375439872,达到了惊人的3GB。所以测试集将占用超过9GB,总共就是12GB。这还没开始训练呢,就占用了如此多的内存,即使我使用的AWS主机的Tesla v100显卡有16GB内存,但是考虑到之后留给训练的余量,还是不要以这种方式来加载数据了。
def model_vector_catch(MODEL, image_size, vect_file_name, vec_dir, train_dir, test_dir, preprocessing=None):
"""
MODEL:the model to extract bottleneck features
image_size:MODEL input size(h, w, channels)
vect_file_name:file to save vector
preprocessing:whether or not need preprocessing
"""
if isfile(vec_dir + '/' + vect_file_name):
print ("%s already OK!" % (vect_file_name))
return
input_tensor = Input(shape=(image_size[0], image_size[1], 3))
if preprocessing:
## check if need preprocessing
input_tensor = Lambda(preprocessing)(input_tensor)
model_no_top = MODEL(include_top=False, weights='imagenet', input_tensor=input_tensor, input_shape=(image_size[0], image_size[1], 3))
## flatten the output shape and generate model
out = GlobalAveragePooling2D()(model_no_top.output)
new_model = Model(inputs=model_no_top.input, outputs=out)
## get iamge generator
gen = ImageDataGenerator()
test_gen = ImageDataGenerator()
"""
classes = ['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9', ] -- cat is 0, dog is 1, so we need write this
class_mode = None -- i will not use like 'fit_fitgenerator', so i do not need labels
shuffle = False -- it is unneccssary
batch_size = 64
"""
class_list = ['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9', ]
train_generator = gen.flow_from_directory(train_dir, image_size, color_mode='rgb', \
classes=class_list, class_mode=None, shuffle=False, batch_size=64)
#test_generator = test_gen.flow_from_directory(test_dir, image_size, color_mode='rgb', \
#class_mode=None, shuffle=False, batch_size=64)
"""
steps = None, by default, the steps = len(generator)
"""
train_vector = new_model.predict_generator(train_generator)
#test_vector = new_model.predict_generator(test_generator)
with h5py.File(vec_dir + "/" + (vect_file_name), 'w') as f:
f.create_dataset('x_train', data=train_vector)
f.create_dataset("y_train", data=train_generator.classes)
#f.create_dataset("test", data=test_vector)
print ("Model %s vector cached complete!" % (vect_file_name))
vec_dir = 'vect'
if not isdir(vec_dir):
os.mkdir(vec_dir)
res_vect_file_name = 'resnet50_vect.h5'
model_vector_catch(resnet50.ResNet50, resNet_input_shape[:2], res_vect_file_name, vec_dir, train_dir_name, test_dir_name, resnet50.preprocess_input)
Found 22424 images belonging to 10 classes.
Model resnet50_vect.h5 vector cached complete!
现在我们已经提取好了训练集(暂时不提取测试集的bottleneck feature,等我先验证了这种方法是否好用之后,再做测试集的bottleneck feature的提取,因为训练模型的时候用不到测试集,所以不着急提取)的bottleneck feature,下面就是搭建模型的top layer来进行训练了,那么top layer要怎么搭建呢?在原始的resnet-50中,只有一个含有1000个隐藏单元的输出层,所以这里我们参考resnet-50的设计,改成一个含有10个隐藏单元的输出层(因为我们的类别是10),激活函数是softmax。
v = 10
input_tensor = Input(shape=(2048,))
x = Dropout(0.5)(input_tensor)
x = Dense(driver_classes, activation='softmax', name='res_dense_1')(x)
resnet50_model = Model(inputs=input_tensor, outputs=x)
在训练模型之前,需要先把保存在文件中的bottleneck feature提取到内存中来,便于之后的fit。进一步地,因为我们之前保存的label是0-9的数字,现在我们要使用one-hot的方式来表示。
import numpy as np
def convert_to_one_hot(Y, C):
Y = np.eye(C)[Y.reshape(-1)]
return Y
from sklearn.utils import shuffle
x_train = []
y_train = []
with h5py.File(vec_dir + '/' + res_vect_file_name, 'r') as f:
x_train = np.array(f['x_train'])
y_train = np.array(f['y_train'])
#one-hot vector
y_train = convert_to_one_hot(y_train, driver_classes)
x_train, y_train = shuffle(x_train, y_train, random_state=0)
注意上面的代码!!!这里有一个坑,是我之前项目中犯过的,导致我的模型无论如何都不收敛,而且还花了很长时间去调试。就是shuffle操作,一定要shuffle啊,如果你的数据是每一类单独存放在一起的,如果不shuffle的话,那么每一个batch的数据都将是同一类的数据,导致模型根本学不到东西,因为这个batch我学习到了这一类的特点,另一个batch我就要学习另一类完全不同的特点,从而抛弃之前学习的内容,导致这样的恶性循环,无法收敛。因为这个错误是之前犯的,所以现在也就没有必要再专门调试说明了。
接下来的过程就是设置编译参数,然后调试了。这里optimizer采用Adam,参数为默认参数。batch_size采用64,epoch先运行10代看看情况,验证集划分为0.2.
resnet50_model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])
hist = resnet50_model.fit(x_train, y_train, batch_size=32, epochs=30, validation_split=0.2)
Train on 17939 samples, validate on 4485 samples
Epoch 1/30
17939/17939 [==============================] - 5s 254us/step - loss: 1.2879 - acc: 0.5758 - val_loss: 0.4706 - val_acc: 0.8783
Epoch 2/30
17939/17939 [==============================] - 2s 98us/step - loss: 0.6077 - acc: 0.7996 - val_loss: 0.2997 - val_acc: 0.9298
Epoch 3/30
17939/17939 [==============================] - 2s 96us/step - loss: 0.4998 - acc: 0.8343 - val_loss: 0.2223 - val_acc: 0.9483
Epoch 4/30
17939/17939 [==============================] - 2s 94us/step - loss: 0.4569 - acc: 0.8463 - val_loss: 0.2140 - val_acc: 0.9452
Epoch 5/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.4240 - acc: 0.8575 - val_loss: 0.2061 - val_acc: 0.9465
Epoch 6/30
17939/17939 [==============================] - 2s 94us/step - loss: 0.4089 - acc: 0.8653 - val_loss: 0.1641 - val_acc: 0.9550
Epoch 7/30
17939/17939 [==============================] - 2s 94us/step - loss: 0.3981 - acc: 0.8657 - val_loss: 0.1955 - val_acc: 0.9394
Epoch 8/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3967 - acc: 0.8688 - val_loss: 0.1512 - val_acc: 0.9574
Epoch 9/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3788 - acc: 0.8712 - val_loss: 0.1469 - val_acc: 0.9601
Epoch 10/30
17939/17939 [==============================] - 2s 96us/step - loss: 0.3852 - acc: 0.8723 - val_loss: 0.1646 - val_acc: 0.9518
Epoch 11/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3733 - acc: 0.8766 - val_loss: 0.1443 - val_acc: 0.9550
Epoch 12/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3858 - acc: 0.8735 - val_loss: 0.1474 - val_acc: 0.9603
Epoch 13/30
17939/17939 [==============================] - 2s 98us/step - loss: 0.3707 - acc: 0.8773 - val_loss: 0.1286 - val_acc: 0.9608
Epoch 14/30
17939/17939 [==============================] - 2s 97us/step - loss: 0.3597 - acc: 0.8810 - val_loss: 0.1340 - val_acc: 0.9581
Epoch 15/30
17939/17939 [==============================] - 2s 102us/step - loss: 0.3754 - acc: 0.8762 - val_loss: 0.2203 - val_acc: 0.9280
Epoch 16/30
17939/17939 [==============================] - 2s 99us/step - loss: 0.3603 - acc: 0.8821 - val_loss: 0.1307 - val_acc: 0.9628
Epoch 17/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3683 - acc: 0.8817 - val_loss: 0.1338 - val_acc: 0.9608
Epoch 18/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3640 - acc: 0.8819 - val_loss: 0.1180 - val_acc: 0.9654
Epoch 19/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3670 - acc: 0.8812 - val_loss: 0.1216 - val_acc: 0.9621
Epoch 20/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3505 - acc: 0.8855 - val_loss: 0.1128 - val_acc: 0.9672
Epoch 21/30
17939/17939 [==============================] - 2s 98us/step - loss: 0.3689 - acc: 0.8801 - val_loss: 0.1253 - val_acc: 0.9632
Epoch 22/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3670 - acc: 0.8808 - val_loss: 0.1418 - val_acc: 0.9538
Epoch 23/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3567 - acc: 0.8866 - val_loss: 0.1127 - val_acc: 0.9672
Epoch 24/30
17939/17939 [==============================] - 2s 96us/step - loss: 0.3484 - acc: 0.8870 - val_loss: 0.1269 - val_acc: 0.9614
Epoch 25/30
17939/17939 [==============================] - 2s 97us/step - loss: 0.3588 - acc: 0.8845 - val_loss: 0.1109 - val_acc: 0.9650
Epoch 26/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3672 - acc: 0.8822 - val_loss: 0.1100 - val_acc: 0.9666
Epoch 27/30
17939/17939 [==============================] - 2s 94us/step - loss: 0.3477 - acc: 0.8882 - val_loss: 0.1082 - val_acc: 0.9683
Epoch 28/30
17939/17939 [==============================] - 2s 96us/step - loss: 0.3647 - acc: 0.8839 - val_loss: 0.1119 - val_acc: 0.9686
Epoch 29/30
17939/17939 [==============================] - 2s 94us/step - loss: 0.3504 - acc: 0.8885 - val_loss: 0.1083 - val_acc: 0.9677
Epoch 30/30
17939/17939 [==============================] - 2s 96us/step - loss: 0.3562 - acc: 0.8877 - val_loss: 0.0988 - val_acc: 0.9706
我尝试过了很多优化算法,Adam,SGD等,包括设置了低的学习速率,以及学习速率衰减,但是现象都和上面的训练结果惊人的相似,主要现象分为两方面:
- (训练集loss居高不下)训练集的loss一直在0.35左右摆动,无论如何也无法有效降低(参考kaggle排行榜上的前10名的loss,均小于0.13)
- (验证集loss远远低于训练集loss)val_loss远远低于train_loss,经过30个epoch后,train_loss为0.35,而val_loss为0.09,并且val_loss还有继续下降的趋势。
大家记住,上面就是我遇到的两个最大的“坑”
下面我们就来一一分析这两个“坑”,并给出解决方案。
先暂且不管模型的验证集的loss是如何如何低,如何如何完美。现在只关心训练集的loss,其值高达0.35,按照正常思维,这就是模型欠拟合了。一般解决欠拟合的套路如下:
- 训练集loss反映了偏差(bias)的大小,bias反映了模型本的拟合能力,所以在模型拟合能力不够的情况下可以适当增加模型的复杂度,加入更多的隐藏层或者隐藏单元。
结果:我尝试了在输出层之前中加入了两个含有1024个隐藏单元的全连接层(这里就不给出代码了,读者可以自行尝试),但是根本没效果
- 减小正则项,其实这也算是提高模型复杂度。对于本项目来说就是减小dropout的drop rate(等效于增大keep probability)
结果:无效!
在上述两种方法都尝试失败了之后,说明正常的套路已经没有用了,是不是我的思考方向错误了呢?是否不应该单纯地从模型的角度分析问题?是否应该尝试转换思路?所以我就开始尝试从数据集本身开始分析。
因为我们采用的策略是迁移学习,迁移学习的权重是来自于resnet-50在ImageNet上进行预训练的权重,也就是说我们靠这些ImageNet训练出来权重进行bottleneck feature的提取,但是结果大家都看到了,目前这样的方法并不能适用于我们的问题,那就是说提取的bottleneck feature不能帮助我们很好地进行学习,如果是这样的话,那就说明一个问题,我们的司机数据集和ImageNet数据集是有点儿“不一样”的。接下来我们就来分析分析这个不一样在哪里?
- 首先ImageNet数据集包含有1000种类别的图片,其中的数据大都是来自于不同的场景,拥有复杂的背景。但是我们的司机数据均来自于同一个场景(一个司机坐在车里,前面有一个方向盘),而且同一个司机的数据来自于同一个视频流,也就是说我们司机数据之间本身很相似!!!!
那么预训练的resnet-50(不含有top layer)对非常相似的司机数据进行提取时,会发生什么呢?当然就是会提取出一些非常相似的bottleneck feature!!(猜想预训练模型会这么‘干’:‘我’提取出了这张图片的信息,这张图片有脸,有手,手方向盘。。。但是我们需要的似乎不仅仅是这些信息),也就是不同的司机状态的数据也会提取出非常相似的feature的话,那么对于司机问题来说,这些feature就是无用的feature(当然‘无用’说得太过分了,毕竟训练的模型也有88%的准确率呢!只是说提取的特征信息还不能够满足本项目的需求),当然也就不能获得很好的训练效果了。
好了,原因知道了那么解决方法也就明朗了。解决方法就是开放更多地层,因为我们的模型需要学习更多的关于司机数据集的‘知识’。既然要开放更多的层,那么bottleneck feature的方法对于我们来说就不太使用了,因为模型中间的层一般维度都比较大,比如1414256,在20000个样本的情况下要占用4G的内存,如果再加上测试集的话,内存就消耗完了(而且如果AWS中选择K80显卡的话,只有11G的内存,根本不够)。所以现在我们就要对网络进行整体的训练。
接下来的大部分的操作和之前相似,不一样的是我们现在要开放更多的卷积层进行学习。在这之前,让我们先来准备用于训练的数据,这次我还是打算采用imagedataGenerator来产生训练的数据,同时用fitgenerator来进行训练。
此时的验证集的数据需要单独地列出来(因为待会儿训练的时候我使用的是fit_generator,而这个接口是不支持validation_split参数的,所以需要手动地将验证集分出来),我决定和之前一样采用总样本数量的20%的样本进行训练,所以我需要在每一个类别中提取出20%的样本作为验证集。这里为了节约磁盘空间,我将采用软链接的形式来建立train_link数据集。
classes = ['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9']
if not isdir(test_link_path):
os.makedirs(test_link_path)
c_filenames = os.listdir(test_dir_name)
for file in c_filenames:
os.symlink('../../' + test_dir_name + '/' + file, test_link_path + '/' + file)
if not isdir(link_path):
os.makedirs(link_train_path)
os.makedirs(link_valid_path)
# make c0-c9
for c in classes:
os.makedirs(link_train_path + '/' + c)
os.makedirs(link_valid_path + '/' + c)
# create link from train dir
for c in classes:
# get path name
c_path = train_dir_name + '/' + c
train_dst_path = link_train_path + '/' + c
valid_dst_path = link_valid_path + '/' + c
# list all file name of this class
c_filenames = os.listdir(c_path)
valid_size = int (len(c_filenames)*0.2)
# create validation data of this class
for file in c_filenames[:valid_size]:
os.symlink('../../../' + c_path + '/' + file, valid_dst_path + '/' + file)
# create train data of this class
for file in c_filenames[valid_size:]:
os.symlink('../../../' + c_path + '/' + file, train_dst_path + '/' + file)
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
<ipython-input-9-8ff26d47ecaa> in <module>()
6
7 for file in c_filenames:
----> 8 os.symlink('../../' + test_dir_name + '/' + file, test_link_path + '/' + file)
9
10 if not isdir(link_path):
KeyboardInterrupt:
def get_data_generator(train_dir, valid_dir, test_dir, image_size):
#gen = ImageDataGenerator(shear_range=0.3, zoom_range=0.3, rotation_range=0.3)
gen = ImageDataGenerator()
gen_valid = ImageDataGenerator()
test_gen = ImageDataGenerator()
"""
classes = ['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9'] -- cat is 0, dog is 1, so we need write this
class_mode = categorical, the returned label mode
shuffle = True, we need,
batch_size = 64
"""
class_list = ['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9']
# create train generator
train_generator = gen.flow_from_directory(train_dir, image_size, color_mode='rgb', \
classes=class_list, class_mode='categorical', shuffle=True, batch_size=32)
# create validation generator
valid_generator = gen_valid.flow_from_directory(valid_dir, image_size, color_mode='rgb', \
classes=class_list, class_mode='categorical', shuffle=False, batch_size=32)
test_generator = test_gen.flow_from_directory(test_dir, image_size, color_mode='rgb', \
class_mode=None, shuffle=False, batch_size=32)
return train_generator, valid_generator, test_generator
#, test_generator
def model_built(MODEL, input_shape, preprocess_input, classes, last_frozen_layer_name):
"""
MODEL: pretrained model
input_shape: pre-trained model's input shape
preprocessing_input: pre-trained model's preprocessing function
last_frozen_layer_name: last layer to frozen
"""
## get pretrained model
x = Input(shape=input_shape)
if preprocess_input:
x = Lambda(preprocess_input)(x)
notop_model = MODEL(include_top=False, weights='imagenet', input_tensor=x, input_shape=input_shape)
x = GlobalAveragePooling2D()(notop_model.output)
## build top layer
x = Dropout(0.5, name='dropout_1')(x)
out = Dense(classes, activation='softmax', name='dense_1')(x)
ret_model = Model(inputs=notop_model.input, outputs=out)
## Frozen some layer
#for layer in ret_model.layers:
#layer.trainable = False
#if layer.name == last_frozen_layer_name:
#break
return ret_model
import pandas as pd
"""
def get_test_result(model_obj, test_generator, model_name="default"):
pred_test = model_obj.predict_generator(test_generator, len(test_generator))
pred_test = np.array(pred_test)
pred_test = pred_test.clip(min=0.005, max=0.995)
df = pd.read_csv("sample_submission.csv")
for i, fname in enumerate(test_generator.filenames):
df.loc[df["img"] == fname] = [fname] + list(pred_test[i])
df.to_csv('%s.csv' % (model_name), index=None)
print ('test result file %s.csv generated!' % (model_name))
df.head(10)
"""
def get_test_result(model_obj, test_generator, model_name="default"):
print("Now to predict")
pred_test = model_obj.predict_generator(test_generator, len(test_generator), verbose=1)
pred_test = np.array(pred_test)
pred_test = pred_test.clip(min=0.005, max=0.995)
print("create datasheet")
result = pd.DataFrame(pred_test, columns=['c0', 'c1', 'c2', 'c3',
'c4', 'c5', 'c6', 'c7',
'c8', 'c9'])
test_filenames = []
for f in test_generator.filenames:
test_filenames.append(os.path.basename(f))
result.loc[:, 'img'] = pd.Series(test_filenames, index=result.index)
result.to_csv('%s.csv' % (model_name), index=None)
print ('test result file %s.csv generated!' % (model_name))
resnet50_train_generator, resnet50_valid_generator, resnet50_test_generator = get_data_generator(link_train_path, link_valid_path, test_link, resNet_input_shape[:2])
Found 17943 images belonging to 10 classes.
Found 4481 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
现在数据准备好了,接下来就是搭建模型,开放所有层进行训练。
resnet50_model = model_built(resnet50.ResNet50, resNet_input_shape, resnet50.preprocess_input, 10, None)
from keras import optimizers
sgd = optimizers.SGD(lr=0.0001, decay=1e-6, momentum=0.9, nesterov=True)
resnet50_model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])
hist = resnet50_model.fit_generator(resnet50_train_generator, len(resnet50_train_generator), epochs=3,\
validation_data=resnet50_valid_generator, validation_steps=len(resnet50_valid_generator))
WARNING:tensorflow:Variable *= will be deprecated. Use variable.assign_mul if you want assignment to the variable value or 'x = x * y' if you want a new python Tensor object.
Epoch 1/3
281/281 [==============================] - 374s 1s/step - loss: 2.0012 - acc: 0.3205 - val_loss: 1.0765 - val_acc: 0.8246
Epoch 2/3
281/281 [==============================] - 112s 398ms/step - loss: 0.9246 - acc: 0.7555 - val_loss: 0.4652 - val_acc: 0.9362
Epoch 3/3
281/281 [==============================] - 112s 399ms/step - loss: 0.4605 - acc: 0.9049 - val_loss: 0.2505 - val_acc: 0.9616
hist = resnet50_model.fit_generator(resnet50_train_generator, len(resnet50_train_generator), epochs=5,\
validation_data=resnet50_valid_generator, validation_steps=len(resnet50_valid_generator))
Epoch 1/5
281/281 [==============================] - 115s 411ms/step - loss: 0.2670 - acc: 0.9494 - val_loss: 0.1625 - val_acc: 0.9728
Epoch 2/5
281/281 [==============================] - 113s 402ms/step - loss: 0.1811 - acc: 0.9661 - val_loss: 0.1189 - val_acc: 0.9788
Epoch 3/5
281/281 [==============================] - 113s 403ms/step - loss: 0.1265 - acc: 0.9781 - val_loss: 0.0949 - val_acc: 0.9828
Epoch 4/5
281/281 [==============================] - 112s 398ms/step - loss: 0.0971 - acc: 0.9834 - val_loss: 0.0788 - val_acc: 0.9850
Epoch 5/5
281/281 [==============================] - 111s 395ms/step - loss: 0.0777 - acc: 0.9868 - val_loss: 0.0683 - val_acc: 0.9853
由于时间关系我就不继续往下训练了,可以看到我们已经将“训练集loss高居不下”的问题解决了,也就是之前列出来的两个问题中的第一个。看起来一切都很完美,完美得我们都已经忽略存在的第二个问题,因为从上面的训练结果不太明显(其实也能看出来,因为每次epoch之后,训练loss都明显高于验证loss),如果你执意认为没问题,那好,让我们将该模型在测试集上运行,然后将运行结果提交kaggle就知道了。
get_test_result(resnet50_model, resnet50_test_generator, model_name="resnet-50")
Now to predict
50/1246 [>.............................] - ETA: 6:25
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
<ipython-input-27-89baaf9d67e9> in <module>()
----> 1 get_test_result(resnet50_model, resnet50_test_generator, model_name="resnet-50")
<ipython-input-26-b7ba56715628> in get_test_result(model_obj, test_generator, model_name)
18 def get_test_result(model_obj, test_generator, model_name="default"):
19 print("Now to predict")
---> 20 pred_test = model_obj.predict_generator(test_generator, len(test_generator), verbose=1)
21 pred_test = np.array(pred_test)
22 pred_test = pred_test.clip(min=0.005, max=0.995)
~/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
89 warnings.warn('Update your `' + object_name +
90 '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91 return func(*args, **kwargs)
92 wrapper._original_function = func
93 return wrapper
~/anaconda3/lib/python3.6/site-packages/keras/engine/training.py in predict_generator(self, generator, steps, max_queue_size, workers, use_multiprocessing, verbose)
2490
2491 while steps_done < steps:
-> 2492 generator_output = next(output_generator)
2493 if isinstance(generator_output, tuple):
2494 # Compatibility with the generators
~/anaconda3/lib/python3.6/site-packages/keras/utils/data_utils.py in get(self)
576 try:
577 while self.is_running():
--> 578 inputs = self.queue.get(block=True).get()
579 self.queue.task_done()
580 if inputs is not None:
~/anaconda3/lib/python3.6/multiprocessing/pool.py in get(self, timeout)
636
637 def get(self, timeout=None):
--> 638 self.wait(timeout)
639 if not self.ready():
640 raise TimeoutError
~/anaconda3/lib/python3.6/multiprocessing/pool.py in wait(self, timeout)
633
634 def wait(self, timeout=None):
--> 635 self._event.wait(timeout)
636
637 def get(self, timeout=None):
~/anaconda3/lib/python3.6/threading.py in wait(self, timeout)
549 signaled = self._flag
550 if not signaled:
--> 551 signaled = self._cond.wait(timeout)
552 return signaled
553
~/anaconda3/lib/python3.6/threading.py in wait(self, timeout)
293 try: # restore state no matter what (e.g., KeyboardInterrupt)
294 if timeout is None:
--> 295 waiter.acquire()
296 gotit = True
297 else:
KeyboardInterrupt:
不出所料,提交到kaggle之后,loss是惊人的2.3。下一小节开始分析此问题
现在就来分析一下哪里有问题吧!从之前的训练过程可以看出,每一次验证集的loss都是要低于测试集的loss,这很不正常。说明我们验证集直接或者间接的过拟合了。为什么会过拟合呢?其实问题原因还是在数据上。我们之前说过,关于同一个司机的数据是由同一个摄像头采集于同一场景,也就是说每一个司机的同一个状态的数据会非常相似(相当于视频流的连续帧),如果这些数据有一些在训练集上,有一些在测试集上,那么模型就相当于已经“看过”了验证集的数据再来做验证,那么此时就造成了验证集的过拟合(相当于用训练数据来做验证),所以呢,验证集的数据不应该出现在测试集上,所以我们应该按照司机的ID来划分验证集和测试集,将一个司机的所有图片用来当做验证集。为了让验证集达到训练集的20%,我这里选择了四个司机(p021,p022,p024,p026)的图片(总共:4800多张)。
import matplotlib.pyplot as plt
def show_loss(hist, title='loss'):
# show the training and validation loss
plt.plot(hist.history['val_loss'], label="validation loss")
plt.plot(hist.history['loss'], label="train loss")
plt.ylabel('loss')
plt.xlabel('epoch')
plt.title(title)
plt.legend()
plt.show()
import pandas as pd
link_path = 'train_link'
link_train_path = 'train_link/train'
link_valid_path = 'train_link/validation'
test_link_path = 'test_link/data'
classes = ['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9']
validation_drivers = ["p021","p022","p024","p026"]
if not isdir(test_link_path):
os.makedirs(test_link_path)
c_filenames = os.listdir(test_dir_name)
for file in c_filenames:
os.symlink('../../' + test_dir_name + '/' + file, test_link_path + '/' + file)
if not isdir(link_path):
os.makedirs(link_train_path)
os.makedirs(link_valid_path)
# get validation file names
df = pd.read_csv(driver_imgs_list_file)
validation_files = list(df[df['subject'].isin(validation_drivers)]['img'])
# make c0-c9
for c in classes:
os.makedirs(link_train_path + '/' + c)
os.makedirs(link_valid_path + '/' + c)
# create link from train dir
for c in classes:
# get path name
c_path = train_dir_name + '/' + c
train_dst_path = link_train_path + '/' + c
valid_dst_path = link_valid_path + '/' + c
# list all file name of this class
c_filenames = os.listdir(c_path)
# create validation data of this class
for file in c_filenames:
if file in validation_files:
os.symlink('../../../' + c_path + '/' + file, valid_dst_path + '/' + file)
else:
os.symlink('../../../' + c_path + '/' + file, train_dst_path + '/' + file)
def model_built(MODEL, input_shape, preprocess_input, classes, last_frozen_layer_name):
"""
MODEL: pretrained model
input_shape: pre-trained model's input shape
preprocessing_input: pre-trained model's preprocessing function
last_frozen_layer_name: last layer to frozen
"""
## get pretrained model
x = Input(shape=input_shape)
if preprocess_input:
x = Lambda(preprocess_input)(x)
notop_model = MODEL(include_top=False, weights='imagenet', input_tensor=x, input_shape=input_shape)
x = GlobalAveragePooling2D()(notop_model.output)
## build top layer
x = Dropout(0.5, name='dropout_1')(x)
out = Dense(classes, activation='softmax', name='dense_1')(x)
ret_model = Model(inputs=notop_model.input, outputs=out)
## Frozen some layer
#for layer in ret_model.layers:
#layer.trainable = False
#if layer.name == last_frozen_layer_name:
#break
return ret_model
from keras import optimizers
from keras.callbacks import ModelCheckpoint
from keras.applications import xception, resnet50
# get generator
resNet_input_shape = (224,224,3)
resnet50_train_generator, resnet50_valid_generator, resnet50_test_generator = get_data_generator(link_train_path, link_valid_path, test_link, resNet_input_shape[:2])
##xception_train_generator, xception_valid_generator, xception_test_generator = get_data_generator(link_train_path, link_valid_path, test_link, (299, 299))
# build model
resnet50_model = model_built(resnet50.ResNet50, resNet_input_shape, resnet50.preprocess_input, 10, None)
#xception_model = model_built(xception.Xception, (299, 299, 3), xception.preprocess_input, 10, None)
# trainmodel
ckpt = ModelCheckpoint('resnet50.weights.{epoch:02d}-{val_loss:.2f}.hdf5', verbose=1, save_best_only=True, save_weights_only=True)
adam = optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
resnet50_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
hist = resnet50_model.fit_generator(resnet50_train_generator, len(resnet50_train_generator), epochs=6,\
validation_data=resnet50_valid_generator, validation_steps=len(resnet50_valid_generator), callbacks=[ckpt])
#ckpt = ModelCheckpoint('xception.weights.{epoch:02d}-{val_loss:.2f}.hdf5', verbose=1, save_best_only=True, save_weights_only=True)
#adam = optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
#xception_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
#hist = xception_model.fit_generator(xception_train_generator, len(xception_train_generator), epochs=6,\
#validation_data=xception_valid_generator, validation_steps=len(xception_valid_generator), callbacks=[ckpt])
Found 17532 images belonging to 10 classes.
Found 4892 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/6
548/548 [==============================] - 406s 742ms/step - loss: 0.1727 - acc: 0.9488 - val_loss: 0.4518 - val_acc: 0.8581
Epoch 00001: val_loss improved from inf to 0.45175, saving model to resnet50.weights.01-0.45.hdf5
Epoch 2/6
548/548 [==============================] - 110s 201ms/step - loss: 0.0169 - acc: 0.9956 - val_loss: 0.3573 - val_acc: 0.8884
Epoch 00002: val_loss improved from 0.45175 to 0.35731, saving model to resnet50.weights.02-0.36.hdf5
Epoch 3/6
548/548 [==============================] - 111s 202ms/step - loss: 0.0214 - acc: 0.9942 - val_loss: 0.4290 - val_acc: 0.8581
Epoch 00003: val_loss did not improve
Epoch 4/6
548/548 [==============================] - 110s 201ms/step - loss: 0.0126 - acc: 0.9970 - val_loss: 0.3218 - val_acc: 0.8996
Epoch 00004: val_loss improved from 0.35731 to 0.32176, saving model to resnet50.weights.04-0.32.hdf5
Epoch 5/6
548/548 [==============================] - 110s 201ms/step - loss: 0.0170 - acc: 0.9958 - val_loss: 0.3336 - val_acc: 0.9064
Epoch 00005: val_loss did not improve
Epoch 6/6
548/548 [==============================] - 111s 202ms/step - loss: 0.0157 - acc: 0.9961 - val_loss: 0.3360 - val_acc: 0.8894
Epoch 00006: val_loss did not improve
#adam = optimizers.Adam(lr=1e-5, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
#xception_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
#xception_model.load_weights('weights.06-0.25.hdf5')
#hist = xception_model.fit_generator(xception_train_generator, len(xception_train_generator), epochs=6,\
#validation_data=xception_valid_generator, validation_steps=len(xception_valid_generator), \
#callbacks=[ckpt])
adam = optimizers.Adam(lr=1e-5, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
resnet50_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
hist = resnet50_model.fit_generator(resnet50_train_generator, len(resnet50_train_generator), epochs=6,\
validation_data=resnet50_valid_generator, validation_steps=len(resnet50_valid_generator), \
callbacks=[ckpt], initial_epoch=4)
Epoch 5/6
548/548 [==============================] - 117s 214ms/step - loss: 6.0906e-04 - acc: 0.9999 - val_loss: 0.3074 - val_acc: 0.9004
Epoch 00005: val_loss improved from 0.32176 to 0.30740, saving model to resnet50.weights.05-0.31.hdf5
Epoch 6/6
548/548 [==============================] - 110s 200ms/step - loss: 1.3819e-04 - acc: 1.0000 - val_loss: 0.2772 - val_acc: 0.9101
Epoch 00006: val_loss improved from 0.30740 to 0.27716, saving model to resnet50.weights.06-0.28.hdf5
hist = resnet50_model.fit_generator(resnet50_train_generator, len(resnet50_train_generator), epochs=6,\
validation_data=resnet50_valid_generator, validation_steps=len(resnet50_valid_generator), \
callbacks=[ckpt])
Epoch 1/6
548/548 [==============================] - 111s 203ms/step - loss: 8.4052e-05 - acc: 1.0000 - val_loss: 0.2922 - val_acc: 0.9099
Epoch 00001: val_loss did not improve
Epoch 2/6
548/548 [==============================] - 110s 200ms/step - loss: 6.5537e-04 - acc: 0.9999 - val_loss: 0.2701 - val_acc: 0.9129
Epoch 00002: val_loss improved from 0.27716 to 0.27013, saving model to resnet50.weights.02-0.27.hdf5
Epoch 3/6
548/548 [==============================] - 110s 201ms/step - loss: 2.7870e-05 - acc: 1.0000 - val_loss: 0.2763 - val_acc: 0.9135
Epoch 00003: val_loss did not improve
Epoch 4/6
548/548 [==============================] - 110s 201ms/step - loss: 2.7291e-05 - acc: 1.0000 - val_loss: 0.2937 - val_acc: 0.9099
Epoch 00004: val_loss did not improve
Epoch 5/6
548/548 [==============================] - 110s 202ms/step - loss: 1.7776e-05 - acc: 1.0000 - val_loss: 0.2863 - val_acc: 0.9121
Epoch 00005: val_loss did not improve
Epoch 6/6
548/548 [==============================] - 110s 200ms/step - loss: 1.9905e-05 - acc: 1.0000 - val_loss: 0.2925 - val_acc: 0.9119
Epoch 00006: val_loss did not improve
#get_test_result(xception_model, xception_test_generator, model_name="xception_test_result")
Now to predict
2345/2492 [===========================>..] - ETA: 1:00
可以看到,经过了上面的训练,验证集loss收敛于一个比较“真实”的loss,所以解决了我们所说的第二个问题。
上面的训练过程折腾了挺长时间,主要在于讯息速率的选择,学习速率如果选择过大,或者过小,都无法收敛到一个合适的值,学习速率太大会破坏预训练的权重,也就是我们说的“跑炸了”,学习速率太小收敛速度慢到无法忍受。
最终resnet-50模型收敛到了验证集loss为0.27,提交到kaggle之后,loss为0.34左右,在kaggle排名为前10%左右。但是我还不满足于这个结果,所以要提高模型的分数,自然就是要用到模型融合了。
模型融合的对于比赛的重要性不言而喻,观察kaggle的各个竞赛项目,排名最高的各位大神们不出意外地都是用了各种各样的模型融合方法,我把这种所谓的模型融合方法叫做集成学习方法(或者他们两个很像)。所以接下来我将使用集成学习方法中的bagging方法来进行模型的集成。
集成学习的核心思想就是三个臭皮匠订个诸葛亮,举个例子,比如你想知道一支股票的涨跌,你仅仅去询问你的一个朋友,得知预测结果后你可能还是不太方放心,如果你陆续询问多个朋友,然后把他们的建议综合起来,那么现在是不是放心多了呢?这就是集成学习的思想。如果想了解什么是集成学习,可以参考我的学习比较中关于集成学习的章节。集成学习比较重要的一个前提是diversity,也就是差异性,如果我们集成了几个没有差异性的模型,这样的集成是无效的,什么意思呢?和刚才那个例子一样,如果对于你询问的股票涨跌问题,你的所有朋友总是给出一致的答案,那么他们集成和不集成又有什么区别呢?所以我们要考虑的是差异性。
差异性的来源有两个:
- 一个是数据本身的差异性,也就是对于同一个模型,如果放在两个不同的数据集上分别进行训练,那么一般会得到两个不同的g(x)。
- 一个是模型的本身带来的差异性,也就是不同的模型在同样的数据上面进行训练, 一般也会得到两个不同的g(x)。
所以我们接下来的方案就是利用数据本身的差异性来得到几个拥有diversity的g(x),然后再将这些g(x)的预测结果进行uniform blending(其实就是所有的g(x)的输出值加起来然后取平均值),得到最终的预测结果。那么这里的数据差异性如何产生呢?这里我们可以借鉴K折交叉验证的方法来进行数据的划分和模型的训练,不同的是K折交叉验证的目的是为了得出一个模型的客观的分数,而我们的目的是为了产生不同的g(x)来进行融合,所以我们只是借鉴K折交叉验证的方法而不是使用K折交叉验证哟,想要了解什么是K折交叉验证请参考培神的这篇文章
生成训练集和测试集的接口函数,传入验证集的司机ID的list。注意这个接口会删除之前创建的链接文件
import pandas as pd
import os
import shutil
classes = ['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9']
def train_valid_split(validation_drivers):
"""
validation_drivers: driver id list, like:["p021","p022","p024","p026"]
warning:this function will remove the old link dir
"""
#validation_drivers = ["p021","p022","p024","p026"]
if isdir(test_link_path):
shutil.rmtree(test_link_path)
if not isdir(test_link_path):
os.makedirs(test_link_path)
c_filenames = os.listdir(test_dir_name)
for file in c_filenames:
os.symlink('../../' + test_dir_name + '/' + file, test_link_path + '/' + file)
## check and remove the train dir
if isdir(link_path):
shutil.rmtree(link_path)
if not isdir(link_path):
os.makedirs(link_train_path)
os.makedirs(link_valid_path)
# get validation file names
df = pd.read_csv(driver_imgs_list_file)
validation_files = list(df[df['subject'].isin(validation_drivers)]['img'])
# make c0-c9
for c in classes:
os.makedirs(link_train_path + '/' + c)
os.makedirs(link_valid_path + '/' + c)
# create link from train dir
for c in classes:
# get path name
c_path = train_dir_name + '/' + c
train_dst_path = link_train_path + '/' + c
valid_dst_path = link_valid_path + '/' + c
# list all file name of this class
c_filenames = os.listdir(c_path)
# create validation data of this class
for file in c_filenames:
if file in validation_files:
os.symlink('../../../' + c_path + '/' + file, valid_dst_path + '/' + file)
else:
os.symlink('../../../' + c_path + '/' + file, train_dst_path + '/' + file)
data generator,根据前面产生的训练集和验证集,我们来生成generator
def data_generator(train_dir, valid_dir, test_dir, image_size):
"""
image_size: the output of the image size, like (224, 224)
"""
#gen = ImageDataGenerator(shear_range=0.3, zoom_range=0.3, rotation_range=0.3)
gen = ImageDataGenerator(rotation_range=10.,
width_shift_range=0.05,
height_shift_range=0.05,
shear_range=0.1,
zoom_range=0.1)
gen_valid = ImageDataGenerator()
test_gen = ImageDataGenerator()
# create train generator
train_generator = gen.flow_from_directory(train_dir, image_size, color_mode='rgb', \
classes=classes, class_mode='categorical', shuffle=True, batch_size=32)
# create validation generator
valid_generator = gen_valid.flow_from_directory(valid_dir, image_size, color_mode='rgb', \
classes=classes, class_mode='categorical', shuffle=False, batch_size=32)
test_generator = test_gen.flow_from_directory(test_dir, image_size, color_mode='rgb', \
class_mode=None, shuffle=False, batch_size=32)
return train_generator, valid_generator, test_generator
def get_model(MODEL, input_shape, preprocess_input, output_num):
"""
MODEL: pretrained model
input_shape: pre-trained model's input shape
preprocessing_input: pre-trained model's preprocessing function
"""
## get pretrained model
x = Input(shape=input_shape)
if preprocess_input:
x = Lambda(preprocess_input)(x)
notop_model = MODEL(include_top=False, weights='imagenet', input_tensor=x, input_shape=input_shape)
x = GlobalAveragePooling2D()(notop_model.output)
## build top layer
x = Dropout(0.5, name='dropout_1')(x)
out = Dense(output_num, activation='softmax', name='dense_1')(x)
ret_model = Model(inputs=notop_model.input, outputs=out)
return ret_model
def get_test_result(model_obj, generator, result_file_name="default"):
print("Now to predict result!")
pred_test = model_obj.predict_generator(generator, len(generator), verbose=1)
pred_test = np.array(pred_test)
pred_test = pred_test.clip(min=0.005, max=0.995)
print("Creating datasheet!")
result = pd.DataFrame(pred_test, columns=classes)
test_filenames = []
for f in test_generator.filenames:
test_filenames.append(os.path.basename(f))
result.loc[:, 'img'] = pd.Series(test_filenames, index=result.index)
result.to_csv(result_file_name, index=None)
print ('Test result file %s generated!' % (result_file_name))
import pandas as pd
def merge_test_results(file_name_list, result_file_name='result'):
result = None
# get file number
file_num = len(file_name_list)
# read all test result
img = []
for i,name in enumerate(file_name_list):
df = pd.read_csv(name)
if i == 0:
img = df['img']
result = df.drop('img', axis=1)
else:
result = result + df.drop('img', axis=1)
result = result / float(file_num)
result['img'] = img
result.to_csv(result_file_name, index=None)
print ("Final result file: " + result_file_name)
我打算融合resnet-50训练出来的四个模型,其中每个模型是由不同的训练集划分出来的,通过上面的过程我们了解到,需要根据司机ID来划分训练集和验证集,在这里我假设每个验证集包含4个司机的数据,现在我们就来划分一下这四个模型需要的训练集和验证集
from keras.applications import resnet50
from keras import optimizers
from keras.callbacks import ModelCheckpoint
res_drivers_id_list = [["p039","p047","p052","p066"],
["p015","p041","p049","p081"],
["p002","p016","p035","p050"],
["p024","p041","p049","p052"]]
res_image_size = (224, 224)
res_input_shape = (224, 224, 3)
csv_names_list = []
for i in range(4):
print ('------------------------------------------------------------------------------------------------------------------')
""" 1. about data """
# create link and remove the old link
train_valid_split(res_drivers_id_list[i])
# get generator
train_generator, valid_generator, test_generator = data_generator(link_train_path, link_valid_path, test_link, res_image_size)
""" 2. about model """
# get model
resnet50_model = get_model(resnet50.ResNet50, res_input_shape, resnet50.preprocess_input, 10)
# compile
weights_file_name = 'resnet50.'+ str(i) + '.weights.hdf5'
ckpt = ModelCheckpoint(weights_file_name, verbose=1, save_best_only=True, save_weights_only=True)
adam = optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
resnet50_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
# fit
hist_1 = resnet50_model.fit_generator(train_generator, len(train_generator), epochs=1,\
validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
# comile
adam = optimizers.Adam(lr=1e-5, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
resnet50_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
# fit using a small learning rate
hist_2 = resnet50_model.fit_generator(train_generator, len(train_generator), epochs=3,\
validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
# load weights
resnet50_model.load_weights(weights_file_name)
""" 3. about result """
result_file_name = 'resnet50.'+ str(i) + '.test.result.hdf5'
get_test_result(resnet50_model, test_generator, result_file_name=result_file_name)
csv_names_list.append(result_file_name)
------------------------------------------------------------------------------------------------------------------
Found 19164 images belonging to 10 classes.
Found 3260 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
599/599 [==============================] - 839s 1s/step - loss: 0.2265 - acc: 0.9306 - val_loss: 0.5889 - val_acc: 0.7936
Epoch 00001: val_loss improved from inf to 0.58890, saving model to resnet50.0.weights.hdf5
Epoch 1/3
599/599 [==============================] - 306s 512ms/step - loss: 0.0173 - acc: 0.9955 - val_loss: 0.4069 - val_acc: 0.8663
Epoch 00001: val_loss improved from 0.58890 to 0.40685, saving model to resnet50.0.weights.hdf5
Epoch 2/3
599/599 [==============================] - 303s 506ms/step - loss: 0.0101 - acc: 0.9973 - val_loss: 0.5276 - val_acc: 0.8534
Epoch 00002: val_loss did not improve
Epoch 3/3
599/599 [==============================] - 299s 499ms/step - loss: 0.0052 - acc: 0.9987 - val_loss: 0.3894 - val_acc: 0.8840
Epoch 00003: val_loss improved from 0.40685 to 0.38938, saving model to resnet50.0.weights.hdf5
Now to predict result!
2492/2492 [==============================] - 1623s 651ms/step
Creating datasheet!
Test result file resnet50.0.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 19110 images belonging to 10 classes.
Found 3314 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
598/598 [==============================] - 318s 531ms/step - loss: 0.2368 - acc: 0.9269 - val_loss: 0.6308 - val_acc: 0.7948
Epoch 00001: val_loss improved from inf to 0.63077, saving model to resnet50.1.weights.hdf5
Epoch 1/3
598/598 [==============================] - 315s 527ms/step - loss: 0.0196 - acc: 0.9949 - val_loss: 0.3938 - val_acc: 0.8699
Epoch 00001: val_loss improved from 0.63077 to 0.39381, saving model to resnet50.1.weights.hdf5
Epoch 2/3
598/598 [==============================] - 308s 515ms/step - loss: 0.0114 - acc: 0.9971 - val_loss: 0.3073 - val_acc: 0.9001
Epoch 00002: val_loss improved from 0.39381 to 0.30732, saving model to resnet50.1.weights.hdf5
Epoch 3/3
598/598 [==============================] - 307s 514ms/step - loss: 0.0075 - acc: 0.9980 - val_loss: 0.4784 - val_acc: 0.8721
Epoch 00003: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 403s 162ms/step
Creating datasheet!
Test result file resnet50.1.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 18983 images belonging to 10 classes.
Found 3441 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
594/594 [==============================] - 322s 541ms/step - loss: 0.2399 - acc: 0.9267 - val_loss: 0.4065 - val_acc: 0.8750
Epoch 00001: val_loss improved from inf to 0.40654, saving model to resnet50.2.weights.hdf5
Epoch 1/3
594/594 [==============================] - 322s 542ms/step - loss: 0.0191 - acc: 0.9955 - val_loss: 0.3682 - val_acc: 0.9195
Epoch 00001: val_loss improved from 0.40654 to 0.36821, saving model to resnet50.2.weights.hdf5
Epoch 2/3
594/594 [==============================] - 309s 521ms/step - loss: 0.0081 - acc: 0.9978 - val_loss: 0.4680 - val_acc: 0.9183
Epoch 00002: val_loss did not improve
Epoch 3/3
594/594 [==============================] - 310s 522ms/step - loss: 0.0066 - acc: 0.9979 - val_loss: 0.3676 - val_acc: 0.9160
Epoch 00003: val_loss improved from 0.36821 to 0.36761, saving model to resnet50.2.weights.hdf5
Now to predict result!
2492/2492 [==============================] - 407s 163ms/step
Creating datasheet!
Test result file resnet50.2.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 18842 images belonging to 10 classes.
Found 3582 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
589/589 [==============================] - 327s 556ms/step - loss: 0.2486 - acc: 0.9233 - val_loss: 0.3011 - val_acc: 0.8867
Epoch 00001: val_loss improved from inf to 0.30106, saving model to resnet50.3.weights.hdf5
Epoch 1/3
589/589 [==============================] - 328s 556ms/step - loss: 0.0211 - acc: 0.9944 - val_loss: 0.1510 - val_acc: 0.9506
Epoch 00001: val_loss improved from 0.30106 to 0.15104, saving model to resnet50.3.weights.hdf5
Epoch 2/3
589/589 [==============================] - 305s 518ms/step - loss: 0.0101 - acc: 0.9971 - val_loss: 0.1875 - val_acc: 0.9369
Epoch 00002: val_loss did not improve
Epoch 3/3
589/589 [==============================] - 306s 520ms/step - loss: 0.0064 - acc: 0.9982 - val_loss: 0.2701 - val_acc: 0.9160
Epoch 00003: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 402s 161ms/step
Creating datasheet!
Test result file resnet50.3.test.result.hdf5 generated!
merge_test_results(csv_names_list, 'resnet50.finall.result.csv')
Final result file: resnet50.finall.result.csv
可以看出在做出了resnet50的模型融合以后,提交kaggle发现loss由0.35下降到了0.23,效果明显提升了。
我打算融合xception训练出来的四个模型,其中每个模型是由不同的训练集划分出来的,出了模型不同外,其他均和resnet50的操作一致。
from keras.applications import xception
from keras import optimizers
from keras.callbacks import ModelCheckpoint
xcp_drivers_id_list = [['p016', 'p072', 'p026', 'p066'],
['p042', 'p022', 'p045', 'p075'],
['p064', 'p047', 'p056', 'p061'],
['p039', 'p012', 'p015', 'p052']]
xcp_image_size = (299, 299)
xcp_input_shape = (299, 299, 3)
xcp_csv_names_list = []
for i in range(4):
print ('------------------------------------------------------------------------------------------------------------------')
""" 1. about data """
# create link and remove the old link
train_valid_split(xcp_drivers_id_list[i])
# get generator
train_generator, valid_generator, test_generator = data_generator(link_train_path, link_valid_path, test_link, xcp_image_size)
""" 2. about model """
# get model
xception_model = get_model(xception.Xception, xcp_input_shape, xception.preprocess_input, 10)
# compile
weights_file_name = 'xception.'+ str(i) + '.weights.hdf5'
ckpt = ModelCheckpoint(weights_file_name, verbose=1, save_best_only=True, save_weights_only=True)
adam = optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
xception_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
# fit
hist_1 = xception_model.fit_generator(train_generator, len(train_generator), epochs=1,\
validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
# comile
adam = optimizers.Adam(lr=1e-5, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
xception_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
# fit using a small learning rate
hist_2 = xception_model.fit_generator(train_generator, len(train_generator), epochs=3,\
validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
# load weights
xception_model.load_weights(weights_file_name)
""" 3. about result """
result_file_name = 'xception.'+ str(i) + '.test.result.hdf5'
get_test_result(xception_model, test_generator, result_file_name=result_file_name)
xcp_csv_names_list.append(result_file_name)
------------------------------------------------------------------------------------------------------------------
Found 18770 images belonging to 10 classes.
Found 3654 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
587/587 [==============================] - 471s 803ms/step - loss: 0.2714 - acc: 0.9238 - val_loss: 0.3599 - val_acc: 0.8815
Epoch 00001: val_loss improved from inf to 0.35995, saving model to xception.0.weights.hdf5
Epoch 1/3
587/587 [==============================] - 477s 812ms/step - loss: 0.0140 - acc: 0.9966 - val_loss: 0.3484 - val_acc: 0.8878
Epoch 00001: val_loss improved from 0.35995 to 0.34837, saving model to xception.0.weights.hdf5
Epoch 2/3
587/587 [==============================] - 457s 778ms/step - loss: 0.0074 - acc: 0.9985 - val_loss: 0.3467 - val_acc: 0.8935
Epoch 00002: val_loss improved from 0.34837 to 0.34673, saving model to xception.0.weights.hdf5
Epoch 3/3
587/587 [==============================] - 458s 780ms/step - loss: 0.0047 - acc: 0.9990 - val_loss: 0.3608 - val_acc: 0.8886
Epoch 00003: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 456s 183ms/step
Creating datasheet!
Test result file xception.0.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 19062 images belonging to 10 classes.
Found 3362 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
596/596 [==============================] - 491s 824ms/step - loss: 0.2813 - acc: 0.9198 - val_loss: 0.2815 - val_acc: 0.9087
Epoch 00001: val_loss improved from inf to 0.28149, saving model to xception.1.weights.hdf5
Epoch 1/3
596/596 [==============================] - 486s 816ms/step - loss: 0.0164 - acc: 0.9960 - val_loss: 0.1917 - val_acc: 0.9337
Epoch 00001: val_loss improved from 0.28149 to 0.19174, saving model to xception.1.weights.hdf5
Epoch 2/3
596/596 [==============================] - 460s 771ms/step - loss: 0.0091 - acc: 0.9984 - val_loss: 0.1643 - val_acc: 0.9485
Epoch 00002: val_loss improved from 0.19174 to 0.16431, saving model to xception.1.weights.hdf5
Epoch 3/3
596/596 [==============================] - 491s 824ms/step - loss: 0.0055 - acc: 0.9989 - val_loss: 0.1670 - val_acc: 0.9465
Epoch 00003: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 510s 205ms/step
Creating datasheet!
Test result file xception.1.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 19166 images belonging to 10 classes.
Found 3258 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
599/599 [==============================] - 542s 904ms/step - loss: 0.2616 - acc: 0.9270 - val_loss: 0.5691 - val_acc: 0.8536
Epoch 00001: val_loss improved from inf to 0.56912, saving model to xception.2.weights.hdf5
Epoch 1/3
599/599 [==============================] - 538s 899ms/step - loss: 0.0131 - acc: 0.9971 - val_loss: 0.3631 - val_acc: 0.9033
Epoch 00001: val_loss improved from 0.56912 to 0.36315, saving model to xception.2.weights.hdf5
Epoch 2/3
599/599 [==============================] - 508s 848ms/step - loss: 0.0066 - acc: 0.9987 - val_loss: 0.3985 - val_acc: 0.8929
Epoch 00002: val_loss did not improve
Epoch 3/3
599/599 [==============================] - 504s 842ms/step - loss: 0.0040 - acc: 0.9991 - val_loss: 0.4135 - val_acc: 0.8895
Epoch 00003: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 494s 198ms/step
Creating datasheet!
Test result file xception.2.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 19335 images belonging to 10 classes.
Found 3089 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
605/605 [==============================] - 548s 905ms/step - loss: 0.2640 - acc: 0.9264 - val_loss: 0.3926 - val_acc: 0.8799
Epoch 00001: val_loss improved from inf to 0.39262, saving model to xception.3.weights.hdf5
Epoch 1/3
605/605 [==============================] - 547s 904ms/step - loss: 0.0147 - acc: 0.9963 - val_loss: 0.3121 - val_acc: 0.9009
Epoch 00001: val_loss improved from 0.39262 to 0.31212, saving model to xception.3.weights.hdf5
Epoch 2/3
605/605 [==============================] - 516s 853ms/step - loss: 0.0075 - acc: 0.9982 - val_loss: 0.3219 - val_acc: 0.8912
Epoch 00002: val_loss did not improve
Epoch 3/3
605/605 [==============================] - 515s 851ms/step - loss: 0.0047 - acc: 0.9992 - val_loss: 0.3226 - val_acc: 0.8880
Epoch 00003: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 515s 207ms/step
Creating datasheet!
Test result file xception.3.test.result.hdf5 generated!
merge_test_results(xcp_csv_names_list, 'xception.finall.result.csv')
Final result file: xception.finall.result.csv
xception的模型的融合结果提交kaggle之后,loss为:0.27290
我打算融合inceptionv3训练出来的四个模型,其中每个模型是由不同的训练集划分出来的,出了模型不同外,其他均和resnet50的操作一致。
from keras.applications import inception_v3
from keras import optimizers
from keras.callbacks import ModelCheckpoint
inc_drivers_id_list = [['p064', 'p056', 'p047', 'p041'],
['p051', 'p072', 'p052', 'p049'],
['p002', 'p050', 'p014', 'p075'],
['p035', 'p039', 'p012', 'p081']]
inc_image_size = (299, 299)
inc_input_shape = (299, 299, 3)
inc_csv_names_list = []
for i in range(4):
print ('------------------------------------------------------------------------------------------------------------------')
""" 1. about data """
# create link and remove the old link
train_valid_split(inc_drivers_id_list[i])
# get generator
train_generator, valid_generator, test_generator = data_generator(link_train_path, link_valid_path, test_link, inc_image_size)
""" 2. about model """
# get model
inceptionv3_model = get_model(inception_v3.InceptionV3, inc_input_shape, inception_v3.preprocess_input, 10)
# compile
weights_file_name = 'inceptionv3.'+ str(i) + '.weights.hdf5'
ckpt = ModelCheckpoint(weights_file_name, verbose=1, save_best_only=True, save_weights_only=True)
adam = optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
inceptionv3_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
# fit
hist_1 = inceptionv3_model.fit_generator(train_generator, len(train_generator), epochs=1,\
validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
# comile
adam = optimizers.Adam(lr=1e-5, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
inceptionv3_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
# fit using a small learning rate
hist_2 = inceptionv3_model.fit_generator(train_generator, len(train_generator), epochs=2,\
validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
# load weights
inceptionv3_model.load_weights(weights_file_name)
""" 3. about result """
result_file_name = 'inceptionv3.'+ str(i) + '.test.result.hdf5'
get_test_result(inceptionv3_model, test_generator, result_file_name=result_file_name)
inc_csv_names_list.append(result_file_name)
------------------------------------------------------------------------------------------------------------------
Found 19370 images belonging to 10 classes.
Found 3054 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
606/606 [==============================] - 557s 920ms/step - loss: 0.2447 - acc: 0.9254 - val_loss: 0.6483 - val_acc: 0.8471
Epoch 00001: val_loss improved from inf to 0.64835, saving model to inceptionv3.0.weights.hdf5
Epoch 1/2
606/606 [==============================] - 560s 924ms/step - loss: 0.0203 - acc: 0.9944 - val_loss: 0.4773 - val_acc: 0.8661
Epoch 00001: val_loss improved from 0.64835 to 0.47726, saving model to inceptionv3.0.weights.hdf5
Epoch 2/2
606/606 [==============================] - 513s 846ms/step - loss: 0.0091 - acc: 0.9979 - val_loss: 0.4670 - val_acc: 0.8752
Epoch 00002: val_loss improved from 0.47726 to 0.46698, saving model to inceptionv3.0.weights.hdf5
Now to predict result!
2492/2492 [==============================] - 511s 205ms/step
Creating datasheet!
Test result file inceptionv3.0.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 19407 images belonging to 10 classes.
Found 3017 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
607/607 [==============================] - 572s 943ms/step - loss: 0.2530 - acc: 0.9239 - val_loss: 0.5091 - val_acc: 0.8472
Epoch 00001: val_loss improved from inf to 0.50909, saving model to inceptionv3.1.weights.hdf5
Epoch 1/2
607/607 [==============================] - 518s 853ms/step - loss: 0.0101 - acc: 0.9974 - val_loss: 0.3786 - val_acc: 0.8810
Epoch 00002: val_loss improved from 0.38029 to 0.37860, saving model to inceptionv3.1.weights.hdf5
Now to predict result!
2492/2492 [==============================] - 521s 209ms/step
Creating datasheet!
Test result file inceptionv3.1.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 19219 images belonging to 10 classes.
Found 3205 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
601/601 [==============================] - 568s 945ms/step - loss: 0.2453 - acc: 0.9248 - val_loss: 0.5244 - val_acc: 0.8537
Epoch 00001: val_loss improved from inf to 0.52442, saving model to inceptionv3.2.weights.hdf5
Epoch 1/2
601/601 [==============================] - 568s 945ms/step - loss: 0.0153 - acc: 0.9962 - val_loss: 0.4941 - val_acc: 0.8955
Epoch 00001: val_loss improved from 0.52442 to 0.49411, saving model to inceptionv3.2.weights.hdf5
Epoch 2/2
601/601 [==============================] - 509s 847ms/step - loss: 0.0069 - acc: 0.9982 - val_loss: 0.5401 - val_acc: 0.8880
Epoch 00002: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 513s 206ms/step
Creating datasheet!
Test result file inceptionv3.2.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 19279 images belonging to 10 classes.
Found 3145 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
603/603 [==============================] - 593s 983ms/step - loss: 0.2375 - acc: 0.9272 - val_loss: 0.5779 - val_acc: 0.8391
Epoch 00001: val_loss improved from inf to 0.57787, saving model to inceptionv3.3.weights.hdf5
Epoch 1/2
603/603 [==============================] - 582s 966ms/step - loss: 0.0224 - acc: 0.9938 - val_loss: 0.3650 - val_acc: 0.8979
Epoch 00001: val_loss improved from 0.57787 to 0.36498, saving model to inceptionv3.3.weights.hdf5
Epoch 2/2
603/603 [==============================] - 515s 853ms/step - loss: 0.0090 - acc: 0.9978 - val_loss: 0.3821 - val_acc: 0.8957
Epoch 00002: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 521s 209ms/step
Creating datasheet!
Test result file inceptionv3.3.test.result.hdf5 generated!
merge_test_results(inc_csv_names_list, 'inceptionv3.finall.result.csv')
Final result file: inceptionv3.finall.result.csv
提交kaggle之后分数为0.31467
我打算融合inception_resnet_v2训练出来的四个模型,其中每个模型是由不同的训练集划分出来的,出了模型不同外,其他均和resnet50的操作一致。我不打算再调
from keras.applications import inception_resnet_v2
from keras import optimizers
from keras.callbacks import ModelCheckpoint
ire_drivers_id_list = [['p041', 'p026', 'p022', 'p002'],
['p061', 'p081', 'p015', 'p075'],
['p042', 'p045', 'p047', 'p012'],
['p064', 'p051', 'p072', 'p021']]
ire_image_size = (299, 299)
ire_input_shape = (299, 299, 3)
ire_csv_names_list = []
for i in range(4):
print ('------------------------------------------------------------------------------------------------------------------')
""" 1. about data """
# create link and remove the old link
train_valid_split(ire_drivers_id_list[i])
# get generator
train_generator, valid_generator, test_generator = data_generator(link_train_path, link_valid_path, test_link, ire_image_size)
""" 2. about model """
# get model
inception_resnet_v2_model = get_model(inception_resnet_v2.InceptionResNetV2, ire_input_shape, inception_resnet_v2.preprocess_input, 10)
# compile
weights_file_name = 'inception_resnet_v2.'+ str(i) + '.weights.hdf5'
ckpt = ModelCheckpoint(weights_file_name, verbose=1, save_best_only=True, save_weights_only=True)
adam = optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
inception_resnet_v2_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
# fit
hist_1 = inception_resnet_v2_model.fit_generator(train_generator, len(train_generator), epochs=1,\
validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
# comile
adam = optimizers.Adam(lr=1e-5, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
inception_resnet_v2_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
# fit using a small learning rate
hist_2 = inception_resnet_v2_model.fit_generator(train_generator, len(train_generator), epochs=3,\
validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
# load weights
inception_resnet_v2_model.load_weights(weights_file_name)
""" 3. about result """
result_file_name = 'inception_resnet_v2.'+ str(i) + '.test.result.hdf5'
get_test_result(inception_resnet_v2_model, test_generator, result_file_name=result_file_name)
ire_csv_names_list.append(result_file_name)
------------------------------------------------------------------------------------------------------------------
Found 18665 images belonging to 10 classes.
Found 3759 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
584/584 [==============================] - 532s 912ms/step - loss: 0.2403 - acc: 0.9257 - val_loss: 0.5532 - val_acc: 0.8843
Epoch 00001: val_loss improved from inf to 0.55320, saving model to inception_resnet_v2.0.weights.hdf5
Epoch 1/3
584/584 [==============================] - 542s 928ms/step - loss: 0.0168 - acc: 0.9959 - val_loss: 0.2754 - val_acc: 0.9146
Epoch 00001: val_loss improved from 0.55320 to 0.27539, saving model to inception_resnet_v2.0.weights.hdf5
Epoch 2/3
584/584 [==============================] - 506s 867ms/step - loss: 0.0076 - acc: 0.9981 - val_loss: 0.3755 - val_acc: 0.9013
Epoch 00002: val_loss did not improve
Epoch 3/3
584/584 [==============================] - 506s 866ms/step - loss: 0.0037 - acc: 0.9991 - val_loss: 0.2906 - val_acc: 0.9197
Epoch 00003: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 498s 200ms/step
Creating datasheet!
Test result file inception_resnet_v2.0.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 19103 images belonging to 10 classes.
Found 3321 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
597/597 [==============================] - 565s 946ms/step - loss: 0.2267 - acc: 0.9294 - val_loss: 0.4564 - val_acc: 0.8811
Epoch 00001: val_loss improved from inf to 0.45637, saving model to inception_resnet_v2.1.weights.hdf5
Epoch 1/3
597/597 [==============================] - 563s 943ms/step - loss: 0.0129 - acc: 0.9970 - val_loss: 0.4455 - val_acc: 0.8817
Epoch 00001: val_loss improved from 0.45637 to 0.44553, saving model to inception_resnet_v2.1.weights.hdf5
Epoch 2/3
597/597 [==============================] - 524s 879ms/step - loss: 0.0063 - acc: 0.9985 - val_loss: 0.6219 - val_acc: 0.8470
Epoch 00002: val_loss did not improve
Epoch 3/3
597/597 [==============================] - 522s 875ms/step - loss: 0.0036 - acc: 0.9992 - val_loss: 0.5228 - val_acc: 0.8735
Epoch 00003: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 533s 214ms/step
Creating datasheet!
Test result file inception_resnet_v2.1.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 19451 images belonging to 10 classes.
Found 2973 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
14/608 [..............................] - ETA: 41:55 - loss: 2.2983 - acc: 0.1763
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
<ipython-input-14-25af3ea26f2c> in <module>()
37
38 # fit
---> 39 hist_1 = inception_resnet_v2_model.fit_generator(train_generator, len(train_generator), epochs=1, validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
40
41 # comile
~/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
89 warnings.warn('Update your `' + object_name +
90 '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91 return func(*args, **kwargs)
92 wrapper._original_function = func
93 return wrapper
~/anaconda3/lib/python3.6/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
2222 outs = self.train_on_batch(x, y,
2223 sample_weight=sample_weight,
-> 2224 class_weight=class_weight)
2225
2226 if not isinstance(outs, list):
~/anaconda3/lib/python3.6/site-packages/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight)
1881 ins = x + y + sample_weights
1882 self._make_train_function()
-> 1883 outputs = self.train_function(ins)
1884 if len(outputs) == 1:
1885 return outputs[0]
~/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
2476 session = get_session()
2477 updated = session.run(fetches=fetches, feed_dict=feed_dict,
-> 2478 **self.session_kwargs)
2479 return updated[:len(self.outputs)]
2480
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
903 try:
904 result = self._run(None, fetches, feed_dict, options_ptr,
--> 905 run_metadata_ptr)
906 if run_metadata:
907 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1135 if final_fetches or final_targets or (handle and feed_dict_tensor):
1136 results = self._do_run(handle, final_targets, final_fetches,
-> 1137 feed_dict_tensor, options, run_metadata)
1138 else:
1139 results = []
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
1353 if handle is None:
1354 return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1355 options, run_metadata)
1356 else:
1357 return self._do_call(_prun_fn, self._session, handle, feeds, fetches)
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
1359 def _do_call(self, fn, *args):
1360 try:
-> 1361 return fn(*args)
1362 except errors.OpError as e:
1363 message = compat.as_text(e.message)
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
1338 else:
1339 return tf_session.TF_Run(session, options, feed_dict, fetch_list,
-> 1340 target_list, status, run_metadata)
1341
1342 def _prun_fn(session, handle, feed_dict, fetch_list):
KeyboardInterrupt:
ire_csv_names_list = ['inception_resnet_v2.0.test.result.hdf5',
'inception_resnet_v2.1.test.result.hdf5',
'inception_resnet_v2.2.test.result.hdf5',
'inception_resnet_v2.3.test.result.hdf5']
merge_test_results(ire_csv_names_list, 'inception_resnet_v2.finall.result.csv')
Final result file: inception_resnet_v2.finall.result.csv
all_name = ['inception_resnet_v2.finall.result.csv', 'inceptionv3.finall.result.csv', 'resnet50.finall.result.csv', 'xception.finall.result.csv']
merge_test_results(all_name, 'four.model.finall.result.csv')
Final result file: four.model.finall.result.csv
提交kaggle之后,loss为:0.29589
根据上面的结果,如果利用单独的模型来进行融合,resnet-50效果是最好的,可能是由于其他模型没有经过babysitting调试的过程,效果不是很理想,如果现在将四个模型都融合的话,有些模型可能会拖后腿,所以我就手工选择几个验证集表现最好的模型来进行融合
name_list = ['resnet50.3.test.result.hdf5',
'resnet50.1.test.result.hdf5',
'xception.1.test.result.hdf5',
'xception.3.test.result.hdf5',
'inception_resnet_v2.0.test.result.hdf5']
merge_test_results(name_list, 'choose_top_5_model_finall.result.csv')
Final result file: choose_top_5_model_finall.result.csv
提交kaggle之后,发现loss为0.23,已经进入了kaggle排行榜的前%10了,达到了初步估计的目标。由于时间关系(训练一个模型几乎都要差不多2个小时,4个模型就是8个小时),我不再继续优化下去了。
下面再增加几个可能的提升点:
-
出了resnet50模型外,其他模型都没能够很好地收敛,原因是我使用了resnet-50的优化算法来对其他模型进行训练。所以一个提升点就是针对每个模型精心调试
-
训练集增加更多的data augmentation,借鉴kaggle排名靠前的大神的方法。将同一个类别的图片进行左右拼接,这样可以提高模型的泛化能力,让模型知道他应该关注哪些地方,不应该关注哪些地方。
-
增加更多的模型,我们知道,模型融合一个方向之一就是使用更多的模型,正常情况下这样总能得到比较好的分数(除非某个模型非常拖后腿)
上面我们通过对模型的集成达到了不错的预测效果,那么如果我们想看看用于集成的模型到底在关注些什么,要怎么做呢?在周博磊博士的论文Learning Deep Features for Discriminative Localization中,提出了一种通过在CNN加入全局平均池化(global average pooling)层来获得class activation map,对于某一个特定的类,该类对应的class activation map表示CNN在预测该类对象的时候所关注的图像区域。产生某一个类的class activation map的方法也很简单,其实就是全局平均池化层的上一个层的各个feature map的线性组合,而线性组合的系数来自于输出层的该类对应的神经元的权重。
比如说resnet50,全局平均池化层的前一个层的输出维度为7x7x2048,那么class activation map就是这2048个7x7的feature map线性组合的结果,而线性组合的权重就是(在本项目中输出维度是10,所以权重维度就是2048x10,某一类别的神经元对应的权重就是2048个)某个类别对应的输出层神经元的权重了。
那么有了class activation map之后如何将其反映在图片中呢?现在有了一个关于某个类别的7x7的CAM,这个CAM中的某个位置的值就对应了原始图片中对应位置(比如CAM右下角的元素对应原始图片中右下角的内容)的内容对预测某一类别的重要性,所以我们就可以利用cv2的热力图的接口来进行热力图的绘制,通过热力图就可以看出模型在关注些什么。
因为我们要通过获取输出层的权重和卷积层(也就是global average pooling的前一层)的激活值来计算class activation maps,所以在搭建模型的时候也需要将卷积层作为输出。
def cam_model(MODEL, input_shape, preprocess_input, output_num, weights_file_name):
"""
MODEL: pretrained model
input_shape: pre-trained model's input shape
preprocessing_input: pre-trained model's preprocessing function
weights_file_name: weights trained on driver datasheet
"""
## get pretrained model
x = Input(shape=input_shape)
if preprocess_input:
x = Lambda(preprocess_input)(x)
notop_model = MODEL(include_top=False, weights=None, input_tensor=x, input_shape=input_shape)
x = GlobalAveragePooling2D(name='global_average_2d_1')(notop_model.output)
## build top layer
x = Dropout(0.5, name='dropout_1')(x)
out = Dense(output_num, activation='softmax', name='dense_1')(x)
ret_model = Model(inputs=notop_model.input, outputs=[out, notop_model.layers[-2].output])
## load weights
ret_model.load_weights(weights_file_name)
## get the output layer weights
weights = ret_model.layers[-1].get_weights()
return ret_model, np.array(weights[0])
from keras.applications import resnet50
res_image_size = (224, 224)
res_input_shape = (224, 224, 3)
cam_model, cam_weights = cam_model(resnet50.ResNet50, res_input_shape, resnet50.preprocess_input, 10, 'resnet50.3.weights.hdf5')
print (cam_weights.shape)
(2048, 10)
import cv2
import numpy as np
image_path = 'train/c3/img_537.jpg'
resNet_input_shape = (224,224,3)
# read image
image = cv2.imread(image_path)
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
image_input = cv2.resize(image, (resNet_input_shape[0], resNet_input_shape[1]))
image_input_m = np.expand_dims(image_input,axis=0)
# predict and get feature maps
predict_m, feature_maps_m = cam_model.predict(image_input_m)
predict = predict_m[0]
feature_maps = feature_maps_m[0]
# get the class result
class_index = np.argmax(predict)
# get the class_index unit's weights
cam_weights_c = cam_weights[:, class_index]
# get the class activation map
cam = np.matmul(feature_maps, cam_weights_c)
# normalize the cam
cam = (cam - cam.min())/(cam.max())
# do not care the low values
cam[np.where(cam<0.2)] = 0
cam = cv2.resize(cam, (resNet_input_shape[0], resNet_input_shape[1]))
cam = np.uint8(255*cam)
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
des = state_des['c'+str(class_index)]
# draw the hotmap
hotmap = cv2.applyColorMap(cam, cv2.COLORMAP_JET)
# linear combine the picture with cam
dis = cv2.addWeighted(image_input, 0.8, hotmap, 0.4, 0)
plt.title("Predict C" + str(class_index) + ':' + des)
plt.imshow(dis)
plt.axis("off")
(-0.5, 223.5, 223.5, -0.5)
import cv2
import numpy as np
def show_hot_map(image_path, model, cam_weights, input_shape):
""" 1. predict """
# read image
image = cv2.imread(image_path)
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
image_input = cv2.resize(image, (input_shape[0], input_shape[1]))
image_input_m = np.expand_dims(image_input,axis=0)
# predict and get feature maps
predict_m, feature_maps_m = model.predict(image_input_m)
""" 2. get the calss activation maps """
predict = predict_m[0]
feature_maps = feature_maps_m[0]
# get the class result
class_index = np.argmax(predict)
# get the class_index unit's weights
cam_weights_c = cam_weights[:, class_index]
# get the class activation map
cam = np.matmul(feature_maps, cam_weights_c)
# normalize the cam
cam = (cam - cam.min())/(cam.max())
# do not care the low values
cam[np.where(cam<0.2)] = 0
cam = cv2.resize(cam, (input_shape[0], input_shape[1]))
cam = np.uint8(255*cam)
""" 3. show the hot map """
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
des = state_des['c'+str(class_index)]
# draw the hotmap
hotmap = cv2.applyColorMap(cam, cv2.COLORMAP_JET)
# linear combine the picture with cam
dis = cv2.addWeighted(image_input, 0.8, hotmap, 0.4, 0)
plt.title("Predict C" + str(class_index) + ':' + des)
plt.imshow(dis)
plt.axis("off")
image_path = 'train/c3/img_537.jpg'
resNet_input_shape = (224,224,3)
show_hot_map(image_path, cam_model, cam_weights, res_input_shape)
可以通过更换文件名来查看每一张图片模型都在关注什么,看的出来,绝大部分情况下,模型都在做正确的事情。
CAM图中,蓝色部分代表CAM值比较大的区域,也就是模型用来进行分类的中点关注区域。
接下来是一段视频演示,该视频有两个子窗口,分别为正常视频和热力图的视频。照着官方文档使用OpenCV来做视频是很方便的
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
def generate_hot_map(frame, cam_model, model_input_size, cam_weights, cam_size):
"""
image_input_m: CAM model's input
cam_model: CAM model
cam_weights: weights for CAM
cam_size: size of the output picture
"""
# resize frame for predict
img_for_model = cv2.resize(frame, model_input_size)
img_for_model = np.expand_dims(img_for_model,axis=0)
""" 1. predict """
# predict and get feature maps
predict_m, feature_maps_m = cam_model.predict(img_for_model)
""" 2. get the calss activation maps """
predict = predict_m[0]
feature_maps = feature_maps_m[0]
# get the class result
class_index = np.argmax(predict)
# get the class_index unit's weights
cam_weights_c = cam_weights[:, class_index]
# get the class activation map
cam = np.matmul(feature_maps, cam_weights_c)
# normalize the cam
cam = (cam - cam.min())/(cam.max())
# do not care the low values
cam[np.where(cam<0.2)] = 0
cam = cv2.resize(cam, cam_size)
cam = np.uint8(255*cam)
""" 3. show the hot map """
des = state_des['c'+str(class_index)]
# draw the hotmap
hotmap = cv2.applyColorMap(cam, cv2.COLORMAP_JET)
# linear combine the picture with cam
image_input = cv2.resize(frame, cam_size)
dis = cv2.addWeighted(image_input, 0.8, hotmap, 0.4, 0)
return dis,predict
import cv2
import numpy as np
## video width and high
main_video_width = 1280
main_video_high = 720
## subwindow size
sub_video_width = int(main_video_width/2)
sub_video_high = int(main_video_high*0.6)
## subwindow coordinate
sub1_coord_1 = int((main_video_high-sub_video_high)/2)
sub1_coord_2 = int((main_video_high-sub_video_high)/2) + sub_video_high
sub1_coord_3 = 0
sub1_coord_4 = sub1_coord_3 + sub_video_width
sub2_coord_1 = int((main_video_high-sub_video_high)/2)
sub2_coord_2 = int((main_video_high-sub_video_high)/2) + sub_video_high
sub2_coord_3 = sub1_coord_4
sub2_coord_4 = main_video_width
def generate_video_with_classfication(model, model_input_size, video_name_or_camera, cam_weights, generate_video_name='output.avi'):
"""
model: model to predict the video
model_input_size: image size of the model
video_name_or_camera: read videl from camera or local video
cam_weights: weights for CAM
generate_video_name: the output video name
"""
"""0. create videl reader and writer, and get more video message """
cap = cv2.VideoCapture(video_name_or_camera)
video_width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
video_high = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
video_fps = cap.get(cv2.CAP_PROP_FPS)
print ("video size:({width}, {high}) fps:{fps}".format(width=video_width, high=video_high, fps=video_fps))
''' 0. create a new image '''
showBigImage = np.zeros((int(main_video_high), int(main_video_width), 3), np.uint8)
''' 1. create video writer '''
fourcc = cv2.VideoWriter_fourcc(*'XVID')
writer = cv2.VideoWriter(generate_video_name, fourcc, 20.0, (main_video_width, main_video_high))
if(cap.isOpened() == False):
print ("Failed to open " + video_name_or_camera)
return
while True:
""" 2. preprocessing and predict """
# get fram
ret, frame = cap.read()
# check if the video is over
if(ret != True):
print ("Ending!")
break
# get hot map
sub_frame_2, predict = generate_hot_map(frame, model, model_input_size, cam_weights, (sub_video_width, sub_video_high))
""" 3. add text to the image and show"""
class_index = np.argmax(predict)
text = 'Predicted: C{} {}'.format(class_index, state_des['c'+str(class_index)])
font = cv2.FONT_HERSHEY_SIMPLEX
showBigImage[:] = 0
cv2.putText(showBigImage, text, (10, sub1_coord_1-10), font, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
""" 4. resize and fill 2 subwindow """
frame = cv2.resize(frame, (sub_video_width, sub_video_high))
showBigImage[sub1_coord_1:sub1_coord_2, sub1_coord_3:sub1_coord_4] = frame
showBigImage[sub2_coord_1:sub2_coord_2, sub2_coord_3:sub2_coord_4] = sub_frame_2
""" 5. show video """
cv2.imshow('image', showBigImage)
""" 6. save video if need """
writer.write(showBigImage)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
from keras.applications import xception
xce_image_size = (299, 299)
xce_input_shape = (299, 299, 3)
cam_model, cam_weights = cam_model(xception.Xception, xce_input_shape, xception.preprocess_input, 10, 'xception.fixed.weights.hdf5')
generate_video_with_classfication(cam_model, xce_image_size, 'real_fei.mp4', cam_weights, generate_video_name='output.avi')
video size:(544.0, 960.0) fps:30.03370088538598
Ending!
我自己拍摄了一段视频,利用xception的单模型进行预测,但是效果不理想,估计是因为:
-
拍摄角度和训练集和测试集不一致,我的角度是平拍,测试集和训练集的角度是俯拍
-
用于训练模型的视频中涉及到的车辆的佩饰和我拍摄视频的车辆的佩饰不一致,模型的泛化能力比较差
所以视频就暂时不发出来演示了。
之前训练的几个模型所用的数据都是经过了data augmentation的,但是也许我们可以做更“激进”一点儿的data augmentation,使得我们训练出来的模型具有更强大的泛化能力。
通过浏览各个驾驶状态的驾驶员以及CAM可以发现,对区分驾驶状态有帮助的图片区域基本上都位于左上部分:头,和右下部分:手,那么我们是否可以这样进行组合,将一张图的左半部分和另一张同类型的图的右半部分进行组合(后来我也加入了图片的上下两部分的结合),形成一张新的图片,这样既保留了用于区分状态的关键区域,又使得非关键区域(以下均称为背景)不具有主导地位(因为背景会时常变化)。
所以我就简单地修改了一下keras自带的data augmentation(其中标有红色注释的部分就是修改的部分),使其对我们的同类数据进行拼接。代码如下所示
接下来我就利用修改之后的data augmentation来产生数据,对xception进行改进,看看效果如何。
from keras.preprocessing.image import *
from keras.preprocessing.image import _count_valid_files_in_directory
from keras.preprocessing.image import _list_valid_filenames_in_directory
import numpy as np
# new class, MergeImageDataGenerator, to generate the merge image
class MergeImageDataGenerator(ImageDataGenerator):
# redefine flow_from_directory method
def flow_from_directory(self, directory,
target_size=(256, 256), color_mode='rgb',
classes=None, class_mode='categorical',
batch_size=32, shuffle=True, seed=None,
save_to_dir=None,
save_prefix='',
save_format='png',
follow_links=False,
subset=None,
interpolation='nearest'):
# just return the iterator
return MergeDirectoryIterator(
directory, self,
target_size=target_size, color_mode=color_mode,
classes=classes, class_mode=class_mode,
data_format=self.data_format,
batch_size=batch_size, shuffle=shuffle, seed=seed,
save_to_dir=save_to_dir,
save_prefix=save_prefix,
save_format=save_format,
follow_links=follow_links,
subset=subset,
interpolation=interpolation)
# redefine the iterator of the generator
class MergeDirectoryIterator(Iterator):
"""Iterator capable of reading images from a directory on disk.
# Arguments
directory: Path to the directory to read images from.
Each subdirectory in this directory will be
considered to contain images from one class,
or alternatively you could specify class subdirectories
via the `classes` argument.
image_data_generator: Instance of `ImageDataGenerator`
to use for random transformations and normalization.
target_size: tuple of integers, dimensions to resize input images to.
color_mode: One of `"rgb"`, `"grayscale"`. Color mode to read images.
classes: Optional list of strings, names of subdirectories
containing images from each class (e.g. `["dogs", "cats"]`).
It will be computed automatically if not set.
class_mode: Mode for yielding the targets:
`"binary"`: binary targets (if there are only two classes),
`"categorical"`: categorical targets,
`"sparse"`: integer targets,
`"input"`: targets are images identical to input images (mainly
used to work with autoencoders),
`None`: no targets get yielded (only input images are yielded).
batch_size: Integer, size of a batch.
shuffle: Boolean, whether to shuffle the data between epochs.
seed: Random seed for data shuffling.
data_format: String, one of `channels_first`, `channels_last`.
save_to_dir: Optional directory where to save the pictures
being yielded, in a viewable format. This is useful
for visualizing the random transformations being
applied, for debugging purposes.
save_prefix: String prefix to use for saving sample
images (if `save_to_dir` is set).
save_format: Format to use for saving sample images
(if `save_to_dir` is set).
subset: Subset of data (`"training"` or `"validation"`) if
validation_split is set in ImageDataGenerator.
interpolation: Interpolation method used to resample the image if the
target size is different from that of the loaded image.
Supported methods are "nearest", "bilinear", and "bicubic".
If PIL version 1.1.3 or newer is installed, "lanczos" is also
supported. If PIL version 3.4.0 or newer is installed, "box" and
"hamming" are also supported. By default, "nearest" is used.
"""
def __init__(self, directory, image_data_generator,
target_size=(256, 256), color_mode='rgb',
classes=None, class_mode='categorical',
batch_size=32, shuffle=True, seed=None,
data_format=None,
save_to_dir=None, save_prefix='', save_format='png',
follow_links=False,
subset=None,
interpolation='nearest'):
if data_format is None:
data_format = K.image_data_format()
self.directory = directory
self.image_data_generator = image_data_generator
self.target_size = tuple(target_size)
if color_mode not in {'rgb', 'grayscale'}:
raise ValueError('Invalid color mode:', color_mode,
'; expected "rgb" or "grayscale".')
self.color_mode = color_mode
self.data_format = data_format
if self.color_mode == 'rgb':
if self.data_format == 'channels_last':
self.image_shape = self.target_size + (3,)
else:
self.image_shape = (3,) + self.target_size
else:
if self.data_format == 'channels_last':
self.image_shape = self.target_size + (1,)
else:
self.image_shape = (1,) + self.target_size
self.classes = classes
if class_mode not in {'categorical', 'binary', 'sparse',
'input', None}:
raise ValueError('Invalid class_mode:', class_mode,
'; expected one of "categorical", '
'"binary", "sparse", "input"'
' or None.')
self.class_mode = class_mode
self.save_to_dir = save_to_dir
self.save_prefix = save_prefix
self.save_format = save_format
self.interpolation = interpolation
if subset is not None:
validation_split = self.image_data_generator._validation_split
if subset == 'validation':
split = (0, validation_split)
elif subset == 'training':
split = (validation_split, 1)
else:
raise ValueError('Invalid subset name: ', subset,
'; expected "training" or "validation"')
else:
split = None
self.subset = subset
white_list_formats = {'png', 'jpg', 'jpeg', 'bmp', 'ppm', 'tif', 'tiff'}
# first, count the number of samples and classes
self.samples = 0
if not classes:
classes = []
for subdir in sorted(os.listdir(directory)):
if os.path.isdir(os.path.join(directory, subdir)):
classes.append(subdir)
self.num_classes = len(classes)
self.class_indices = dict(zip(classes, range(len(classes))))
pool = multiprocessing.pool.ThreadPool()
function_partial = partial(_count_valid_files_in_directory,
white_list_formats=white_list_formats,
follow_links=follow_links,
split=split)
self.samples = sum(pool.map(function_partial,
(os.path.join(directory, subdir)
for subdir in classes)))
print('!!!!!!!Found %d images belonging to %d classes.' % (self.samples, self.num_classes))
# second, build an index of the images in the different class subfolders
results = []
self.filenames = []
self.classes = np.zeros((self.samples,), dtype='int32')
i = 0
for dirpath in (os.path.join(directory, subdir) for subdir in classes):
results.append(pool.apply_async(_list_valid_filenames_in_directory,
(dirpath, white_list_formats, split,
self.class_indices, follow_links)))
"""
author: rikichou 2018,5,1 21:03:31
record the index range of each class
"""
# Start Code
classes_range = []
# End Code
for res in results:
classes, filenames = res.get()
"""
author: rikichou 2018,5,1 21:04:01
"""
# Start Code
start = i
end = i + len(classes) - 1
classes_range.append({'start':start, 'end':end})
# End Code
self.classes[i:i + len(classes)] = classes
self.filenames += filenames
i += len(classes)
"""
author: rikichou
"""
# Start Code
self.classes_range = classes_range
# End Code
pool.close()
pool.join()
super(MergeDirectoryIterator, self).__init__(self.samples, batch_size, shuffle, seed)
def _get_batches_of_transformed_samples(self, index_array):
batch_x = np.zeros((len(index_array),) + self.image_shape, dtype=K.floatx())
grayscale = self.color_mode == 'grayscale'
# build batch of image data
for i, j in enumerate(index_array):
fname = self.filenames[j]
img = load_img(os.path.join(self.directory, fname),
grayscale=grayscale,
target_size=self.target_size,
interpolation=self.interpolation)
if np.random.rand() < 0.8:
"""
author: rikichou 2018,5,1 21:03:31
before ramdom trainsform and standardize
"""
# Start Code
""" 1. which iamge to merge with """
class_merge = self.classes[j]
class_range = self.classes_range[class_merge]
target_index = np.random.randint(class_range['start'], class_range['end'])
""" 2. load target image """
target_name = self.filenames[target_index]
target_img = load_img(os.path.join(self.directory, target_name),
grayscale=grayscale,
target_size=self.target_size,
interpolation=self.interpolation)
""" 3. get the target image's start and end coordinate """
if np.random.rand() < 0.5:
start_x = 0
start_y = self.target_size[0]//2
end_x = self.target_size[1]
end_y = self.target_size[0]
else:
start_x = self.target_size[1]//2
start_y = 0
end_x = self.target_size[1]
end_y = self.target_size[0]
from_copy = (start_x, start_y, end_x, end_y)
to_paste = (start_x, start_y)
""" 4. copy and paste """
region = target_img.crop(from_copy)
img.paste(region, to_paste)
#EndCode
x = img_to_array(img, data_format=self.data_format)
x = self.image_data_generator.random_transform(x)
x = self.image_data_generator.standardize(x)
batch_x[i] = x
# optionally save augmented images to disk for debugging purposes
if self.save_to_dir:
for i, j in enumerate(index_array):
img = array_to_img(batch_x[i], self.data_format, scale=True)
fname = '{prefix}_{index}_{hash}.{format}'.format(prefix=self.save_prefix,
index=j,
hash=np.random.randint(1e7),
format=self.save_format)
img.save(os.path.join(self.save_to_dir, fname))
# build batch of labels
if self.class_mode == 'input':
batch_y = batch_x.copy()
elif self.class_mode == 'sparse':
batch_y = self.classes[index_array]
elif self.class_mode == 'binary':
batch_y = self.classes[index_array].astype(K.floatx())
elif self.class_mode == 'categorical':
batch_y = np.zeros((len(batch_x), self.num_classes), dtype=K.floatx())
for i, label in enumerate(self.classes[index_array]):
batch_y[i, label] = 1.
else:
return batch_x
return batch_x, batch_y
def next(self):
"""For python 2.x.
# Returns
The next batch.
"""
with self.lock:
index_array = next(self.index_generator)
# The transformation of images is not under thread lock
# so it can be done in parallel
return self._get_batches_of_transformed_samples(index_array)
def fixed_data_generator(train_dir, valid_dir, test_dir, image_size):
"""
image_size: the output of the image size, like (224, 224)
"""
gen = MergeImageDataGenerator(rotation_range=10.,
width_shift_range=0.05,
height_shift_range=0.05,
shear_range=0.1,
zoom_range=0.1)
gen_valid = ImageDataGenerator()
test_gen = ImageDataGenerator()
# create train generator
train_generator = gen.flow_from_directory(train_dir, image_size, color_mode='rgb', \
classes=classes, class_mode='categorical', shuffle=True, batch_size=32)
# create validation generator
valid_generator = gen_valid.flow_from_directory(valid_dir, image_size, color_mode='rgb', \
classes=classes, class_mode='categorical', shuffle=False, batch_size=32)
test_generator = test_gen.flow_from_directory(test_dir, image_size, color_mode='rgb', \
class_mode=None, shuffle=False, batch_size=32)
return train_generator, valid_generator, test_generator
from keras.applications import xception
from keras import optimizers
from keras.callbacks import ModelCheckpoint
xcp_drivers_id_list = [['p016', 'p072', 'p026', 'p066', 'p075']]
xcp_image_size = (299, 299)
xcp_input_shape = (299, 299, 3)
xcp_csv_names_list = []
#for i in range(2):
i = 0
print ('------------------------------------------------------------------------------------------------------------------')
""" 1. about data """
# create link and remove the old link
train_valid_split(xcp_drivers_id_list[i])
# get generator
train_generator, valid_generator, test_generator = fixed_data_generator(link_train_path, link_valid_path, test_link, xcp_image_size)
------------------------------------------------------------------------------------------------------------------
!!!!!!!Found 17956 images belonging to 10 classes.
Found 4468 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
""" 2. about model """
# get model
xception_model = get_model(xception.Xception, xcp_input_shape, xception.preprocess_input, 10)
from keras.optimizers import Adam
# compile
weights_file_name = 'xception.fixed.weights.hdf5'
ckpt = ModelCheckpoint(weights_file_name, verbose=1, save_best_only=True, save_weights_only=True)
adam = optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
xception_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
# fit
hist_1 = xception_model.fit_generator(train_generator, len(train_generator), epochs=1,\
validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
Epoch 1/1
562/562 [==============================] - 514s 915ms/step - loss: 0.3387 - acc: 0.8970 - val_loss: 0.3116 - val_acc: 0.8957
Epoch 00001: val_loss improved from inf to 0.31157, saving model to xception.fixed.weights.hdf5
adam = optimizers.Adam(lr=1e-5, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
xception_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
hist_1 = xception_model.fit_generator(train_generator, len(train_generator), epochs=1,\
validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
Epoch 1/1
562/562 [==============================] - 514s 915ms/step - loss: 0.0338 - acc: 0.9914 - val_loss: 0.2759 - val_acc: 0.8988
Epoch 00001: val_loss improved from 0.31157 to 0.27592, saving model to xception.fixed.weights.hdf5
hist_1 = xception_model.fit_generator(train_generator, len(train_generator), epochs=1,\
validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
Epoch 1/1
562/562 [==============================] - 511s 910ms/step - loss: 0.0215 - acc: 0.9945 - val_loss: 0.2713 - val_acc: 0.9029
Epoch 00001: val_loss improved from 0.27592 to 0.27130, saving model to xception.fixed.weights.hdf5
可以看出来,xception的单模型表现明显提升,由于时间关系我就不再训练更多的模型进行融合了,但是这肯定是一条可行的道路。
通过上述的过程走下来,我们得到了一个不错的模型,在kaggle的排行榜中可以排到top 10%。该项目有多个重点,如下:
-
验证集训练集划分的时候,因为图像数据的特殊性(相似性),所以需要根据司机ID来进行划分,而不应该根据司机状态
-
由于同一个司机的不同状态的图像很相似,所以应该开放尽可能多的层进行训练(本项目中我没有lock任何层)
-
由于是kaggle竞赛,衡量标准仅仅是softmax cross entropy,对模型预测时间没有要求,所以为了提高分数,总是可以采用模型融合的策略
-
为了让模型具有更强的泛化能力,应该加入一些特殊的data augmentation。对于同一状态的图片,将随机的两张图片的左半部分和右半部分或者上半部分和下班部分组合成一张新的图片。这样做确实可以提高单模型的泛化能力,如果加上模型融合的话,分数会进一步提高