Skip to content

分心驾驶检测,检测驾驶员的驾驶状态,对分心驾驶的驾驶员做出提醒

Notifications You must be signed in to change notification settings

rikichou/distracted-driving-detection

Repository files navigation

驾驶员状态检测

根据汽车安全部门的调查显示,五分之一的交通事故都是由于驾驶员分心(distracted)造成的。每年,distracted driving会造成约42500人受伤,3000人死亡。这个数字非常惊人。

State Farm希望通过车载的dashboard cameras来检测用户是否处于distracted driving的状态,从而发出警告。


这是我实际的开发过程,我不会上来就写出最优的解决方案,而是把我所遇到的“坑”都给写出来,这些“坑”真的很经典

由于训练量比较大,所以本项目的所有训练均是通过租用Amazon EC2进行训练,价格适中,下文有介绍。

1.1 数据集获取

数据集来自于kaggle,总共有三个文件需要下载,如下所示。其中imgs.zip是通过摄像头来抓取的驾驶员的状态的标记数据集。该数据集的大小有4G。

  • imgs.zip - 所有训练/测试图片打包的zip文件(你需要自行下载)
  • sample_submission.csv - 提交kaggle时候的格式
  • driver_imgs_list.csv - 文件的信息,文件名对应的图像中的司机ID以及图像中司机的状态ID。

现在解压imgs.zip来观察

from keras.layers import Input
from keras.layers.core import Lambda
from keras.layers import Dense, GlobalAveragePooling2D, Dropout
from keras.models import Model
from keras.preprocessing.image import ImageDataGenerator

from urllib.request import urlretrieve
from os.path import isfile, isdir
from tqdm import tqdm
import zipfile
import os
import h5py
import numpy as np

zip_name = 'imgs.zip'
train_dir_name = 'train'
test_dir_name = 'test'

link_path = 'train_link'
link_train_path = 'train_link/train'
link_valid_path = 'train_link/validation'

test_link = 'test_link'
test_link_path = 'test_link/data'

resnet_50_model_save_name = 'model_resnet50.h5'
inceptionv3_model_save_name = 'model_inceptionv3.h5'
xception_model_save_name = 'model_xception.h5'

## check if the train and  test data is exist
if not isdir(train_dir_name) or not isdir(test_dir_name):
    if not isfile(zip_name):
        print ("Please download imgs.zip from kaggle!")
        assert(False)
    else:
        with zipfile.ZipFile(zip_name) as azip:
            print ("Now to extract %s " % (zip_name))
            azip.extractall()
print ("Data is ready!")
Data is ready!

我们将检测驾驶员10种驾驶员的驾驶状态,如下

  • c0: 安全驾驶
  • c1: 右手打字
  • c2: 右手打电话
  • c3: 左手打字
  • c4: 左手打电话
  • c5: 调收音机
  • c6: 喝饮料
  • c7: 拿后面的东西
  • c8: 整理头发和化妆
  • c9: 和其他乘客说话

其中,关于每一种驾驶状态的图片都分开存放,也就是有c0-c9文件夹存放各自状态的图片。此时文件夹的目录结构大概如下

|----imgs.zip
|----train
       |-----c0
       |-----c1
       |-----c2
       |-----c3
       |-----c4
       |-----c5
       |-----c6
       |-----c7
       |-----c8
       |-----c9
|----test

1.2 数据集基本信息

接下来就是了解数据集的基本信息:

  1. 统计训练测试样本数量
  2. 每一类训练数据的数量分布
import os

## get train and file nums
train_class_dir_names = os.listdir(train_dir_name)
test_size = len(os.listdir(test_dir_name))

train_size = 0
train_class_size = {}

for dname in train_class_dir_names:
    file_names = os.listdir(train_dir_name + '/' + dname)
    train_class_size[dname] = len(file_names)
    train_size = train_class_size[dname] + train_size
    
print ("Test file numbers: ", test_size)
print ("Train file numbers: ", train_size)
Test file numbers:  79726
Train file numbers:  22424
import matplotlib.pyplot as plt
%matplotlib inline

fig = plt.figure()  
plt.bar(train_class_size.keys(), train_class_size.values(), 0.4, color="green")  
plt.xlabel("Classes")  
plt.ylabel("File nums")  
plt.title("Classes distribution")  
plt.show()

png

从上面的结果可以看出,我们总共有79726张测试图片,22424张测试图片。在训练图片中,每一个状态(类)包含大约2000张图片,分布还是比较均匀的。上面的可视化是根据状态来显示的,现在我们从另一个角度,看看每个司机大约包含了多少张图片。

接下来就要解压文件driver_imgs_list.csv.zip

# get driver_imgs_list_file
driver_imgs_list_zip = 'driver_imgs_list.csv.zip'
driver_imgs_list_file = 'driver_imgs_list.csv'

if not isfile(driver_imgs_list_file):
    if not isfile(driver_imgs_list_zip):
        print ("Please download river_imgs_list.csv.zip from kaggle!")
        assert(False)
    else:
        with zipfile.ZipFile(driver_imgs_list_zip) as azip:
            print ("Now to extract %s " % (driver_imgs_list_zip))
            azip.extractall()
import pandas as pd

df = pd.read_csv(driver_imgs_list_file)

df.describe()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
subject classname img
count 22424 22424 22424
unique 26 10 22424
top p021 c0 img_2362.jpg
freq 1237 2489 1
ts = df['subject'].value_counts()
print (ts)
fig = plt.figure(figsize=(30,10))  
plt.bar(ts.index.tolist(), ts.iloc[:].tolist(), 0.4, color="green")  
plt.xlabel("Driver ID")  
plt.ylabel("File nums")  
plt.title("Driver ID distribution")
plt.show()
p021    1237
p022    1233
p024    1226
p026    1196
p016    1078
p066    1034
p049    1011
p051     920
p014     876
p015     875
p035     848
p047     835
p012     823
p081     823
p064     820
p075     814
p061     809
p056     794
p050     790
p052     740
p002     725
p045     724
p039     651
p041     605
p042     591
p072     346
Name: subject, dtype: int64

png

可以看出,我们的训练集的数据来自于26个不同的司机的状态,每个司机都是都拥有不同状态的图片,其中346号司机拥有的图片数量最少,为346张。21号司机拥有的图片数量最多,为1237张。可以看出,如果按照司机ID来观察数据,数据分布并不均匀。但是这并不影响我们的训练,因为我们主要关心的是每一个状态所拥有的图片是否均匀,而不是每一个司机所拥有的图片是否均匀。接下来我们来简单可视化一下每一类的样本

import cv2
import numpy as np

state_des = {'c0':'safe driving','c1':'texting - right hand','c2':'talking on the phone - right','c3':'texting - left hand',  \
             'c4':'talking on the phone - left hand','c5':'operating the radio','c6':'drinking','c7':'reaching behind','c8':'hair and makeup',  \
             'c9':'talking to passenger'};

## class that you want to display
c = 0

## random choose the filenames of the class
dis_dir = train_dir_name + '/c' + str(c)
dis_filenames = os.listdir(dis_dir)

dis_list = np.random.randint(len(dis_filenames), size=(6))
dis_list = [dis_filenames[index] for index in dis_list]

plt.figure(1, figsize=(13, 13))

for i,filename in enumerate(dis_list):
    image = cv2.imread(dis_dir +  '/' + str(filename))
    image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
    ax1=plt.subplot(3,3,i+1)
    plt.imshow(image)
    plt.axis("off")
    plt.title(state_des['c'+str(c)] + "\n" + str(image.shape))
    
plt.show()

png

2.单模型的迁移学习

在做过了几个kaggle项目之后,对计算机视觉类的项目有了一个大概的直觉,如果要采用当前流行的CNN模型来完成项目,那么总是可以尝试迁移学习的,因为对于CNN来说,一些图片的“底层信息”(直线,边缘,眼睛,耳朵,猫脸,狗脸)是可以共享的,所以我们就可以利用各大模型在大型计算机视觉数据集中学习到的通用的“知识”,将其迁移到我们的项目的学习中,从而帮助我们更好地提取通用的特征。并且

  • 如果我们只拥有少量的数据集,那么就可以只训练top layer(全连接层,输出层)的权重,不学习top layer以外的层的权重。如果数据集很少的情况下执意要对top layer以外的层进行训练,极有可能破坏预训练模型在庞大数据集中学习的知识,反而造成不好的效果。
  • 如果我们拥有中等数量的数据集,那么我们可以开放少量的卷积层进行权重的fine-tune,而且为了避免大的权重的更新对以前学习到的知识造成破坏,建议采用小的learning rate。
  • 如果我们拥有庞大的数据集,并且项目拥有足够的时间来进行训练,那么此时依然可以采用迁移学习,但是这时候我们会开放比第二种情况更多的层,甚至所有的层

之前我们已经对数据集进行了一些基本的了解,发现训练集的数量为2万多,所以我将我们现在的情况定位为第一种(毕竟2万多相对于百万级别的数据还是太少)。所以现在我打算在单模型上面进行迁移学习, 并且只更新top layer的权重 (为什么这里标记为红色,后面你就知道了)

2.1 resnet-50的单模型迁移学习(Train top layer only)

我第一个尝试的模型是在ILSVRC 2015比赛中获得了冠军resnet-50,该模型利用residual block,解决了当网络深度增加是所产生的Degradation问题,即准确率会先上升然后达到饱和,再持续增加深度则会导致准确率下降(注意哟,这是训练集的准确率哟,不是验证集哟,所以不是过拟合问题)。我相信50层神经网络能够轻松应付本项目,并且不用担心会出现Degradation问题,何乐而不为呢?废话不多说,加载预训练权重并开始调参过程吧!

这里说明一下:我的电脑的显卡是Gtx 960,这个显卡用于训练小的神经网络还可以,但是用于训练resnet这种规模的网络就有点儿捉襟见肘了。所以我租用的亚马逊AWS云主机,p3.2 xlarge,其竞价实例的价格大约是1美元每小时,如果没有接触过AWS的可以通过这篇文章来学习如何利用AWS来进行深度学习(不知不觉又给亚马逊打了一下广告),还有,如果你想用AWS的话,你需要一把梯子(翻墙)。如果不想翻墙的话,阿里巴巴也有类似的云主机,只不过就是价格贵了点儿,自己衡量吧!

上面说了,我只会训练top layer的权重,所以为了方便之后的调参过程,这里采用bottleneck feature的方式来减少重复的前向传播过程,也就是提取所有样本在top layer之前的输出结果(对于resnet,该输出结果的维度是112048,以后统一叫做特征向量),将其保存起来,从而将其作为top layer的输入来达到加速调参的过程。也算是一种用空间换取时间的策略吧!

## load pretrained resnet
resNet_input_shape = (224,224,3)
res_x = Input(shape=resNet_input_shape)
res_x = Lambda(resnet50.preprocess_input)(res_x)
res_model = resnet50.ResNet50(include_top=False, weights='imagenet', input_tensor=res_x, input_shape=resNet_input_shape)

res_model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 224, 224, 3)  0                                            
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 224, 224, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 230, 230, 3)  0           lambda_1[0][0]                   
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 112, 112, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, 112, 112, 64) 256         conv1[0][0]                      
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 112, 112, 64) 0           bn_conv1[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 55, 55, 64)   0           activation_1[0][0]               
__________________________________________________________________________________________________
res2a_branch2a (Conv2D)         (None, 55, 55, 64)   4160        max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 55, 55, 64)   256         res2a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 55, 55, 64)   0           bn2a_branch2a[0][0]              
__________________________________________________________________________________________________
res2a_branch2b (Conv2D)         (None, 55, 55, 64)   36928       activation_2[0][0]               
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 55, 55, 64)   256         res2a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 55, 55, 64)   0           bn2a_branch2b[0][0]              
__________________________________________________________________________________________________
res2a_branch2c (Conv2D)         (None, 55, 55, 256)  16640       activation_3[0][0]               
__________________________________________________________________________________________________
res2a_branch1 (Conv2D)          (None, 55, 55, 256)  16640       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 55, 55, 256)  1024        res2a_branch2c[0][0]             
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 55, 55, 256)  1024        res2a_branch1[0][0]              
__________________________________________________________________________________________________
add_1 (Add)                     (None, 55, 55, 256)  0           bn2a_branch2c[0][0]              
                                                                 bn2a_branch1[0][0]               
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 55, 55, 256)  0           add_1[0][0]                      
__________________________________________________________________________________________________
res2b_branch2a (Conv2D)         (None, 55, 55, 64)   16448       activation_4[0][0]               
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 55, 55, 64)   256         res2b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 55, 55, 64)   0           bn2b_branch2a[0][0]              
__________________________________________________________________________________________________
res2b_branch2b (Conv2D)         (None, 55, 55, 64)   36928       activation_5[0][0]               
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 55, 55, 64)   256         res2b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 55, 55, 64)   0           bn2b_branch2b[0][0]              
__________________________________________________________________________________________________
res2b_branch2c (Conv2D)         (None, 55, 55, 256)  16640       activation_6[0][0]               
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 55, 55, 256)  1024        res2b_branch2c[0][0]             
__________________________________________________________________________________________________
add_2 (Add)                     (None, 55, 55, 256)  0           bn2b_branch2c[0][0]              
                                                                 activation_4[0][0]               
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 55, 55, 256)  0           add_2[0][0]                      
__________________________________________________________________________________________________
res2c_branch2a (Conv2D)         (None, 55, 55, 64)   16448       activation_7[0][0]               
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 55, 55, 64)   256         res2c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 55, 55, 64)   0           bn2c_branch2a[0][0]              
__________________________________________________________________________________________________
res2c_branch2b (Conv2D)         (None, 55, 55, 64)   36928       activation_8[0][0]               
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 55, 55, 64)   256         res2c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 55, 55, 64)   0           bn2c_branch2b[0][0]              
__________________________________________________________________________________________________
res2c_branch2c (Conv2D)         (None, 55, 55, 256)  16640       activation_9[0][0]               
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 55, 55, 256)  1024        res2c_branch2c[0][0]             
__________________________________________________________________________________________________
add_3 (Add)                     (None, 55, 55, 256)  0           bn2c_branch2c[0][0]              
                                                                 activation_7[0][0]               
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 55, 55, 256)  0           add_3[0][0]                      
__________________________________________________________________________________________________
res3a_branch2a (Conv2D)         (None, 28, 28, 128)  32896       activation_10[0][0]              
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_11 (Activation)      (None, 28, 28, 128)  0           bn3a_branch2a[0][0]              
__________________________________________________________________________________________________
res3a_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_11[0][0]              
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_12 (Activation)      (None, 28, 28, 128)  0           bn3a_branch2b[0][0]              
__________________________________________________________________________________________________
res3a_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_12[0][0]              
__________________________________________________________________________________________________
res3a_branch1 (Conv2D)          (None, 28, 28, 512)  131584      activation_10[0][0]              
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3a_branch2c[0][0]             
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 28, 28, 512)  2048        res3a_branch1[0][0]              
__________________________________________________________________________________________________
add_4 (Add)                     (None, 28, 28, 512)  0           bn3a_branch2c[0][0]              
                                                                 bn3a_branch1[0][0]               
__________________________________________________________________________________________________
activation_13 (Activation)      (None, 28, 28, 512)  0           add_4[0][0]                      
__________________________________________________________________________________________________
res3b_branch2a (Conv2D)         (None, 28, 28, 128)  65664       activation_13[0][0]              
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_14 (Activation)      (None, 28, 28, 128)  0           bn3b_branch2a[0][0]              
__________________________________________________________________________________________________
res3b_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_14[0][0]              
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_15 (Activation)      (None, 28, 28, 128)  0           bn3b_branch2b[0][0]              
__________________________________________________________________________________________________
res3b_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_15[0][0]              
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3b_branch2c[0][0]             
__________________________________________________________________________________________________
add_5 (Add)                     (None, 28, 28, 512)  0           bn3b_branch2c[0][0]              
                                                                 activation_13[0][0]              
__________________________________________________________________________________________________
activation_16 (Activation)      (None, 28, 28, 512)  0           add_5[0][0]                      
__________________________________________________________________________________________________
res3c_branch2a (Conv2D)         (None, 28, 28, 128)  65664       activation_16[0][0]              
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_17 (Activation)      (None, 28, 28, 128)  0           bn3c_branch2a[0][0]              
__________________________________________________________________________________________________
res3c_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_17[0][0]              
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_18 (Activation)      (None, 28, 28, 128)  0           bn3c_branch2b[0][0]              
__________________________________________________________________________________________________
res3c_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_18[0][0]              
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3c_branch2c[0][0]             
__________________________________________________________________________________________________
add_6 (Add)                     (None, 28, 28, 512)  0           bn3c_branch2c[0][0]              
                                                                 activation_16[0][0]              
__________________________________________________________________________________________________
activation_19 (Activation)      (None, 28, 28, 512)  0           add_6[0][0]                      
__________________________________________________________________________________________________
res3d_branch2a (Conv2D)         (None, 28, 28, 128)  65664       activation_19[0][0]              
__________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_20 (Activation)      (None, 28, 28, 128)  0           bn3d_branch2a[0][0]              
__________________________________________________________________________________________________
res3d_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_20[0][0]              
__________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_21 (Activation)      (None, 28, 28, 128)  0           bn3d_branch2b[0][0]              
__________________________________________________________________________________________________
res3d_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_21[0][0]              
__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3d_branch2c[0][0]             
__________________________________________________________________________________________________
add_7 (Add)                     (None, 28, 28, 512)  0           bn3d_branch2c[0][0]              
                                                                 activation_19[0][0]              
__________________________________________________________________________________________________
activation_22 (Activation)      (None, 28, 28, 512)  0           add_7[0][0]                      
__________________________________________________________________________________________________
res4a_branch2a (Conv2D)         (None, 14, 14, 256)  131328      activation_22[0][0]              
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_23 (Activation)      (None, 14, 14, 256)  0           bn4a_branch2a[0][0]              
__________________________________________________________________________________________________
res4a_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_23[0][0]              
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_24 (Activation)      (None, 14, 14, 256)  0           bn4a_branch2b[0][0]              
__________________________________________________________________________________________________
res4a_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_24[0][0]              
__________________________________________________________________________________________________
res4a_branch1 (Conv2D)          (None, 14, 14, 1024) 525312      activation_22[0][0]              
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4a_branch2c[0][0]             
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, 14, 14, 1024) 4096        res4a_branch1[0][0]              
__________________________________________________________________________________________________
add_8 (Add)                     (None, 14, 14, 1024) 0           bn4a_branch2c[0][0]              
                                                                 bn4a_branch1[0][0]               
__________________________________________________________________________________________________
activation_25 (Activation)      (None, 14, 14, 1024) 0           add_8[0][0]                      
__________________________________________________________________________________________________
res4b_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_25[0][0]              
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_26 (Activation)      (None, 14, 14, 256)  0           bn4b_branch2a[0][0]              
__________________________________________________________________________________________________
res4b_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_26[0][0]              
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_27 (Activation)      (None, 14, 14, 256)  0           bn4b_branch2b[0][0]              
__________________________________________________________________________________________________
res4b_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_27[0][0]              
__________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4b_branch2c[0][0]             
__________________________________________________________________________________________________
add_9 (Add)                     (None, 14, 14, 1024) 0           bn4b_branch2c[0][0]              
                                                                 activation_25[0][0]              
__________________________________________________________________________________________________
activation_28 (Activation)      (None, 14, 14, 1024) 0           add_9[0][0]                      
__________________________________________________________________________________________________
res4c_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_28[0][0]              
__________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_29 (Activation)      (None, 14, 14, 256)  0           bn4c_branch2a[0][0]              
__________________________________________________________________________________________________
res4c_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_29[0][0]              
__________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_30 (Activation)      (None, 14, 14, 256)  0           bn4c_branch2b[0][0]              
__________________________________________________________________________________________________
res4c_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_30[0][0]              
__________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4c_branch2c[0][0]             
__________________________________________________________________________________________________
add_10 (Add)                    (None, 14, 14, 1024) 0           bn4c_branch2c[0][0]              
                                                                 activation_28[0][0]              
__________________________________________________________________________________________________
activation_31 (Activation)      (None, 14, 14, 1024) 0           add_10[0][0]                     
__________________________________________________________________________________________________
res4d_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_31[0][0]              
__________________________________________________________________________________________________
bn4d_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_32 (Activation)      (None, 14, 14, 256)  0           bn4d_branch2a[0][0]              
__________________________________________________________________________________________________
res4d_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_32[0][0]              
__________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_33 (Activation)      (None, 14, 14, 256)  0           bn4d_branch2b[0][0]              
__________________________________________________________________________________________________
res4d_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_33[0][0]              
__________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4d_branch2c[0][0]             
__________________________________________________________________________________________________
add_11 (Add)                    (None, 14, 14, 1024) 0           bn4d_branch2c[0][0]              
                                                                 activation_31[0][0]              
__________________________________________________________________________________________________
activation_34 (Activation)      (None, 14, 14, 1024) 0           add_11[0][0]                     
__________________________________________________________________________________________________
res4e_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_34[0][0]              
__________________________________________________________________________________________________
bn4e_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4e_branch2a[0][0]             
__________________________________________________________________________________________________
activation_35 (Activation)      (None, 14, 14, 256)  0           bn4e_branch2a[0][0]              
__________________________________________________________________________________________________
res4e_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_35[0][0]              
__________________________________________________________________________________________________
bn4e_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4e_branch2b[0][0]             
__________________________________________________________________________________________________
activation_36 (Activation)      (None, 14, 14, 256)  0           bn4e_branch2b[0][0]              
__________________________________________________________________________________________________
res4e_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_36[0][0]              
__________________________________________________________________________________________________
bn4e_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4e_branch2c[0][0]             
__________________________________________________________________________________________________
add_12 (Add)                    (None, 14, 14, 1024) 0           bn4e_branch2c[0][0]              
                                                                 activation_34[0][0]              
__________________________________________________________________________________________________
activation_37 (Activation)      (None, 14, 14, 1024) 0           add_12[0][0]                     
__________________________________________________________________________________________________
res4f_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_37[0][0]              
__________________________________________________________________________________________________
bn4f_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4f_branch2a[0][0]             
__________________________________________________________________________________________________
activation_38 (Activation)      (None, 14, 14, 256)  0           bn4f_branch2a[0][0]              
__________________________________________________________________________________________________
res4f_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_38[0][0]              
__________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4f_branch2b[0][0]             
__________________________________________________________________________________________________
activation_39 (Activation)      (None, 14, 14, 256)  0           bn4f_branch2b[0][0]              
__________________________________________________________________________________________________
res4f_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_39[0][0]              
__________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4f_branch2c[0][0]             
__________________________________________________________________________________________________
add_13 (Add)                    (None, 14, 14, 1024) 0           bn4f_branch2c[0][0]              
                                                                 activation_37[0][0]              
__________________________________________________________________________________________________
activation_40 (Activation)      (None, 14, 14, 1024) 0           add_13[0][0]                     
__________________________________________________________________________________________________
res5a_branch2a (Conv2D)         (None, 7, 7, 512)    524800      activation_40[0][0]              
__________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizati (None, 7, 7, 512)    2048        res5a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_41 (Activation)      (None, 7, 7, 512)    0           bn5a_branch2a[0][0]              
__________________________________________________________________________________________________
res5a_branch2b (Conv2D)         (None, 7, 7, 512)    2359808     activation_41[0][0]              
__________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizati (None, 7, 7, 512)    2048        res5a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_42 (Activation)      (None, 7, 7, 512)    0           bn5a_branch2b[0][0]              
__________________________________________________________________________________________________
res5a_branch2c (Conv2D)         (None, 7, 7, 2048)   1050624     activation_42[0][0]              
__________________________________________________________________________________________________
res5a_branch1 (Conv2D)          (None, 7, 7, 2048)   2099200     activation_40[0][0]              
__________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizati (None, 7, 7, 2048)   8192        res5a_branch2c[0][0]             
__________________________________________________________________________________________________
bn5a_branch1 (BatchNormalizatio (None, 7, 7, 2048)   8192        res5a_branch1[0][0]              
__________________________________________________________________________________________________
add_14 (Add)                    (None, 7, 7, 2048)   0           bn5a_branch2c[0][0]              
                                                                 bn5a_branch1[0][0]               
__________________________________________________________________________________________________
activation_43 (Activation)      (None, 7, 7, 2048)   0           add_14[0][0]                     
__________________________________________________________________________________________________
res5b_branch2a (Conv2D)         (None, 7, 7, 512)    1049088     activation_43[0][0]              
__________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizati (None, 7, 7, 512)    2048        res5b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_44 (Activation)      (None, 7, 7, 512)    0           bn5b_branch2a[0][0]              
__________________________________________________________________________________________________
res5b_branch2b (Conv2D)         (None, 7, 7, 512)    2359808     activation_44[0][0]              
__________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizati (None, 7, 7, 512)    2048        res5b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_45 (Activation)      (None, 7, 7, 512)    0           bn5b_branch2b[0][0]              
__________________________________________________________________________________________________
res5b_branch2c (Conv2D)         (None, 7, 7, 2048)   1050624     activation_45[0][0]              
__________________________________________________________________________________________________
bn5b_branch2c (BatchNormalizati (None, 7, 7, 2048)   8192        res5b_branch2c[0][0]             
__________________________________________________________________________________________________
add_15 (Add)                    (None, 7, 7, 2048)   0           bn5b_branch2c[0][0]              
                                                                 activation_43[0][0]              
__________________________________________________________________________________________________
activation_46 (Activation)      (None, 7, 7, 2048)   0           add_15[0][0]                     
__________________________________________________________________________________________________
res5c_branch2a (Conv2D)         (None, 7, 7, 512)    1049088     activation_46[0][0]              
__________________________________________________________________________________________________
bn5c_branch2a (BatchNormalizati (None, 7, 7, 512)    2048        res5c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_47 (Activation)      (None, 7, 7, 512)    0           bn5c_branch2a[0][0]              
__________________________________________________________________________________________________
res5c_branch2b (Conv2D)         (None, 7, 7, 512)    2359808     activation_47[0][0]              
__________________________________________________________________________________________________
bn5c_branch2b (BatchNormalizati (None, 7, 7, 512)    2048        res5c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_48 (Activation)      (None, 7, 7, 512)    0           bn5c_branch2b[0][0]              
__________________________________________________________________________________________________
res5c_branch2c (Conv2D)         (None, 7, 7, 2048)   1050624     activation_48[0][0]              
__________________________________________________________________________________________________
bn5c_branch2c (BatchNormalizati (None, 7, 7, 2048)   8192        res5c_branch2c[0][0]             
__________________________________________________________________________________________________
add_16 (Add)                    (None, 7, 7, 2048)   0           bn5c_branch2c[0][0]              
                                                                 activation_46[0][0]              
__________________________________________________________________________________________________
activation_49 (Activation)      (None, 7, 7, 2048)   0           add_16[0][0]                     
__________________________________________________________________________________________________
avg_pool (AveragePooling2D)     (None, 1, 1, 2048)   0           activation_49[0][0]              
==================================================================================================
Total params: 23,587,712
Trainable params: 23,534,592
Non-trainable params: 53,120
__________________________________________________________________________________________________

我们使用的在ImageNet上训练的resnet-50模型,并且,我们没有加载top layer,因为接下来是要提取bottleneck feature,无需top layer。从模型的打印信息可以知道,该模型的输出维度为112048,但是我要保存的是1维的特征向量,所以需要将其进行类似flatten的操作

out = GlobalAveragePooling2D()(res_model.output)
res_vec_model = Model(inputs=res_model.input, outputs=out)

在得到了提取bottleneck feature的模型之后,就可以开始着手提取特征了,考虑到之后也可能使用类似的操作来提取其他模型的bottleneck feature,所以我就写了一个函数(所以上面两个cell只是用来说明如何构建提取feature的模型的,真正使用的是下面这个函数)。该函数先构造一个提取bottleneck feature的模型,做的事情和上面两个cell一样。后面使用了image data generator的方式来来进行bottleneck feature提取,为什么采用这种方式呢?因为这种方式并不会一次性地把所有图片都加载到内存中,而是采用类似队列一样的边使用边加载的方式进行数据的读取,这样可以大大地减少内存的使用。我们可以来算一算,如果将所有的数据一次性加载到内存中会消耗多少的内存?假设一张图片的大小为2242243(resnet-50需要的大小),那么我们有22424张训练集(还没有包含庞大的测试集),加入我们以uint8的data type来加载数据,那么训练集所占用的内存大小就是2242243*22424 = 3375439872,达到了惊人的3GB。所以测试集将占用超过9GB,总共就是12GB。这还没开始训练呢,就占用了如此多的内存,即使我使用的AWS主机的Tesla v100显卡有16GB内存,但是考虑到之后留给训练的余量,还是不要以这种方式来加载数据了。

def model_vector_catch(MODEL, image_size, vect_file_name, vec_dir, train_dir, test_dir, preprocessing=None):
    """
        MODEL:the model to extract bottleneck features
        image_size:MODEL input size(h, w, channels)
        vect_file_name:file to save vector
        preprocessing:whether or not need preprocessing
    """
    if isfile(vec_dir + '/' + vect_file_name):
        print ("%s already OK!" % (vect_file_name))
        return
    
    input_tensor = Input(shape=(image_size[0], image_size[1], 3))
    
    if preprocessing:
        ## check if need preprocessing
        input_tensor = Lambda(preprocessing)(input_tensor)
    
    model_no_top = MODEL(include_top=False, weights='imagenet', input_tensor=input_tensor, input_shape=(image_size[0], image_size[1], 3))
   
    ## flatten the output shape and generate model
    out = GlobalAveragePooling2D()(model_no_top.output)
    new_model = Model(inputs=model_no_top.input, outputs=out)
    
    ## get iamge generator
    gen = ImageDataGenerator()
    test_gen = ImageDataGenerator()
    
    """
    classes = ['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9', ] -- cat is 0, dog is 1, so we need write this
    class_mode = None -- i will not use like 'fit_fitgenerator', so i do not need labels
    shuffle = False -- it is unneccssary
    batch_size = 64 
    """
    class_list = ['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9', ]
    train_generator = gen.flow_from_directory(train_dir, image_size, color_mode='rgb', \
                                              classes=class_list, class_mode=None, shuffle=False, batch_size=64)
    
    #test_generator = test_gen.flow_from_directory(test_dir, image_size, color_mode='rgb', \
                                          #class_mode=None, shuffle=False, batch_size=64)
   
    """
    steps = None, by default, the steps = len(generator)
    """
    train_vector = new_model.predict_generator(train_generator)
    #test_vector = new_model.predict_generator(test_generator)
    
    with h5py.File(vec_dir + "/" + (vect_file_name), 'w') as f: 
        f.create_dataset('x_train', data=train_vector)
        f.create_dataset("y_train", data=train_generator.classes)
        #f.create_dataset("test", data=test_vector)
    print ("Model %s vector cached complete!" % (vect_file_name))
vec_dir = 'vect'

if not isdir(vec_dir):
    os.mkdir(vec_dir)
res_vect_file_name = 'resnet50_vect.h5'

model_vector_catch(resnet50.ResNet50, resNet_input_shape[:2], res_vect_file_name, vec_dir, train_dir_name, test_dir_name, resnet50.preprocess_input)
Found 22424 images belonging to 10 classes.
Model resnet50_vect.h5 vector cached complete!

现在我们已经提取好了训练集(暂时不提取测试集的bottleneck feature,等我先验证了这种方法是否好用之后,再做测试集的bottleneck feature的提取,因为训练模型的时候用不到测试集,所以不着急提取)的bottleneck feature,下面就是搭建模型的top layer来进行训练了,那么top layer要怎么搭建呢?在原始的resnet-50中,只有一个含有1000个隐藏单元的输出层,所以这里我们参考resnet-50的设计,改成一个含有10个隐藏单元的输出层(因为我们的类别是10),激活函数是softmax。

v = 10

input_tensor = Input(shape=(2048,))
x = Dropout(0.5)(input_tensor)
x = Dense(driver_classes, activation='softmax', name='res_dense_1')(x)

resnet50_model = Model(inputs=input_tensor, outputs=x)

在训练模型之前,需要先把保存在文件中的bottleneck feature提取到内存中来,便于之后的fit。进一步地,因为我们之前保存的label是0-9的数字,现在我们要使用one-hot的方式来表示。

import numpy as np

def convert_to_one_hot(Y, C):
    Y = np.eye(C)[Y.reshape(-1)]
    return Y
from sklearn.utils import shuffle

x_train = []
y_train = []

with h5py.File(vec_dir + '/' + res_vect_file_name, 'r') as f:
    x_train = np.array(f['x_train'])
    y_train = np.array(f['y_train'])
    
    #one-hot vector
    y_train = convert_to_one_hot(y_train, driver_classes)
    
    x_train, y_train = shuffle(x_train, y_train, random_state=0)

注意上面的代码!!!这里有一个坑,是我之前项目中犯过的,导致我的模型无论如何都不收敛,而且还花了很长时间去调试。就是shuffle操作,一定要shuffle啊,如果你的数据是每一类单独存放在一起的,如果不shuffle的话,那么每一个batch的数据都将是同一类的数据,导致模型根本学不到东西,因为这个batch我学习到了这一类的特点,另一个batch我就要学习另一类完全不同的特点,从而抛弃之前学习的内容,导致这样的恶性循环,无法收敛。因为这个错误是之前犯的,所以现在也就没有必要再专门调试说明了。

接下来的过程就是设置编译参数,然后调试了。这里optimizer采用Adam,参数为默认参数。batch_size采用64,epoch先运行10代看看情况,验证集划分为0.2.

resnet50_model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])

hist = resnet50_model.fit(x_train, y_train, batch_size=32, epochs=30, validation_split=0.2)
Train on 17939 samples, validate on 4485 samples
Epoch 1/30
17939/17939 [==============================] - 5s 254us/step - loss: 1.2879 - acc: 0.5758 - val_loss: 0.4706 - val_acc: 0.8783
Epoch 2/30
17939/17939 [==============================] - 2s 98us/step - loss: 0.6077 - acc: 0.7996 - val_loss: 0.2997 - val_acc: 0.9298
Epoch 3/30
17939/17939 [==============================] - 2s 96us/step - loss: 0.4998 - acc: 0.8343 - val_loss: 0.2223 - val_acc: 0.9483
Epoch 4/30
17939/17939 [==============================] - 2s 94us/step - loss: 0.4569 - acc: 0.8463 - val_loss: 0.2140 - val_acc: 0.9452
Epoch 5/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.4240 - acc: 0.8575 - val_loss: 0.2061 - val_acc: 0.9465
Epoch 6/30
17939/17939 [==============================] - 2s 94us/step - loss: 0.4089 - acc: 0.8653 - val_loss: 0.1641 - val_acc: 0.9550
Epoch 7/30
17939/17939 [==============================] - 2s 94us/step - loss: 0.3981 - acc: 0.8657 - val_loss: 0.1955 - val_acc: 0.9394
Epoch 8/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3967 - acc: 0.8688 - val_loss: 0.1512 - val_acc: 0.9574
Epoch 9/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3788 - acc: 0.8712 - val_loss: 0.1469 - val_acc: 0.9601
Epoch 10/30
17939/17939 [==============================] - 2s 96us/step - loss: 0.3852 - acc: 0.8723 - val_loss: 0.1646 - val_acc: 0.9518
Epoch 11/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3733 - acc: 0.8766 - val_loss: 0.1443 - val_acc: 0.9550
Epoch 12/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3858 - acc: 0.8735 - val_loss: 0.1474 - val_acc: 0.9603
Epoch 13/30
17939/17939 [==============================] - 2s 98us/step - loss: 0.3707 - acc: 0.8773 - val_loss: 0.1286 - val_acc: 0.9608
Epoch 14/30
17939/17939 [==============================] - 2s 97us/step - loss: 0.3597 - acc: 0.8810 - val_loss: 0.1340 - val_acc: 0.9581
Epoch 15/30
17939/17939 [==============================] - 2s 102us/step - loss: 0.3754 - acc: 0.8762 - val_loss: 0.2203 - val_acc: 0.9280
Epoch 16/30
17939/17939 [==============================] - 2s 99us/step - loss: 0.3603 - acc: 0.8821 - val_loss: 0.1307 - val_acc: 0.9628
Epoch 17/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3683 - acc: 0.8817 - val_loss: 0.1338 - val_acc: 0.9608
Epoch 18/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3640 - acc: 0.8819 - val_loss: 0.1180 - val_acc: 0.9654
Epoch 19/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3670 - acc: 0.8812 - val_loss: 0.1216 - val_acc: 0.9621
Epoch 20/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3505 - acc: 0.8855 - val_loss: 0.1128 - val_acc: 0.9672
Epoch 21/30
17939/17939 [==============================] - 2s 98us/step - loss: 0.3689 - acc: 0.8801 - val_loss: 0.1253 - val_acc: 0.9632
Epoch 22/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3670 - acc: 0.8808 - val_loss: 0.1418 - val_acc: 0.9538
Epoch 23/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3567 - acc: 0.8866 - val_loss: 0.1127 - val_acc: 0.9672
Epoch 24/30
17939/17939 [==============================] - 2s 96us/step - loss: 0.3484 - acc: 0.8870 - val_loss: 0.1269 - val_acc: 0.9614
Epoch 25/30
17939/17939 [==============================] - 2s 97us/step - loss: 0.3588 - acc: 0.8845 - val_loss: 0.1109 - val_acc: 0.9650
Epoch 26/30
17939/17939 [==============================] - 2s 95us/step - loss: 0.3672 - acc: 0.8822 - val_loss: 0.1100 - val_acc: 0.9666
Epoch 27/30
17939/17939 [==============================] - 2s 94us/step - loss: 0.3477 - acc: 0.8882 - val_loss: 0.1082 - val_acc: 0.9683
Epoch 28/30
17939/17939 [==============================] - 2s 96us/step - loss: 0.3647 - acc: 0.8839 - val_loss: 0.1119 - val_acc: 0.9686
Epoch 29/30
17939/17939 [==============================] - 2s 94us/step - loss: 0.3504 - acc: 0.8885 - val_loss: 0.1083 - val_acc: 0.9677
Epoch 30/30
17939/17939 [==============================] - 2s 96us/step - loss: 0.3562 - acc: 0.8877 - val_loss: 0.0988 - val_acc: 0.9706

我尝试过了很多优化算法,Adam,SGD等,包括设置了低的学习速率,以及学习速率衰减,但是现象都和上面的训练结果惊人的相似,主要现象分为两方面:

  • (训练集loss居高不下)训练集的loss一直在0.35左右摆动,无论如何也无法有效降低(参考kaggle排行榜上的前10名的loss,均小于0.13)
  • (验证集loss远远低于训练集loss)val_loss远远低于train_loss,经过30个epoch后,train_loss为0.35,而val_loss为0.09,并且val_loss还有继续下降的趋势。

大家记住,上面就是我遇到的两个最大的“坑”

下面我们就来一一分析这两个“坑”,并给出解决方案。

2.1.1 针对“训练集loss高居不下”的现象

先暂且不管模型的验证集的loss是如何如何低,如何如何完美。现在只关心训练集的loss,其值高达0.35,按照正常思维,这就是模型欠拟合了。一般解决欠拟合的套路如下:

  • 训练集loss反映了偏差(bias)的大小,bias反映了模型本的拟合能力,所以在模型拟合能力不够的情况下可以适当增加模型的复杂度,加入更多的隐藏层或者隐藏单元。

结果:我尝试了在输出层之前中加入了两个含有1024个隐藏单元的全连接层(这里就不给出代码了,读者可以自行尝试),但是根本没效果

  • 减小正则项,其实这也算是提高模型复杂度。对于本项目来说就是减小dropout的drop rate(等效于增大keep probability)

结果:无效!

在上述两种方法都尝试失败了之后,说明正常的套路已经没有用了,是不是我的思考方向错误了呢?是否不应该单纯地从模型的角度分析问题?是否应该尝试转换思路?所以我就开始尝试从数据集本身开始分析。

因为我们采用的策略是迁移学习,迁移学习的权重是来自于resnet-50在ImageNet上进行预训练的权重,也就是说我们靠这些ImageNet训练出来权重进行bottleneck feature的提取,但是结果大家都看到了,目前这样的方法并不能适用于我们的问题,那就是说提取的bottleneck feature不能帮助我们很好地进行学习,如果是这样的话,那就说明一个问题,我们的司机数据集和ImageNet数据集是有点儿“不一样”的。接下来我们就来分析分析这个不一样在哪里?

  1. 首先ImageNet数据集包含有1000种类别的图片,其中的数据大都是来自于不同的场景,拥有复杂的背景。但是我们的司机数据均来自于同一个场景(一个司机坐在车里,前面有一个方向盘),而且同一个司机的数据来自于同一个视频流,也就是说我们司机数据之间本身很相似!!!!

那么预训练的resnet-50(不含有top layer)对非常相似的司机数据进行提取时,会发生什么呢?当然就是会提取出一些非常相似的bottleneck feature!!(猜想预训练模型会这么‘干’:‘我’提取出了这张图片的信息,这张图片有脸,有手,手方向盘。。。但是我们需要的似乎不仅仅是这些信息),也就是不同的司机状态的数据也会提取出非常相似的feature的话,那么对于司机问题来说,这些feature就是无用的feature(当然‘无用’说得太过分了,毕竟训练的模型也有88%的准确率呢!只是说提取的特征信息还不能够满足本项目的需求),当然也就不能获得很好的训练效果了。

好了,原因知道了那么解决方法也就明朗了。解决方法就是开放更多地层,因为我们的模型需要学习更多的关于司机数据集的‘知识’。既然要开放更多的层,那么bottleneck feature的方法对于我们来说就不太使用了,因为模型中间的层一般维度都比较大,比如1414256,在20000个样本的情况下要占用4G的内存,如果再加上测试集的话,内存就消耗完了(而且如果AWS中选择K80显卡的话,只有11G的内存,根本不够)。所以现在我们就要对网络进行整体的训练。

2.1.1.1 resnet-50在更多的卷积层上面训练

接下来的大部分的操作和之前相似,不一样的是我们现在要开放更多的卷积层进行学习。在这之前,让我们先来准备用于训练的数据,这次我还是打算采用imagedataGenerator来产生训练的数据,同时用fitgenerator来进行训练。

此时的验证集的数据需要单独地列出来(因为待会儿训练的时候我使用的是fit_generator,而这个接口是不支持validation_split参数的,所以需要手动地将验证集分出来),我决定和之前一样采用总样本数量的20%的样本进行训练,所以我需要在每一个类别中提取出20%的样本作为验证集。这里为了节约磁盘空间,我将采用软链接的形式来建立train_link数据集。

classes = ['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9']

if not isdir(test_link_path):
    os.makedirs(test_link_path)
    c_filenames = os.listdir(test_dir_name)
    
    for file in c_filenames:
        os.symlink('../../' + test_dir_name + '/' + file, test_link_path + '/' + file)

if not isdir(link_path):
    os.makedirs(link_train_path)
    os.makedirs(link_valid_path)
    
    # make c0-c9
    for c in classes:
        os.makedirs(link_train_path + '/' + c)
        os.makedirs(link_valid_path + '/' + c)
        
    # create link from train dir
    for c in classes:
        # get path name
        c_path = train_dir_name + '/' + c
        train_dst_path = link_train_path + '/' + c
        valid_dst_path = link_valid_path + '/' + c
        
        # list all file name of this class
        c_filenames = os.listdir(c_path)
        valid_size = int (len(c_filenames)*0.2)
        
        # create validation data of this class
        for file in c_filenames[:valid_size]:
            os.symlink('../../../' + c_path + '/' + file, valid_dst_path + '/' + file)

        # create train data of this class
        for file in c_filenames[valid_size:]:
            os.symlink('../../../' + c_path + '/' + file, train_dst_path + '/' + file)
---------------------------------------------------------------------------

KeyboardInterrupt                         Traceback (most recent call last)

<ipython-input-9-8ff26d47ecaa> in <module>()
      6 
      7     for file in c_filenames:
----> 8         os.symlink('../../' + test_dir_name + '/' + file, test_link_path + '/' + file)
      9 
     10 if not isdir(link_path):


KeyboardInterrupt: 
def get_data_generator(train_dir, valid_dir, test_dir, image_size): 
    #gen = ImageDataGenerator(shear_range=0.3, zoom_range=0.3, rotation_range=0.3)
    gen = ImageDataGenerator()
    gen_valid = ImageDataGenerator()
    test_gen = ImageDataGenerator()
    
    """
    classes = ['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9'] -- cat is 0, dog is 1, so we need write this
    class_mode = categorical, the returned label mode
    shuffle = True, we need, 
    batch_size = 64 
    """
    class_list = ['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9']
    
    # create train generator
    train_generator = gen.flow_from_directory(train_dir, image_size, color_mode='rgb', \
                                              classes=class_list, class_mode='categorical', shuffle=True, batch_size=32)
    
    # create validation generator
    valid_generator = gen_valid.flow_from_directory(valid_dir, image_size, color_mode='rgb', \
                                              classes=class_list, class_mode='categorical', shuffle=False, batch_size=32)
    
    test_generator = test_gen.flow_from_directory(test_dir, image_size, color_mode='rgb', \
                                          class_mode=None, shuffle=False, batch_size=32)
    
    return train_generator, valid_generator, test_generator
#, test_generator
def model_built(MODEL, input_shape, preprocess_input, classes, last_frozen_layer_name):
    """
        MODEL:                  pretrained model
        input_shape:            pre-trained model's input shape
        preprocessing_input:    pre-trained model's preprocessing function
        last_frozen_layer_name: last layer to frozen  
    """
    
    ## get pretrained model
    x = Input(shape=input_shape)
    
    if preprocess_input:
        x = Lambda(preprocess_input)(x)
    
    notop_model = MODEL(include_top=False, weights='imagenet', input_tensor=x, input_shape=input_shape)
    
    x = GlobalAveragePooling2D()(notop_model.output)

    ## build top layer
    x = Dropout(0.5, name='dropout_1')(x)
    out = Dense(classes, activation='softmax', name='dense_1')(x)
    
    ret_model = Model(inputs=notop_model.input, outputs=out)
    
    ## Frozen some layer
    #for layer in ret_model.layers:
        #layer.trainable = False
        #if layer.name == last_frozen_layer_name:
        #break
    
    return ret_model
import pandas as pd
"""
def get_test_result(model_obj, test_generator, model_name="default"):
    pred_test = model_obj.predict_generator(test_generator, len(test_generator))
    pred_test = np.array(pred_test)
    pred_test = pred_test.clip(min=0.005, max=0.995)
    
    df = pd.read_csv("sample_submission.csv")

    for i, fname in enumerate(test_generator.filenames):
        df.loc[df["img"] == fname] = [fname] + list(pred_test[i])
    
    df.to_csv('%s.csv' % (model_name), index=None)
    print ('test result file %s.csv generated!' % (model_name))
    df.head(10)
"""

def get_test_result(model_obj, test_generator, model_name="default"):
    print("Now to predict")
    pred_test = model_obj.predict_generator(test_generator, len(test_generator), verbose=1)
    pred_test = np.array(pred_test)
    pred_test = pred_test.clip(min=0.005, max=0.995)
    print("create datasheet")
    result = pd.DataFrame(pred_test, columns=['c0', 'c1', 'c2', 'c3',
                                                 'c4', 'c5', 'c6', 'c7',
                                                 'c8', 'c9'])
    test_filenames = []
    for f in test_generator.filenames:
        test_filenames.append(os.path.basename(f))
    result.loc[:, 'img'] = pd.Series(test_filenames, index=result.index)
    result.to_csv('%s.csv' % (model_name), index=None)
    print ('test result file %s.csv generated!' % (model_name))
resnet50_train_generator, resnet50_valid_generator, resnet50_test_generator = get_data_generator(link_train_path, link_valid_path, test_link, resNet_input_shape[:2])
Found 17943 images belonging to 10 classes.
Found 4481 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.

现在数据准备好了,接下来就是搭建模型,开放所有层进行训练。

resnet50_model = model_built(resnet50.ResNet50, resNet_input_shape, resnet50.preprocess_input, 10, None)
from keras import optimizers

sgd = optimizers.SGD(lr=0.0001, decay=1e-6, momentum=0.9, nesterov=True)

resnet50_model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])

hist = resnet50_model.fit_generator(resnet50_train_generator, len(resnet50_train_generator), epochs=3,\
                                    validation_data=resnet50_valid_generator, validation_steps=len(resnet50_valid_generator))
WARNING:tensorflow:Variable *= will be deprecated. Use variable.assign_mul if you want assignment to the variable value or 'x = x * y' if you want a new python Tensor object.
Epoch 1/3
281/281 [==============================] - 374s 1s/step - loss: 2.0012 - acc: 0.3205 - val_loss: 1.0765 - val_acc: 0.8246
Epoch 2/3
281/281 [==============================] - 112s 398ms/step - loss: 0.9246 - acc: 0.7555 - val_loss: 0.4652 - val_acc: 0.9362
Epoch 3/3
281/281 [==============================] - 112s 399ms/step - loss: 0.4605 - acc: 0.9049 - val_loss: 0.2505 - val_acc: 0.9616
hist = resnet50_model.fit_generator(resnet50_train_generator, len(resnet50_train_generator), epochs=5,\
                                    validation_data=resnet50_valid_generator, validation_steps=len(resnet50_valid_generator))
Epoch 1/5
281/281 [==============================] - 115s 411ms/step - loss: 0.2670 - acc: 0.9494 - val_loss: 0.1625 - val_acc: 0.9728
Epoch 2/5
281/281 [==============================] - 113s 402ms/step - loss: 0.1811 - acc: 0.9661 - val_loss: 0.1189 - val_acc: 0.9788
Epoch 3/5
281/281 [==============================] - 113s 403ms/step - loss: 0.1265 - acc: 0.9781 - val_loss: 0.0949 - val_acc: 0.9828
Epoch 4/5
281/281 [==============================] - 112s 398ms/step - loss: 0.0971 - acc: 0.9834 - val_loss: 0.0788 - val_acc: 0.9850
Epoch 5/5
281/281 [==============================] - 111s 395ms/step - loss: 0.0777 - acc: 0.9868 - val_loss: 0.0683 - val_acc: 0.9853

由于时间关系我就不继续往下训练了,可以看到我们已经将“训练集loss高居不下”的问题解决了,也就是之前列出来的两个问题中的第一个。看起来一切都很完美,完美得我们都已经忽略存在的第二个问题,因为从上面的训练结果不太明显(其实也能看出来,因为每次epoch之后,训练loss都明显高于验证loss),如果你执意认为没问题,那好,让我们将该模型在测试集上运行,然后将运行结果提交kaggle就知道了。

get_test_result(resnet50_model, resnet50_test_generator, model_name="resnet-50")
Now to predict
  50/1246 [>.............................] - ETA: 6:25


---------------------------------------------------------------------------

KeyboardInterrupt                         Traceback (most recent call last)

<ipython-input-27-89baaf9d67e9> in <module>()
----> 1 get_test_result(resnet50_model, resnet50_test_generator, model_name="resnet-50")


<ipython-input-26-b7ba56715628> in get_test_result(model_obj, test_generator, model_name)
     18 def get_test_result(model_obj, test_generator, model_name="default"):
     19     print("Now to predict")
---> 20     pred_test = model_obj.predict_generator(test_generator, len(test_generator), verbose=1)
     21     pred_test = np.array(pred_test)
     22     pred_test = pred_test.clip(min=0.005, max=0.995)


~/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper


~/anaconda3/lib/python3.6/site-packages/keras/engine/training.py in predict_generator(self, generator, steps, max_queue_size, workers, use_multiprocessing, verbose)
   2490 
   2491             while steps_done < steps:
-> 2492                 generator_output = next(output_generator)
   2493                 if isinstance(generator_output, tuple):
   2494                     # Compatibility with the generators


~/anaconda3/lib/python3.6/site-packages/keras/utils/data_utils.py in get(self)
    576         try:
    577             while self.is_running():
--> 578                 inputs = self.queue.get(block=True).get()
    579                 self.queue.task_done()
    580                 if inputs is not None:


~/anaconda3/lib/python3.6/multiprocessing/pool.py in get(self, timeout)
    636 
    637     def get(self, timeout=None):
--> 638         self.wait(timeout)
    639         if not self.ready():
    640             raise TimeoutError


~/anaconda3/lib/python3.6/multiprocessing/pool.py in wait(self, timeout)
    633 
    634     def wait(self, timeout=None):
--> 635         self._event.wait(timeout)
    636 
    637     def get(self, timeout=None):


~/anaconda3/lib/python3.6/threading.py in wait(self, timeout)
    549             signaled = self._flag
    550             if not signaled:
--> 551                 signaled = self._cond.wait(timeout)
    552             return signaled
    553 


~/anaconda3/lib/python3.6/threading.py in wait(self, timeout)
    293         try:    # restore state no matter what (e.g., KeyboardInterrupt)
    294             if timeout is None:
--> 295                 waiter.acquire()
    296                 gotit = True
    297             else:


KeyboardInterrupt: 

不出所料,提交到kaggle之后,loss是惊人的2.3。下一小节开始分析此问题

2.1.2 针对“验证集loss远远低于训练集loss”问题

现在就来分析一下哪里有问题吧!从之前的训练过程可以看出,每一次验证集的loss都是要低于测试集的loss,这很不正常。说明我们验证集直接或者间接的过拟合了。为什么会过拟合呢?其实问题原因还是在数据上。我们之前说过,关于同一个司机的数据是由同一个摄像头采集于同一场景,也就是说每一个司机的同一个状态的数据会非常相似(相当于视频流的连续帧),如果这些数据有一些在训练集上,有一些在测试集上,那么模型就相当于已经“看过”了验证集的数据再来做验证,那么此时就造成了验证集的过拟合(相当于用训练数据来做验证),所以呢,验证集的数据不应该出现在测试集上,所以我们应该按照司机的ID来划分验证集和测试集,将一个司机的所有图片用来当做验证集。为了让验证集达到训练集的20%,我这里选择了四个司机(p021,p022,p024,p026)的图片(总共:4800多张)。

import matplotlib.pyplot as plt

def show_loss(hist, title='loss'):
    # show the training and validation loss
    plt.plot(hist.history['val_loss'], label="validation loss")
    plt.plot(hist.history['loss'], label="train loss")
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.title(title)
    plt.legend()
    plt.show()
import pandas as pd

link_path = 'train_link'
link_train_path = 'train_link/train'
link_valid_path = 'train_link/validation'

test_link_path = 'test_link/data'

classes = ['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9']

validation_drivers = ["p021","p022","p024","p026"]

if not isdir(test_link_path):
    os.makedirs(test_link_path)
    c_filenames = os.listdir(test_dir_name)
    
    for file in c_filenames:
        os.symlink('../../' + test_dir_name + '/' + file, test_link_path + '/' + file)

if not isdir(link_path):
    os.makedirs(link_train_path)
    os.makedirs(link_valid_path)
    
    # get validation file names
    df = pd.read_csv(driver_imgs_list_file)
    
    validation_files = list(df[df['subject'].isin(validation_drivers)]['img'])
    
    # make c0-c9
    for c in classes:
        os.makedirs(link_train_path + '/' + c)
        os.makedirs(link_valid_path + '/' + c)
        
    # create link from train dir
    for c in classes:
        # get path name
        c_path = train_dir_name + '/' + c
        train_dst_path = link_train_path + '/' + c
        valid_dst_path = link_valid_path + '/' + c
        
        # list all file name of this class
        c_filenames = os.listdir(c_path)
        
        # create validation data of this class
        for file in c_filenames:
            if file in validation_files:
                os.symlink('../../../' + c_path + '/' + file, valid_dst_path + '/' + file)
            else:
                os.symlink('../../../' + c_path + '/' + file, train_dst_path + '/' + file)
def model_built(MODEL, input_shape, preprocess_input, classes, last_frozen_layer_name):
    """
        MODEL:                  pretrained model
        input_shape:            pre-trained model's input shape
        preprocessing_input:    pre-trained model's preprocessing function
        last_frozen_layer_name: last layer to frozen  
    """
    
    ## get pretrained model
    x = Input(shape=input_shape)
    
    if preprocess_input:
        x = Lambda(preprocess_input)(x)
    
    notop_model = MODEL(include_top=False, weights='imagenet', input_tensor=x, input_shape=input_shape)
    
    x = GlobalAveragePooling2D()(notop_model.output)

    ## build top layer
    x = Dropout(0.5, name='dropout_1')(x)
    out = Dense(classes, activation='softmax', name='dense_1')(x)
    
    ret_model = Model(inputs=notop_model.input, outputs=out)
    
    ## Frozen some layer
    #for layer in ret_model.layers:
        #layer.trainable = False
        #if layer.name == last_frozen_layer_name:
            #break
    
    return ret_model
from keras import optimizers
from keras.callbacks import ModelCheckpoint
from keras.applications import xception, resnet50

# get generator
resNet_input_shape = (224,224,3)
resnet50_train_generator, resnet50_valid_generator, resnet50_test_generator = get_data_generator(link_train_path, link_valid_path, test_link, resNet_input_shape[:2])
##xception_train_generator, xception_valid_generator, xception_test_generator = get_data_generator(link_train_path, link_valid_path, test_link, (299, 299))

# build model
resnet50_model = model_built(resnet50.ResNet50, resNet_input_shape, resnet50.preprocess_input, 10, None)
#xception_model = model_built(xception.Xception, (299, 299, 3), xception.preprocess_input, 10, None)

# trainmodel
ckpt = ModelCheckpoint('resnet50.weights.{epoch:02d}-{val_loss:.2f}.hdf5', verbose=1, save_best_only=True, save_weights_only=True)

adam = optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
resnet50_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])

hist = resnet50_model.fit_generator(resnet50_train_generator, len(resnet50_train_generator), epochs=6,\
                                    validation_data=resnet50_valid_generator, validation_steps=len(resnet50_valid_generator), callbacks=[ckpt])

#ckpt = ModelCheckpoint('xception.weights.{epoch:02d}-{val_loss:.2f}.hdf5', verbose=1, save_best_only=True, save_weights_only=True)

#adam = optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
#xception_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
#hist = xception_model.fit_generator(xception_train_generator, len(xception_train_generator), epochs=6,\
                                    #validation_data=xception_valid_generator, validation_steps=len(xception_valid_generator), callbacks=[ckpt])
Found 17532 images belonging to 10 classes.
Found 4892 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/6
548/548 [==============================] - 406s 742ms/step - loss: 0.1727 - acc: 0.9488 - val_loss: 0.4518 - val_acc: 0.8581

Epoch 00001: val_loss improved from inf to 0.45175, saving model to resnet50.weights.01-0.45.hdf5
Epoch 2/6
548/548 [==============================] - 110s 201ms/step - loss: 0.0169 - acc: 0.9956 - val_loss: 0.3573 - val_acc: 0.8884

Epoch 00002: val_loss improved from 0.45175 to 0.35731, saving model to resnet50.weights.02-0.36.hdf5
Epoch 3/6
548/548 [==============================] - 111s 202ms/step - loss: 0.0214 - acc: 0.9942 - val_loss: 0.4290 - val_acc: 0.8581

Epoch 00003: val_loss did not improve
Epoch 4/6
548/548 [==============================] - 110s 201ms/step - loss: 0.0126 - acc: 0.9970 - val_loss: 0.3218 - val_acc: 0.8996

Epoch 00004: val_loss improved from 0.35731 to 0.32176, saving model to resnet50.weights.04-0.32.hdf5
Epoch 5/6
548/548 [==============================] - 110s 201ms/step - loss: 0.0170 - acc: 0.9958 - val_loss: 0.3336 - val_acc: 0.9064

Epoch 00005: val_loss did not improve
Epoch 6/6
548/548 [==============================] - 111s 202ms/step - loss: 0.0157 - acc: 0.9961 - val_loss: 0.3360 - val_acc: 0.8894

Epoch 00006: val_loss did not improve
#adam = optimizers.Adam(lr=1e-5, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
#xception_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])

#xception_model.load_weights('weights.06-0.25.hdf5')
#hist = xception_model.fit_generator(xception_train_generator, len(xception_train_generator), epochs=6,\
                                    #validation_data=xception_valid_generator, validation_steps=len(xception_valid_generator), \
                                    #callbacks=[ckpt])
adam = optimizers.Adam(lr=1e-5, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
resnet50_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])

hist = resnet50_model.fit_generator(resnet50_train_generator, len(resnet50_train_generator), epochs=6,\
                                    validation_data=resnet50_valid_generator, validation_steps=len(resnet50_valid_generator), \
                                    callbacks=[ckpt], initial_epoch=4)
Epoch 5/6
548/548 [==============================] - 117s 214ms/step - loss: 6.0906e-04 - acc: 0.9999 - val_loss: 0.3074 - val_acc: 0.9004

Epoch 00005: val_loss improved from 0.32176 to 0.30740, saving model to resnet50.weights.05-0.31.hdf5
Epoch 6/6
548/548 [==============================] - 110s 200ms/step - loss: 1.3819e-04 - acc: 1.0000 - val_loss: 0.2772 - val_acc: 0.9101

Epoch 00006: val_loss improved from 0.30740 to 0.27716, saving model to resnet50.weights.06-0.28.hdf5
hist = resnet50_model.fit_generator(resnet50_train_generator, len(resnet50_train_generator), epochs=6,\
                                    validation_data=resnet50_valid_generator, validation_steps=len(resnet50_valid_generator), \
                                    callbacks=[ckpt])
Epoch 1/6
548/548 [==============================] - 111s 203ms/step - loss: 8.4052e-05 - acc: 1.0000 - val_loss: 0.2922 - val_acc: 0.9099

Epoch 00001: val_loss did not improve
Epoch 2/6
548/548 [==============================] - 110s 200ms/step - loss: 6.5537e-04 - acc: 0.9999 - val_loss: 0.2701 - val_acc: 0.9129

Epoch 00002: val_loss improved from 0.27716 to 0.27013, saving model to resnet50.weights.02-0.27.hdf5
Epoch 3/6
548/548 [==============================] - 110s 201ms/step - loss: 2.7870e-05 - acc: 1.0000 - val_loss: 0.2763 - val_acc: 0.9135

Epoch 00003: val_loss did not improve
Epoch 4/6
548/548 [==============================] - 110s 201ms/step - loss: 2.7291e-05 - acc: 1.0000 - val_loss: 0.2937 - val_acc: 0.9099

Epoch 00004: val_loss did not improve
Epoch 5/6
548/548 [==============================] - 110s 202ms/step - loss: 1.7776e-05 - acc: 1.0000 - val_loss: 0.2863 - val_acc: 0.9121

Epoch 00005: val_loss did not improve
Epoch 6/6
548/548 [==============================] - 110s 200ms/step - loss: 1.9905e-05 - acc: 1.0000 - val_loss: 0.2925 - val_acc: 0.9119

Epoch 00006: val_loss did not improve
#get_test_result(xception_model, xception_test_generator, model_name="xception_test_result")
Now to predict
2345/2492 [===========================>..] - ETA: 1:00

可以看到,经过了上面的训练,验证集loss收敛于一个比较“真实”的loss,所以解决了我们所说的第二个问题。

上面的训练过程折腾了挺长时间,主要在于讯息速率的选择,学习速率如果选择过大,或者过小,都无法收敛到一个合适的值,学习速率太大会破坏预训练的权重,也就是我们说的“跑炸了”,学习速率太小收敛速度慢到无法忍受。

最终resnet-50模型收敛到了验证集loss为0.27,提交到kaggle之后,loss为0.34左右,在kaggle排名为前10%左右。但是我还不满足于这个结果,所以要提高模型的分数,自然就是要用到模型融合了。

3. 模型融合提高kaggle分数

模型融合的对于比赛的重要性不言而喻,观察kaggle的各个竞赛项目,排名最高的各位大神们不出意外地都是用了各种各样的模型融合方法,我把这种所谓的模型融合方法叫做集成学习方法(或者他们两个很像)。所以接下来我将使用集成学习方法中的bagging方法来进行模型的集成。

集成学习的核心思想就是三个臭皮匠订个诸葛亮,举个例子,比如你想知道一支股票的涨跌,你仅仅去询问你的一个朋友,得知预测结果后你可能还是不太方放心,如果你陆续询问多个朋友,然后把他们的建议综合起来,那么现在是不是放心多了呢?这就是集成学习的思想。如果想了解什么是集成学习,可以参考我的学习比较中关于集成学习的章节。集成学习比较重要的一个前提是diversity,也就是差异性,如果我们集成了几个没有差异性的模型,这样的集成是无效的,什么意思呢?和刚才那个例子一样,如果对于你询问的股票涨跌问题,你的所有朋友总是给出一致的答案,那么他们集成和不集成又有什么区别呢?所以我们要考虑的是差异性。

差异性的来源有两个:

  • 一个是数据本身的差异性,也就是对于同一个模型,如果放在两个不同的数据集上分别进行训练,那么一般会得到两个不同的g(x)。
  • 一个是模型的本身带来的差异性,也就是不同的模型在同样的数据上面进行训练, 一般也会得到两个不同的g(x)。

所以我们接下来的方案就是利用数据本身的差异性来得到几个拥有diversity的g(x),然后再将这些g(x)的预测结果进行uniform blending(其实就是所有的g(x)的输出值加起来然后取平均值),得到最终的预测结果。那么这里的数据差异性如何产生呢?这里我们可以借鉴K折交叉验证的方法来进行数据的划分和模型的训练,不同的是K折交叉验证的目的是为了得出一个模型的客观的分数,而我们的目的是为了产生不同的g(x)来进行融合,所以我们只是借鉴K折交叉验证的方法而不是使用K折交叉验证哟,想要了解什么是K折交叉验证请参考培神的这篇文章

3.1 通用的接口函数

3.1.1 数据相关的接口函数

生成训练集和测试集的接口函数,传入验证集的司机ID的list。注意这个接口会删除之前创建的链接文件

import pandas as pd
import os
import shutil

classes = ['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9']

def train_valid_split(validation_drivers):
    """
        validation_drivers: driver id list, like:["p021","p022","p024","p026"]
        
        warning:this function will remove the old link dir
    """
    #validation_drivers = ["p021","p022","p024","p026"]
    if isdir(test_link_path):
        shutil.rmtree(test_link_path)
    
    if not isdir(test_link_path):
        os.makedirs(test_link_path)
        c_filenames = os.listdir(test_dir_name)

        for file in c_filenames:
            os.symlink('../../' + test_dir_name + '/' + file, test_link_path + '/' + file)
    
    ## check and remove the train dir
    if isdir(link_path):
        shutil.rmtree(link_path)
    
    if not isdir(link_path):
        os.makedirs(link_train_path)
        os.makedirs(link_valid_path)

        # get validation file names
        df = pd.read_csv(driver_imgs_list_file)

        validation_files = list(df[df['subject'].isin(validation_drivers)]['img'])

        # make c0-c9
        for c in classes:
            os.makedirs(link_train_path + '/' + c)
            os.makedirs(link_valid_path + '/' + c)

        # create link from train dir
        for c in classes:
            # get path name
            c_path = train_dir_name + '/' + c
            train_dst_path = link_train_path + '/' + c
            valid_dst_path = link_valid_path + '/' + c

            # list all file name of this class
            c_filenames = os.listdir(c_path)
            
            # create validation data of this class
            for file in c_filenames:
                if file in validation_files:
                    os.symlink('../../../' + c_path + '/' + file, valid_dst_path + '/' + file)
                else:
                    os.symlink('../../../' + c_path + '/' + file, train_dst_path + '/' + file)

data generator,根据前面产生的训练集和验证集,我们来生成generator

def data_generator(train_dir, valid_dir, test_dir, image_size): 
    """
        image_size: the output of the image size, like (224, 224)
    """
    #gen = ImageDataGenerator(shear_range=0.3, zoom_range=0.3, rotation_range=0.3)
    gen = ImageDataGenerator(rotation_range=10.,
    width_shift_range=0.05,
    height_shift_range=0.05,
    shear_range=0.1,
    zoom_range=0.1)
    gen_valid = ImageDataGenerator()
    test_gen = ImageDataGenerator()
    
    # create train generator
    train_generator = gen.flow_from_directory(train_dir, image_size, color_mode='rgb', \
                                              classes=classes, class_mode='categorical', shuffle=True, batch_size=32)
    
    # create validation generator
    valid_generator = gen_valid.flow_from_directory(valid_dir, image_size, color_mode='rgb', \
                                              classes=classes, class_mode='categorical', shuffle=False, batch_size=32)
    
    test_generator = test_gen.flow_from_directory(test_dir, image_size, color_mode='rgb', \
                                          class_mode=None, shuffle=False, batch_size=32)
    
    return train_generator, valid_generator, test_generator

3.1.2 模型相关的接口函数

def get_model(MODEL, input_shape, preprocess_input, output_num):
    """
        MODEL:                  pretrained model
        input_shape:            pre-trained model's input shape
        preprocessing_input:    pre-trained model's preprocessing function
    """
    
    ## get pretrained model
    x = Input(shape=input_shape)
    
    if preprocess_input:
        x = Lambda(preprocess_input)(x)
    
    notop_model = MODEL(include_top=False, weights='imagenet', input_tensor=x, input_shape=input_shape)
    
    x = GlobalAveragePooling2D()(notop_model.output)

    ## build top layer
    x = Dropout(0.5, name='dropout_1')(x)
    out = Dense(output_num, activation='softmax', name='dense_1')(x)
    
    ret_model = Model(inputs=notop_model.input, outputs=out)
    
    return ret_model

3.13 测试文件生成接口函数

def get_test_result(model_obj, generator, result_file_name="default"):
    
    print("Now to predict result!")
    pred_test = model_obj.predict_generator(generator, len(generator), verbose=1)
    pred_test = np.array(pred_test)
    pred_test = pred_test.clip(min=0.005, max=0.995)
    
    print("Creating datasheet!")
    result = pd.DataFrame(pred_test, columns=classes)
    test_filenames = []
    for f in test_generator.filenames:
        test_filenames.append(os.path.basename(f))
    result.loc[:, 'img'] = pd.Series(test_filenames, index=result.index)
    result.to_csv(result_file_name, index=None)
    
    print ('Test result file %s generated!' % (result_file_name))
import pandas as pd

def merge_test_results(file_name_list, result_file_name='result'):
    result = None
    # get file number
    file_num = len(file_name_list)
    # read all test result
    img = []
    for i,name in enumerate(file_name_list):
        df = pd.read_csv(name)
        
        if i == 0:
            img = df['img']
            result = df.drop('img', axis=1)
        else:
            result = result + df.drop('img', axis=1)
    
    result = result / float(file_num)
    result['img'] = img
    
    result.to_csv(result_file_name, index=None)
    
    print ("Final result file: " + result_file_name)

3.2 resnet-50的模型融合

我打算融合resnet-50训练出来的四个模型,其中每个模型是由不同的训练集划分出来的,通过上面的过程我们了解到,需要根据司机ID来划分训练集和验证集,在这里我假设每个验证集包含4个司机的数据,现在我们就来划分一下这四个模型需要的训练集和验证集

3.2.1 训练模型

from keras.applications import resnet50
from keras import optimizers
from keras.callbacks import ModelCheckpoint

res_drivers_id_list = [["p039","p047","p052","p066"],
                      ["p015","p041","p049","p081"],
                      ["p002","p016","p035","p050"],
                      ["p024","p041","p049","p052"]]

res_image_size = (224, 224)
res_input_shape = (224, 224, 3)

csv_names_list = []

for i in range(4):
    
    print ('------------------------------------------------------------------------------------------------------------------')
    
    """ 1. about data """
    
    # create link and remove the old link
    train_valid_split(res_drivers_id_list[i])
    
    # get generator
    train_generator, valid_generator, test_generator = data_generator(link_train_path, link_valid_path, test_link, res_image_size)

    """ 2. about model """

    # get model
    resnet50_model = get_model(resnet50.ResNet50, res_input_shape, resnet50.preprocess_input, 10)
    
    # compile
    weights_file_name = 'resnet50.'+ str(i) + '.weights.hdf5'
    ckpt = ModelCheckpoint(weights_file_name, verbose=1, save_best_only=True, save_weights_only=True)
    adam = optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
    resnet50_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
    
    # fit
    hist_1 = resnet50_model.fit_generator(train_generator, len(train_generator), epochs=1,\
                                        validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
    
    # comile
    adam = optimizers.Adam(lr=1e-5, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
    resnet50_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
    
    # fit using a small learning rate
    hist_2 = resnet50_model.fit_generator(train_generator, len(train_generator), epochs=3,\
                                        validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
    
    # load weights
    resnet50_model.load_weights(weights_file_name)

    """ 3. about result """
    result_file_name = 'resnet50.'+ str(i) + '.test.result.hdf5'
    get_test_result(resnet50_model, test_generator, result_file_name=result_file_name)
    
    csv_names_list.append(result_file_name)
------------------------------------------------------------------------------------------------------------------
Found 19164 images belonging to 10 classes.
Found 3260 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
599/599 [==============================] - 839s 1s/step - loss: 0.2265 - acc: 0.9306 - val_loss: 0.5889 - val_acc: 0.7936

Epoch 00001: val_loss improved from inf to 0.58890, saving model to resnet50.0.weights.hdf5
Epoch 1/3
599/599 [==============================] - 306s 512ms/step - loss: 0.0173 - acc: 0.9955 - val_loss: 0.4069 - val_acc: 0.8663

Epoch 00001: val_loss improved from 0.58890 to 0.40685, saving model to resnet50.0.weights.hdf5
Epoch 2/3
599/599 [==============================] - 303s 506ms/step - loss: 0.0101 - acc: 0.9973 - val_loss: 0.5276 - val_acc: 0.8534

Epoch 00002: val_loss did not improve
Epoch 3/3
599/599 [==============================] - 299s 499ms/step - loss: 0.0052 - acc: 0.9987 - val_loss: 0.3894 - val_acc: 0.8840

Epoch 00003: val_loss improved from 0.40685 to 0.38938, saving model to resnet50.0.weights.hdf5
Now to predict result!
2492/2492 [==============================] - 1623s 651ms/step
Creating datasheet!
Test result file resnet50.0.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 19110 images belonging to 10 classes.
Found 3314 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
598/598 [==============================] - 318s 531ms/step - loss: 0.2368 - acc: 0.9269 - val_loss: 0.6308 - val_acc: 0.7948

Epoch 00001: val_loss improved from inf to 0.63077, saving model to resnet50.1.weights.hdf5
Epoch 1/3
598/598 [==============================] - 315s 527ms/step - loss: 0.0196 - acc: 0.9949 - val_loss: 0.3938 - val_acc: 0.8699

Epoch 00001: val_loss improved from 0.63077 to 0.39381, saving model to resnet50.1.weights.hdf5
Epoch 2/3
598/598 [==============================] - 308s 515ms/step - loss: 0.0114 - acc: 0.9971 - val_loss: 0.3073 - val_acc: 0.9001

Epoch 00002: val_loss improved from 0.39381 to 0.30732, saving model to resnet50.1.weights.hdf5
Epoch 3/3
598/598 [==============================] - 307s 514ms/step - loss: 0.0075 - acc: 0.9980 - val_loss: 0.4784 - val_acc: 0.8721

Epoch 00003: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 403s 162ms/step
Creating datasheet!
Test result file resnet50.1.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 18983 images belonging to 10 classes.
Found 3441 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
594/594 [==============================] - 322s 541ms/step - loss: 0.2399 - acc: 0.9267 - val_loss: 0.4065 - val_acc: 0.8750

Epoch 00001: val_loss improved from inf to 0.40654, saving model to resnet50.2.weights.hdf5
Epoch 1/3
594/594 [==============================] - 322s 542ms/step - loss: 0.0191 - acc: 0.9955 - val_loss: 0.3682 - val_acc: 0.9195

Epoch 00001: val_loss improved from 0.40654 to 0.36821, saving model to resnet50.2.weights.hdf5
Epoch 2/3
594/594 [==============================] - 309s 521ms/step - loss: 0.0081 - acc: 0.9978 - val_loss: 0.4680 - val_acc: 0.9183

Epoch 00002: val_loss did not improve
Epoch 3/3
594/594 [==============================] - 310s 522ms/step - loss: 0.0066 - acc: 0.9979 - val_loss: 0.3676 - val_acc: 0.9160

Epoch 00003: val_loss improved from 0.36821 to 0.36761, saving model to resnet50.2.weights.hdf5
Now to predict result!
2492/2492 [==============================] - 407s 163ms/step
Creating datasheet!
Test result file resnet50.2.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 18842 images belonging to 10 classes.
Found 3582 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
589/589 [==============================] - 327s 556ms/step - loss: 0.2486 - acc: 0.9233 - val_loss: 0.3011 - val_acc: 0.8867

Epoch 00001: val_loss improved from inf to 0.30106, saving model to resnet50.3.weights.hdf5
Epoch 1/3
589/589 [==============================] - 328s 556ms/step - loss: 0.0211 - acc: 0.9944 - val_loss: 0.1510 - val_acc: 0.9506

Epoch 00001: val_loss improved from 0.30106 to 0.15104, saving model to resnet50.3.weights.hdf5
Epoch 2/3
589/589 [==============================] - 305s 518ms/step - loss: 0.0101 - acc: 0.9971 - val_loss: 0.1875 - val_acc: 0.9369

Epoch 00002: val_loss did not improve
Epoch 3/3
589/589 [==============================] - 306s 520ms/step - loss: 0.0064 - acc: 0.9982 - val_loss: 0.2701 - val_acc: 0.9160

Epoch 00003: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 402s 161ms/step
Creating datasheet!
Test result file resnet50.3.test.result.hdf5 generated!

3.2.2 融合测试结果

merge_test_results(csv_names_list, 'resnet50.finall.result.csv')
Final result file: resnet50.finall.result.csv

可以看出在做出了resnet50的模型融合以后,提交kaggle发现loss由0.35下降到了0.23,效果明显提升了。

3.3 xception的模型融合

我打算融合xception训练出来的四个模型,其中每个模型是由不同的训练集划分出来的,出了模型不同外,其他均和resnet50的操作一致。

3.3.1 训练模型

from keras.applications import xception
from keras import optimizers
from keras.callbacks import ModelCheckpoint

xcp_drivers_id_list = [['p016', 'p072', 'p026', 'p066'],
                      ['p042', 'p022', 'p045', 'p075'],
                      ['p064', 'p047', 'p056', 'p061'],
                      ['p039', 'p012', 'p015', 'p052']]

xcp_image_size = (299, 299)
xcp_input_shape = (299, 299, 3)

xcp_csv_names_list = []

for i in range(4):
    
    print ('------------------------------------------------------------------------------------------------------------------')
    
    """ 1. about data """
    
    # create link and remove the old link
    train_valid_split(xcp_drivers_id_list[i])
    
    # get generator
    train_generator, valid_generator, test_generator = data_generator(link_train_path, link_valid_path, test_link, xcp_image_size)

    """ 2. about model """

    # get model
    xception_model = get_model(xception.Xception, xcp_input_shape, xception.preprocess_input, 10)
    
    # compile
    weights_file_name = 'xception.'+ str(i) + '.weights.hdf5'
    ckpt = ModelCheckpoint(weights_file_name, verbose=1, save_best_only=True, save_weights_only=True)
    adam = optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
    xception_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
    
    # fit
    hist_1 = xception_model.fit_generator(train_generator, len(train_generator), epochs=1,\
                                        validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
    
    # comile
    adam = optimizers.Adam(lr=1e-5, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
    xception_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
    
    # fit using a small learning rate
    hist_2 = xception_model.fit_generator(train_generator, len(train_generator), epochs=3,\
                                        validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
    
    # load weights
    xception_model.load_weights(weights_file_name)

    """ 3. about result """
    result_file_name = 'xception.'+ str(i) + '.test.result.hdf5'
    get_test_result(xception_model, test_generator, result_file_name=result_file_name)
    
    xcp_csv_names_list.append(result_file_name)
------------------------------------------------------------------------------------------------------------------
Found 18770 images belonging to 10 classes.
Found 3654 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
587/587 [==============================] - 471s 803ms/step - loss: 0.2714 - acc: 0.9238 - val_loss: 0.3599 - val_acc: 0.8815

Epoch 00001: val_loss improved from inf to 0.35995, saving model to xception.0.weights.hdf5
Epoch 1/3
587/587 [==============================] - 477s 812ms/step - loss: 0.0140 - acc: 0.9966 - val_loss: 0.3484 - val_acc: 0.8878

Epoch 00001: val_loss improved from 0.35995 to 0.34837, saving model to xception.0.weights.hdf5
Epoch 2/3
587/587 [==============================] - 457s 778ms/step - loss: 0.0074 - acc: 0.9985 - val_loss: 0.3467 - val_acc: 0.8935

Epoch 00002: val_loss improved from 0.34837 to 0.34673, saving model to xception.0.weights.hdf5
Epoch 3/3
587/587 [==============================] - 458s 780ms/step - loss: 0.0047 - acc: 0.9990 - val_loss: 0.3608 - val_acc: 0.8886

Epoch 00003: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 456s 183ms/step
Creating datasheet!
Test result file xception.0.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 19062 images belonging to 10 classes.
Found 3362 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
596/596 [==============================] - 491s 824ms/step - loss: 0.2813 - acc: 0.9198 - val_loss: 0.2815 - val_acc: 0.9087

Epoch 00001: val_loss improved from inf to 0.28149, saving model to xception.1.weights.hdf5
Epoch 1/3
596/596 [==============================] - 486s 816ms/step - loss: 0.0164 - acc: 0.9960 - val_loss: 0.1917 - val_acc: 0.9337

Epoch 00001: val_loss improved from 0.28149 to 0.19174, saving model to xception.1.weights.hdf5
Epoch 2/3
596/596 [==============================] - 460s 771ms/step - loss: 0.0091 - acc: 0.9984 - val_loss: 0.1643 - val_acc: 0.9485

Epoch 00002: val_loss improved from 0.19174 to 0.16431, saving model to xception.1.weights.hdf5
Epoch 3/3
596/596 [==============================] - 491s 824ms/step - loss: 0.0055 - acc: 0.9989 - val_loss: 0.1670 - val_acc: 0.9465

Epoch 00003: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 510s 205ms/step
Creating datasheet!
Test result file xception.1.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 19166 images belonging to 10 classes.
Found 3258 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
599/599 [==============================] - 542s 904ms/step - loss: 0.2616 - acc: 0.9270 - val_loss: 0.5691 - val_acc: 0.8536

Epoch 00001: val_loss improved from inf to 0.56912, saving model to xception.2.weights.hdf5
Epoch 1/3
599/599 [==============================] - 538s 899ms/step - loss: 0.0131 - acc: 0.9971 - val_loss: 0.3631 - val_acc: 0.9033

Epoch 00001: val_loss improved from 0.56912 to 0.36315, saving model to xception.2.weights.hdf5
Epoch 2/3
599/599 [==============================] - 508s 848ms/step - loss: 0.0066 - acc: 0.9987 - val_loss: 0.3985 - val_acc: 0.8929

Epoch 00002: val_loss did not improve
Epoch 3/3
599/599 [==============================] - 504s 842ms/step - loss: 0.0040 - acc: 0.9991 - val_loss: 0.4135 - val_acc: 0.8895

Epoch 00003: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 494s 198ms/step
Creating datasheet!
Test result file xception.2.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 19335 images belonging to 10 classes.
Found 3089 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
605/605 [==============================] - 548s 905ms/step - loss: 0.2640 - acc: 0.9264 - val_loss: 0.3926 - val_acc: 0.8799

Epoch 00001: val_loss improved from inf to 0.39262, saving model to xception.3.weights.hdf5
Epoch 1/3
605/605 [==============================] - 547s 904ms/step - loss: 0.0147 - acc: 0.9963 - val_loss: 0.3121 - val_acc: 0.9009

Epoch 00001: val_loss improved from 0.39262 to 0.31212, saving model to xception.3.weights.hdf5
Epoch 2/3
605/605 [==============================] - 516s 853ms/step - loss: 0.0075 - acc: 0.9982 - val_loss: 0.3219 - val_acc: 0.8912

Epoch 00002: val_loss did not improve
Epoch 3/3
605/605 [==============================] - 515s 851ms/step - loss: 0.0047 - acc: 0.9992 - val_loss: 0.3226 - val_acc: 0.8880

Epoch 00003: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 515s 207ms/step
Creating datasheet!
Test result file xception.3.test.result.hdf5 generated!
merge_test_results(xcp_csv_names_list, 'xception.finall.result.csv')
Final result file: xception.finall.result.csv

xception的模型的融合结果提交kaggle之后,loss为:0.27290

3.3 inceptionv3的模型融合

我打算融合inceptionv3训练出来的四个模型,其中每个模型是由不同的训练集划分出来的,出了模型不同外,其他均和resnet50的操作一致。

from keras.applications import inception_v3
from keras import optimizers
from keras.callbacks import ModelCheckpoint

inc_drivers_id_list = [['p064', 'p056', 'p047', 'p041'],
                      ['p051', 'p072', 'p052', 'p049'],
                      ['p002', 'p050', 'p014', 'p075'],
                      ['p035', 'p039', 'p012', 'p081']]

inc_image_size = (299, 299)
inc_input_shape = (299, 299, 3)

inc_csv_names_list = []

for i in range(4):
    
    print ('------------------------------------------------------------------------------------------------------------------')
    
    """ 1. about data """
    
    # create link and remove the old link
    train_valid_split(inc_drivers_id_list[i])
    
    # get generator
    train_generator, valid_generator, test_generator = data_generator(link_train_path, link_valid_path, test_link, inc_image_size)

    """ 2. about model """

    # get model
    inceptionv3_model = get_model(inception_v3.InceptionV3, inc_input_shape, inception_v3.preprocess_input, 10)
    
    # compile
    weights_file_name = 'inceptionv3.'+ str(i) + '.weights.hdf5'
    ckpt = ModelCheckpoint(weights_file_name, verbose=1, save_best_only=True, save_weights_only=True)
    adam = optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
    inceptionv3_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
    
    # fit
    hist_1 = inceptionv3_model.fit_generator(train_generator, len(train_generator), epochs=1,\
                                        validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
    
    # comile
    adam = optimizers.Adam(lr=1e-5, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
    inceptionv3_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
    
    # fit using a small learning rate
    hist_2 = inceptionv3_model.fit_generator(train_generator, len(train_generator), epochs=2,\
                                        validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
    
    # load weights
    inceptionv3_model.load_weights(weights_file_name)

    """ 3. about result """
    result_file_name = 'inceptionv3.'+ str(i) + '.test.result.hdf5'
    get_test_result(inceptionv3_model, test_generator, result_file_name=result_file_name)
    
    inc_csv_names_list.append(result_file_name)
------------------------------------------------------------------------------------------------------------------
Found 19370 images belonging to 10 classes.
Found 3054 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
606/606 [==============================] - 557s 920ms/step - loss: 0.2447 - acc: 0.9254 - val_loss: 0.6483 - val_acc: 0.8471

Epoch 00001: val_loss improved from inf to 0.64835, saving model to inceptionv3.0.weights.hdf5
Epoch 1/2
606/606 [==============================] - 560s 924ms/step - loss: 0.0203 - acc: 0.9944 - val_loss: 0.4773 - val_acc: 0.8661

Epoch 00001: val_loss improved from 0.64835 to 0.47726, saving model to inceptionv3.0.weights.hdf5
Epoch 2/2
606/606 [==============================] - 513s 846ms/step - loss: 0.0091 - acc: 0.9979 - val_loss: 0.4670 - val_acc: 0.8752

Epoch 00002: val_loss improved from 0.47726 to 0.46698, saving model to inceptionv3.0.weights.hdf5
Now to predict result!
2492/2492 [==============================] - 511s 205ms/step
Creating datasheet!
Test result file inceptionv3.0.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 19407 images belonging to 10 classes.
Found 3017 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
607/607 [==============================] - 572s 943ms/step - loss: 0.2530 - acc: 0.9239 - val_loss: 0.5091 - val_acc: 0.8472

Epoch 00001: val_loss improved from inf to 0.50909, saving model to inceptionv3.1.weights.hdf5
Epoch 1/2
607/607 [==============================] - 518s 853ms/step - loss: 0.0101 - acc: 0.9974 - val_loss: 0.3786 - val_acc: 0.8810

Epoch 00002: val_loss improved from 0.38029 to 0.37860, saving model to inceptionv3.1.weights.hdf5
Now to predict result!
2492/2492 [==============================] - 521s 209ms/step
Creating datasheet!
Test result file inceptionv3.1.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 19219 images belonging to 10 classes.
Found 3205 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
601/601 [==============================] - 568s 945ms/step - loss: 0.2453 - acc: 0.9248 - val_loss: 0.5244 - val_acc: 0.8537

Epoch 00001: val_loss improved from inf to 0.52442, saving model to inceptionv3.2.weights.hdf5
Epoch 1/2
601/601 [==============================] - 568s 945ms/step - loss: 0.0153 - acc: 0.9962 - val_loss: 0.4941 - val_acc: 0.8955

Epoch 00001: val_loss improved from 0.52442 to 0.49411, saving model to inceptionv3.2.weights.hdf5
Epoch 2/2
601/601 [==============================] - 509s 847ms/step - loss: 0.0069 - acc: 0.9982 - val_loss: 0.5401 - val_acc: 0.8880

Epoch 00002: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 513s 206ms/step
Creating datasheet!
Test result file inceptionv3.2.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 19279 images belonging to 10 classes.
Found 3145 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
603/603 [==============================] - 593s 983ms/step - loss: 0.2375 - acc: 0.9272 - val_loss: 0.5779 - val_acc: 0.8391

Epoch 00001: val_loss improved from inf to 0.57787, saving model to inceptionv3.3.weights.hdf5
Epoch 1/2
603/603 [==============================] - 582s 966ms/step - loss: 0.0224 - acc: 0.9938 - val_loss: 0.3650 - val_acc: 0.8979

Epoch 00001: val_loss improved from 0.57787 to 0.36498, saving model to inceptionv3.3.weights.hdf5
Epoch 2/2
603/603 [==============================] - 515s 853ms/step - loss: 0.0090 - acc: 0.9978 - val_loss: 0.3821 - val_acc: 0.8957

Epoch 00002: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 521s 209ms/step
Creating datasheet!
Test result file inceptionv3.3.test.result.hdf5 generated!
merge_test_results(inc_csv_names_list, 'inceptionv3.finall.result.csv')
Final result file: inceptionv3.finall.result.csv

提交kaggle之后分数为0.31467

3.4 inception_resnet_v2的模型融合

我打算融合inception_resnet_v2训练出来的四个模型,其中每个模型是由不同的训练集划分出来的,出了模型不同外,其他均和resnet50的操作一致。我不打算再调

from keras.applications import inception_resnet_v2
from keras import optimizers
from keras.callbacks import ModelCheckpoint

ire_drivers_id_list = [['p041', 'p026', 'p022', 'p002'],
                      ['p061', 'p081', 'p015', 'p075'],
                      ['p042', 'p045', 'p047', 'p012'],
                      ['p064', 'p051', 'p072', 'p021']]

ire_image_size = (299, 299)
ire_input_shape = (299, 299, 3)

ire_csv_names_list = []

for i in range(4):
    
    print ('------------------------------------------------------------------------------------------------------------------')
    
    """ 1. about data """
    
    # create link and remove the old link
    train_valid_split(ire_drivers_id_list[i])
    
    # get generator
    train_generator, valid_generator, test_generator = data_generator(link_train_path, link_valid_path, test_link, ire_image_size)

    """ 2. about model """

    # get model
    inception_resnet_v2_model = get_model(inception_resnet_v2.InceptionResNetV2, ire_input_shape, inception_resnet_v2.preprocess_input, 10)
    
    # compile
    weights_file_name = 'inception_resnet_v2.'+ str(i) + '.weights.hdf5'
    ckpt = ModelCheckpoint(weights_file_name, verbose=1, save_best_only=True, save_weights_only=True)
    adam = optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
    inception_resnet_v2_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
    
    # fit
    hist_1 = inception_resnet_v2_model.fit_generator(train_generator, len(train_generator), epochs=1,\
                                        validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
    
    # comile
    adam = optimizers.Adam(lr=1e-5, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
    inception_resnet_v2_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
    
    # fit using a small learning rate
    hist_2 = inception_resnet_v2_model.fit_generator(train_generator, len(train_generator), epochs=3,\
                                        validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
    
    # load weights
    inception_resnet_v2_model.load_weights(weights_file_name)

    """ 3. about result """
    result_file_name = 'inception_resnet_v2.'+ str(i) + '.test.result.hdf5'
    get_test_result(inception_resnet_v2_model, test_generator, result_file_name=result_file_name)
    
    ire_csv_names_list.append(result_file_name)
------------------------------------------------------------------------------------------------------------------
Found 18665 images belonging to 10 classes.
Found 3759 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
584/584 [==============================] - 532s 912ms/step - loss: 0.2403 - acc: 0.9257 - val_loss: 0.5532 - val_acc: 0.8843

Epoch 00001: val_loss improved from inf to 0.55320, saving model to inception_resnet_v2.0.weights.hdf5
Epoch 1/3
584/584 [==============================] - 542s 928ms/step - loss: 0.0168 - acc: 0.9959 - val_loss: 0.2754 - val_acc: 0.9146

Epoch 00001: val_loss improved from 0.55320 to 0.27539, saving model to inception_resnet_v2.0.weights.hdf5
Epoch 2/3
584/584 [==============================] - 506s 867ms/step - loss: 0.0076 - acc: 0.9981 - val_loss: 0.3755 - val_acc: 0.9013

Epoch 00002: val_loss did not improve
Epoch 3/3
584/584 [==============================] - 506s 866ms/step - loss: 0.0037 - acc: 0.9991 - val_loss: 0.2906 - val_acc: 0.9197

Epoch 00003: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 498s 200ms/step
Creating datasheet!
Test result file inception_resnet_v2.0.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 19103 images belonging to 10 classes.
Found 3321 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
597/597 [==============================] - 565s 946ms/step - loss: 0.2267 - acc: 0.9294 - val_loss: 0.4564 - val_acc: 0.8811

Epoch 00001: val_loss improved from inf to 0.45637, saving model to inception_resnet_v2.1.weights.hdf5
Epoch 1/3
597/597 [==============================] - 563s 943ms/step - loss: 0.0129 - acc: 0.9970 - val_loss: 0.4455 - val_acc: 0.8817

Epoch 00001: val_loss improved from 0.45637 to 0.44553, saving model to inception_resnet_v2.1.weights.hdf5
Epoch 2/3
597/597 [==============================] - 524s 879ms/step - loss: 0.0063 - acc: 0.9985 - val_loss: 0.6219 - val_acc: 0.8470

Epoch 00002: val_loss did not improve
Epoch 3/3
597/597 [==============================] - 522s 875ms/step - loss: 0.0036 - acc: 0.9992 - val_loss: 0.5228 - val_acc: 0.8735

Epoch 00003: val_loss did not improve
Now to predict result!
2492/2492 [==============================] - 533s 214ms/step
Creating datasheet!
Test result file inception_resnet_v2.1.test.result.hdf5 generated!
------------------------------------------------------------------------------------------------------------------
Found 19451 images belonging to 10 classes.
Found 2973 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
Epoch 1/1
 14/608 [..............................] - ETA: 41:55 - loss: 2.2983 - acc: 0.1763


---------------------------------------------------------------------------

KeyboardInterrupt                         Traceback (most recent call last)

<ipython-input-14-25af3ea26f2c> in <module>()
     37 
     38     # fit
---> 39     hist_1 = inception_resnet_v2_model.fit_generator(train_generator, len(train_generator), epochs=1,                                        validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
     40 
     41     # comile


~/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper


~/anaconda3/lib/python3.6/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   2222                     outs = self.train_on_batch(x, y,
   2223                                                sample_weight=sample_weight,
-> 2224                                                class_weight=class_weight)
   2225 
   2226                     if not isinstance(outs, list):


~/anaconda3/lib/python3.6/site-packages/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight)
   1881             ins = x + y + sample_weights
   1882         self._make_train_function()
-> 1883         outputs = self.train_function(ins)
   1884         if len(outputs) == 1:
   1885             return outputs[0]


~/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2476         session = get_session()
   2477         updated = session.run(fetches=fetches, feed_dict=feed_dict,
-> 2478                               **self.session_kwargs)
   2479         return updated[:len(self.outputs)]
   2480 


~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    903     try:
    904       result = self._run(None, fetches, feed_dict, options_ptr,
--> 905                          run_metadata_ptr)
    906       if run_metadata:
    907         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)


~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1135     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1136       results = self._do_run(handle, final_targets, final_fetches,
-> 1137                              feed_dict_tensor, options, run_metadata)
   1138     else:
   1139       results = []


~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1353     if handle is None:
   1354       return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1355                            options, run_metadata)
   1356     else:
   1357       return self._do_call(_prun_fn, self._session, handle, feeds, fetches)


~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1359   def _do_call(self, fn, *args):
   1360     try:
-> 1361       return fn(*args)
   1362     except errors.OpError as e:
   1363       message = compat.as_text(e.message)


~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1338         else:
   1339           return tf_session.TF_Run(session, options, feed_dict, fetch_list,
-> 1340                                    target_list, status, run_metadata)
   1341 
   1342     def _prun_fn(session, handle, feed_dict, fetch_list):


KeyboardInterrupt: 
ire_csv_names_list = ['inception_resnet_v2.0.test.result.hdf5',
'inception_resnet_v2.1.test.result.hdf5',
'inception_resnet_v2.2.test.result.hdf5',
'inception_resnet_v2.3.test.result.hdf5']

merge_test_results(ire_csv_names_list, 'inception_resnet_v2.finall.result.csv')
Final result file: inception_resnet_v2.finall.result.csv
all_name = ['inception_resnet_v2.finall.result.csv', 'inceptionv3.finall.result.csv', 'resnet50.finall.result.csv', 'xception.finall.result.csv']

merge_test_results(all_name, 'four.model.finall.result.csv')
Final result file: four.model.finall.result.csv

提交kaggle之后,loss为:0.29589

3.5四个模型选择性的融合

根据上面的结果,如果利用单独的模型来进行融合,resnet-50效果是最好的,可能是由于其他模型没有经过babysitting调试的过程,效果不是很理想,如果现在将四个模型都融合的话,有些模型可能会拖后腿,所以我就手工选择几个验证集表现最好的模型来进行融合

name_list = ['resnet50.3.test.result.hdf5',
'resnet50.1.test.result.hdf5',
'xception.1.test.result.hdf5',
'xception.3.test.result.hdf5',
'inception_resnet_v2.0.test.result.hdf5']

merge_test_results(name_list, 'choose_top_5_model_finall.result.csv')
Final result file: choose_top_5_model_finall.result.csv

提交kaggle之后,发现loss为0.23,已经进入了kaggle排行榜的前%10了,达到了初步估计的目标。由于时间关系(训练一个模型几乎都要差不多2个小时,4个模型就是8个小时),我不再继续优化下去了。

下面再增加几个可能的提升点:

  • 出了resnet50模型外,其他模型都没能够很好地收敛,原因是我使用了resnet-50的优化算法来对其他模型进行训练。所以一个提升点就是针对每个模型精心调试

  • 训练集增加更多的data augmentation,借鉴kaggle排名靠前的大神的方法。将同一个类别的图片进行左右拼接,这样可以提高模型的泛化能力,让模型知道他应该关注哪些地方,不应该关注哪些地方。

  • 增加更多的模型,我们知道,模型融合一个方向之一就是使用更多的模型,正常情况下这样总能得到比较好的分数(除非某个模型非常拖后腿)

4.CAM可视化

上面我们通过对模型的集成达到了不错的预测效果,那么如果我们想看看用于集成的模型到底在关注些什么,要怎么做呢?在周博磊博士的论文Learning Deep Features for Discriminative Localization中,提出了一种通过在CNN加入全局平均池化(global average pooling)层来获得class activation map,对于某一个特定的类,该类对应的class activation map表示CNN在预测该类对象的时候所关注的图像区域。产生某一个类的class activation map的方法也很简单,其实就是全局平均池化层的上一个层的各个feature map的线性组合,而线性组合的系数来自于输出层的该类对应的神经元的权重。

比如说resnet50,全局平均池化层的前一个层的输出维度为7x7x2048,那么class activation map就是这2048个7x7的feature map线性组合的结果,而线性组合的权重就是(在本项目中输出维度是10,所以权重维度就是2048x10,某一类别的神经元对应的权重就是2048个)某个类别对应的输出层神经元的权重了。

那么有了class activation map之后如何将其反映在图片中呢?现在有了一个关于某个类别的7x7的CAM,这个CAM中的某个位置的值就对应了原始图片中对应位置(比如CAM右下角的元素对应原始图片中右下角的内容)的内容对预测某一类别的重要性,所以我们就可以利用cv2的热力图的接口来进行热力图的绘制,通过热力图就可以看出模型在关注些什么。

4.1 resnet50 CAM可视化

4.1.1 加载模型

因为我们要通过获取输出层的权重和卷积层(也就是global average pooling的前一层)的激活值来计算class activation maps,所以在搭建模型的时候也需要将卷积层作为输出。

def cam_model(MODEL, input_shape, preprocess_input, output_num, weights_file_name):
    """
        MODEL:                  pretrained model
        input_shape:            pre-trained model's input shape
        preprocessing_input:    pre-trained model's preprocessing function
        weights_file_name:      weights trained on driver datasheet
    """
    
    ## get pretrained model
    x = Input(shape=input_shape)
    
    if preprocess_input:
        x = Lambda(preprocess_input)(x)
    
    notop_model = MODEL(include_top=False, weights=None, input_tensor=x, input_shape=input_shape)
    
    x = GlobalAveragePooling2D(name='global_average_2d_1')(notop_model.output)

    ## build top layer
    x = Dropout(0.5, name='dropout_1')(x)
    out = Dense(output_num, activation='softmax', name='dense_1')(x)
    
    ret_model = Model(inputs=notop_model.input, outputs=[out, notop_model.layers[-2].output])
    
    ## load weights
    ret_model.load_weights(weights_file_name)
    
    ## get the output layer weights
    weights = ret_model.layers[-1].get_weights()
    
    return ret_model, np.array(weights[0])
from keras.applications import resnet50

res_image_size = (224, 224)
res_input_shape = (224, 224, 3)

cam_model, cam_weights = cam_model(resnet50.ResNet50, res_input_shape, resnet50.preprocess_input, 10, 'resnet50.3.weights.hdf5')

print (cam_weights.shape)
(2048, 10)

4.1.2 绘制热力图

import cv2
import numpy as np

image_path = 'train/c3/img_537.jpg'
resNet_input_shape = (224,224,3)

# read image
image = cv2.imread(image_path)
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
image_input = cv2.resize(image, (resNet_input_shape[0], resNet_input_shape[1]))

image_input_m = np.expand_dims(image_input,axis=0)

# predict and get feature maps
predict_m, feature_maps_m = cam_model.predict(image_input_m)
predict = predict_m[0]
feature_maps = feature_maps_m[0]

# get the class result
class_index = np.argmax(predict)

# get the class_index unit's weights
cam_weights_c = cam_weights[:, class_index]

# get the class activation map
cam = np.matmul(feature_maps, cam_weights_c)

# normalize the cam
cam = (cam - cam.min())/(cam.max())

# do not care the low values
cam[np.where(cam<0.2)] = 0

cam = cv2.resize(cam, (resNet_input_shape[0], resNet_input_shape[1]))
cam = np.uint8(255*cam)
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

des = state_des['c'+str(class_index)]

# draw the hotmap
hotmap = cv2.applyColorMap(cam, cv2.COLORMAP_JET)

# linear combine the picture with cam
dis = cv2.addWeighted(image_input, 0.8, hotmap, 0.4, 0)

plt.title("Predict C" + str(class_index) + ':' + des)
plt.imshow(dis)
plt.axis("off")
(-0.5, 223.5, 223.5, -0.5)

png

import cv2
import numpy as np

def show_hot_map(image_path, model, cam_weights, input_shape):
    """ 1. predict """
    # read image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
    image_input = cv2.resize(image, (input_shape[0], input_shape[1]))

    image_input_m = np.expand_dims(image_input,axis=0)

    # predict and get feature maps
    predict_m, feature_maps_m = model.predict(image_input_m)
    
    """ 2. get the calss activation maps """
    predict = predict_m[0]
    feature_maps = feature_maps_m[0]

    # get the class result
    class_index = np.argmax(predict)

    # get the class_index unit's weights
    cam_weights_c = cam_weights[:, class_index]

    # get the class activation map
    cam = np.matmul(feature_maps, cam_weights_c)

    # normalize the cam
    cam = (cam - cam.min())/(cam.max())

    # do not care the low values
    cam[np.where(cam<0.2)] = 0

    cam = cv2.resize(cam, (input_shape[0], input_shape[1]))
    cam = np.uint8(255*cam)
    
    """ 3. show the hot map """
    import matplotlib.pyplot as plt
    %matplotlib inline
    %config InlineBackend.figure_format = 'retina'

    des = state_des['c'+str(class_index)]

    # draw the hotmap
    hotmap = cv2.applyColorMap(cam, cv2.COLORMAP_JET)

    # linear combine the picture with cam
    dis = cv2.addWeighted(image_input, 0.8, hotmap, 0.4, 0)

    plt.title("Predict C" + str(class_index) + ':' + des)
    plt.imshow(dis)
    plt.axis("off")
image_path = 'train/c3/img_537.jpg'
resNet_input_shape = (224,224,3)

show_hot_map(image_path, cam_model, cam_weights, res_input_shape)

可以通过更换文件名来查看每一张图片模型都在关注什么,看的出来,绝大部分情况下,模型都在做正确的事情。

CAM图中,蓝色部分代表CAM值比较大的区域,也就是模型用来进行分类的中点关注区域。

5. 视频演示

接下来是一段视频演示,该视频有两个子窗口,分别为正常视频和热力图的视频。照着官方文档使用OpenCV来做视频是很方便的

5.1 关于生成热力图像的接口

import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

def generate_hot_map(frame, cam_model, model_input_size, cam_weights, cam_size):
    """
        image_input_m:  CAM model's input
        cam_model:      CAM model
        cam_weights:    weights for CAM
        cam_size:       size of the output picture 
    """
    
    # resize frame for predict
    img_for_model = cv2.resize(frame, model_input_size)
    img_for_model = np.expand_dims(img_for_model,axis=0)
    
    """ 1. predict """
    # predict and get feature maps
    predict_m, feature_maps_m = cam_model.predict(img_for_model)
    
    """ 2. get the calss activation maps """
    predict = predict_m[0]
    feature_maps = feature_maps_m[0]

    # get the class result
    class_index = np.argmax(predict)

    # get the class_index unit's weights
    cam_weights_c = cam_weights[:, class_index]

    # get the class activation map
    cam = np.matmul(feature_maps, cam_weights_c)

    # normalize the cam
    cam = (cam - cam.min())/(cam.max())

    # do not care the low values
    cam[np.where(cam<0.2)] = 0

    cam = cv2.resize(cam, cam_size)
    cam = np.uint8(255*cam)
    
    """ 3. show the hot map """
    des = state_des['c'+str(class_index)]

    # draw the hotmap
    hotmap = cv2.applyColorMap(cam, cv2.COLORMAP_JET)

    # linear combine the picture with cam
    image_input = cv2.resize(frame, cam_size)
    dis = cv2.addWeighted(image_input, 0.8, hotmap, 0.4, 0)

    return dis,predict

5.2 关于视频的接口

import cv2
import numpy as np

## video width and high
main_video_width = 1280
main_video_high = 720

## subwindow size
sub_video_width = int(main_video_width/2)
sub_video_high = int(main_video_high*0.6)

## subwindow coordinate
sub1_coord_1 = int((main_video_high-sub_video_high)/2)
sub1_coord_2 = int((main_video_high-sub_video_high)/2) + sub_video_high
sub1_coord_3 = 0
sub1_coord_4 = sub1_coord_3 + sub_video_width

sub2_coord_1 = int((main_video_high-sub_video_high)/2)
sub2_coord_2 = int((main_video_high-sub_video_high)/2) + sub_video_high
sub2_coord_3 = sub1_coord_4
sub2_coord_4 = main_video_width

def generate_video_with_classfication(model, model_input_size, video_name_or_camera, cam_weights, generate_video_name='output.avi'):
    """
        model:                 model to predict the video
        model_input_size:      image size of the model 
        video_name_or_camera:  read videl from camera or local video
        cam_weights:           weights for CAM
        generate_video_name:   the output video name
    """

    
    """0. create videl reader and writer, and get more video message """
    cap = cv2.VideoCapture(video_name_or_camera)
    
    video_width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
    video_high = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
    video_fps = cap.get(cv2.CAP_PROP_FPS)
    print ("video size:({width}, {high})   fps:{fps}".format(width=video_width, high=video_high, fps=video_fps))
    
    ''' 0. create a new image '''
    showBigImage = np.zeros((int(main_video_high), int(main_video_width), 3), np.uint8)
    
    ''' 1. create video writer '''
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    writer = cv2.VideoWriter(generate_video_name, fourcc, 20.0, (main_video_width, main_video_high))
    
    if(cap.isOpened() == False):
        print ("Failed to open " + video_name_or_camera)
        return
    
    while True:
        
        """ 2. preprocessing and predict """
        # get fram
        ret, frame = cap.read()
        
        # check if the video is over
        if(ret != True):
            print ("Ending!")
            break
        
        # get hot map
        sub_frame_2, predict = generate_hot_map(frame, model, model_input_size, cam_weights, (sub_video_width, sub_video_high))
        
        """ 3. add text to the image and show"""
        class_index = np.argmax(predict)
        
        text = 'Predicted:  C{}  {}'.format(class_index, state_des['c'+str(class_index)])
        font = cv2.FONT_HERSHEY_SIMPLEX
        showBigImage[:] = 0
        cv2.putText(showBigImage, text, (10, sub1_coord_1-10), font, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
        
        """ 4. resize and fill 2 subwindow """
        frame = cv2.resize(frame, (sub_video_width, sub_video_high))
        showBigImage[sub1_coord_1:sub1_coord_2, sub1_coord_3:sub1_coord_4] = frame
        showBigImage[sub2_coord_1:sub2_coord_2, sub2_coord_3:sub2_coord_4] = sub_frame_2
        
        """ 5. show video """
        cv2.imshow('image', showBigImage)
        
        """ 6. save video if need """
        writer.write(showBigImage)
        
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
        
    cap.release()
    cv2.destroyAllWindows()

5.3 关于模型的接口

from keras.applications import xception

xce_image_size = (299, 299)
xce_input_shape = (299, 299, 3)

cam_model, cam_weights = cam_model(xception.Xception, xce_input_shape, xception.preprocess_input, 10, 'xception.fixed.weights.hdf5')

5.4 生成视频

generate_video_with_classfication(cam_model, xce_image_size, 'real_fei.mp4', cam_weights, generate_video_name='output.avi')
video size:(544.0, 960.0)   fps:30.03370088538598
Ending!

我自己拍摄了一段视频,利用xception的单模型进行预测,但是效果不理想,估计是因为:

  • 拍摄角度和训练集和测试集不一致,我的角度是平拍,测试集和训练集的角度是俯拍

  • 用于训练模型的视频中涉及到的车辆的佩饰和我拍摄视频的车辆的佩饰不一致,模型的泛化能力比较差

所以视频就暂时不发出来演示了。

6.一点儿改进的尝试

之前训练的几个模型所用的数据都是经过了data augmentation的,但是也许我们可以做更“激进”一点儿的data augmentation,使得我们训练出来的模型具有更强大的泛化能力。

通过浏览各个驾驶状态的驾驶员以及CAM可以发现,对区分驾驶状态有帮助的图片区域基本上都位于左上部分:头,和右下部分:手,那么我们是否可以这样进行组合,将一张图的左半部分和另一张同类型的图的右半部分进行组合(后来我也加入了图片的上下两部分的结合),形成一张新的图片,这样既保留了用于区分状态的关键区域,又使得非关键区域(以下均称为背景)不具有主导地位(因为背景会时常变化)。

所以我就简单地修改了一下keras自带的data augmentation(其中标有红色注释的部分就是修改的部分),使其对我们的同类数据进行拼接。代码如下所示

接下来我就利用修改之后的data augmentation来产生数据,对xception进行改进,看看效果如何。

from keras.preprocessing.image import *
from keras.preprocessing.image import _count_valid_files_in_directory
from keras.preprocessing.image import _list_valid_filenames_in_directory
import numpy as np

# new class, MergeImageDataGenerator, to generate the merge image
class MergeImageDataGenerator(ImageDataGenerator):
    # redefine flow_from_directory method
    def flow_from_directory(self, directory,
                    target_size=(256, 256), color_mode='rgb',
                    classes=None, class_mode='categorical',
                    batch_size=32, shuffle=True, seed=None,
                    save_to_dir=None,
                    save_prefix='',
                    save_format='png',
                    follow_links=False,
                    subset=None,
                    interpolation='nearest'):
        # just return the iterator
        return MergeDirectoryIterator(
        directory, self,
        target_size=target_size, color_mode=color_mode,
        classes=classes, class_mode=class_mode,
        data_format=self.data_format,
        batch_size=batch_size, shuffle=shuffle, seed=seed,
        save_to_dir=save_to_dir,
        save_prefix=save_prefix,
        save_format=save_format,
        follow_links=follow_links,
        subset=subset,
        interpolation=interpolation)
    
# redefine the iterator of the generator
class MergeDirectoryIterator(Iterator):
    """Iterator capable of reading images from a directory on disk.

    # Arguments
        directory: Path to the directory to read images from.
            Each subdirectory in this directory will be
            considered to contain images from one class,
            or alternatively you could specify class subdirectories
            via the `classes` argument.
        image_data_generator: Instance of `ImageDataGenerator`
            to use for random transformations and normalization.
        target_size: tuple of integers, dimensions to resize input images to.
        color_mode: One of `"rgb"`, `"grayscale"`. Color mode to read images.
        classes: Optional list of strings, names of subdirectories
            containing images from each class (e.g. `["dogs", "cats"]`).
            It will be computed automatically if not set.
        class_mode: Mode for yielding the targets:
            `"binary"`: binary targets (if there are only two classes),
            `"categorical"`: categorical targets,
            `"sparse"`: integer targets,
            `"input"`: targets are images identical to input images (mainly
                used to work with autoencoders),
            `None`: no targets get yielded (only input images are yielded).
        batch_size: Integer, size of a batch.
        shuffle: Boolean, whether to shuffle the data between epochs.
        seed: Random seed for data shuffling.
        data_format: String, one of `channels_first`, `channels_last`.
        save_to_dir: Optional directory where to save the pictures
            being yielded, in a viewable format. This is useful
            for visualizing the random transformations being
            applied, for debugging purposes.
        save_prefix: String prefix to use for saving sample
            images (if `save_to_dir` is set).
        save_format: Format to use for saving sample images
            (if `save_to_dir` is set).
        subset: Subset of data (`"training"` or `"validation"`) if
            validation_split is set in ImageDataGenerator.
        interpolation: Interpolation method used to resample the image if the
            target size is different from that of the loaded image.
            Supported methods are "nearest", "bilinear", and "bicubic".
            If PIL version 1.1.3 or newer is installed, "lanczos" is also
            supported. If PIL version 3.4.0 or newer is installed, "box" and
            "hamming" are also supported. By default, "nearest" is used.
    """

    def __init__(self, directory, image_data_generator,
                 target_size=(256, 256), color_mode='rgb',
                 classes=None, class_mode='categorical',
                 batch_size=32, shuffle=True, seed=None,
                 data_format=None,
                 save_to_dir=None, save_prefix='', save_format='png',
                 follow_links=False,
                 subset=None,
                 interpolation='nearest'):
        if data_format is None:
            data_format = K.image_data_format()
        self.directory = directory
        self.image_data_generator = image_data_generator
        self.target_size = tuple(target_size)
        if color_mode not in {'rgb', 'grayscale'}:
            raise ValueError('Invalid color mode:', color_mode,
                             '; expected "rgb" or "grayscale".')
        self.color_mode = color_mode
        self.data_format = data_format
        if self.color_mode == 'rgb':
            if self.data_format == 'channels_last':
                self.image_shape = self.target_size + (3,)
            else:
                self.image_shape = (3,) + self.target_size
        else:
            if self.data_format == 'channels_last':
                self.image_shape = self.target_size + (1,)
            else:
                self.image_shape = (1,) + self.target_size
        self.classes = classes
        if class_mode not in {'categorical', 'binary', 'sparse',
                              'input', None}:
            raise ValueError('Invalid class_mode:', class_mode,
                             '; expected one of "categorical", '
                             '"binary", "sparse", "input"'
                             ' or None.')
        self.class_mode = class_mode
        self.save_to_dir = save_to_dir
        self.save_prefix = save_prefix
        self.save_format = save_format
        self.interpolation = interpolation

        if subset is not None:
            validation_split = self.image_data_generator._validation_split
            if subset == 'validation':
                split = (0, validation_split)
            elif subset == 'training':
                split = (validation_split, 1)
            else:
                raise ValueError('Invalid subset name: ', subset,
                                 '; expected "training" or "validation"')
        else:
            split = None
        self.subset = subset

        white_list_formats = {'png', 'jpg', 'jpeg', 'bmp', 'ppm', 'tif', 'tiff'}

        # first, count the number of samples and classes
        self.samples = 0

        if not classes:
            classes = []
            for subdir in sorted(os.listdir(directory)):
                if os.path.isdir(os.path.join(directory, subdir)):
                    classes.append(subdir)
        self.num_classes = len(classes)
        self.class_indices = dict(zip(classes, range(len(classes))))

        pool = multiprocessing.pool.ThreadPool()
        function_partial = partial(_count_valid_files_in_directory,
                                   white_list_formats=white_list_formats,
                                   follow_links=follow_links,
                                   split=split)
        self.samples = sum(pool.map(function_partial,
                                    (os.path.join(directory, subdir)
                                     for subdir in classes)))

        print('!!!!!!!Found %d images belonging to %d classes.' % (self.samples, self.num_classes))

        # second, build an index of the images in the different class subfolders
        results = []

        self.filenames = []
        self.classes = np.zeros((self.samples,), dtype='int32')
        i = 0
        for dirpath in (os.path.join(directory, subdir) for subdir in classes):
            results.append(pool.apply_async(_list_valid_filenames_in_directory,
                                            (dirpath, white_list_formats, split,
                                             self.class_indices, follow_links)))
        """
            author: rikichou 2018,5,1  21:03:31
            record the index range of each class
        """
        # Start Code
        classes_range = []
        # End Code
        
        for res in results:
            classes, filenames = res.get()
            
            """
                author: rikichou 2018,5,1  21:04:01
            """
            # Start Code
            start = i
            end = i + len(classes) - 1
            classes_range.append({'start':start, 'end':end})
            # End Code
            
            self.classes[i:i + len(classes)] = classes
            self.filenames += filenames
            i += len(classes)
        """
            author: rikichou
        """
        # Start Code
        self.classes_range = classes_range
        # End Code
        
        pool.close()
        pool.join()
        super(MergeDirectoryIterator, self).__init__(self.samples, batch_size, shuffle, seed)

    def _get_batches_of_transformed_samples(self, index_array):
        batch_x = np.zeros((len(index_array),) + self.image_shape, dtype=K.floatx())
        grayscale = self.color_mode == 'grayscale'
        # build batch of image data
        for i, j in enumerate(index_array):
            fname = self.filenames[j]
            img = load_img(os.path.join(self.directory, fname),
                           grayscale=grayscale,
                           target_size=self.target_size,
                           interpolation=self.interpolation)
            
            if np.random.rand() < 0.8:
                """
                    author: rikichou 2018,5,1  21:03:31
                    before ramdom trainsform and standardize
                """
                # Start Code
                """ 1. which iamge to merge with """
                class_merge = self.classes[j]
                class_range = self.classes_range[class_merge]
                target_index = np.random.randint(class_range['start'], class_range['end'])

                """ 2. load target image """
                target_name = self.filenames[target_index]
                target_img = load_img(os.path.join(self.directory, target_name),
                                    grayscale=grayscale,
                                    target_size=self.target_size,
                                    interpolation=self.interpolation)

                """ 3. get the target image's start and end coordinate """
                if np.random.rand() < 0.5:
                    start_x = 0
                    start_y = self.target_size[0]//2
                    end_x = self.target_size[1]
                    end_y = self.target_size[0]
                else:
                    start_x = self.target_size[1]//2
                    start_y = 0
                    end_x = self.target_size[1]
                    end_y = self.target_size[0]
                from_copy = (start_x, start_y, end_x, end_y)
                to_paste = (start_x, start_y)

                """ 4. copy and paste """
                region = target_img.crop(from_copy)
                img.paste(region, to_paste)
                #EndCode
            
            x = img_to_array(img, data_format=self.data_format)
            x = self.image_data_generator.random_transform(x)
            x = self.image_data_generator.standardize(x)
            batch_x[i] = x
        # optionally save augmented images to disk for debugging purposes
        if self.save_to_dir:
            for i, j in enumerate(index_array):
                img = array_to_img(batch_x[i], self.data_format, scale=True)
                fname = '{prefix}_{index}_{hash}.{format}'.format(prefix=self.save_prefix,
                                                                  index=j,
                                                                  hash=np.random.randint(1e7),
                                                                  format=self.save_format)
                img.save(os.path.join(self.save_to_dir, fname))
        # build batch of labels
        if self.class_mode == 'input':
            batch_y = batch_x.copy()
        elif self.class_mode == 'sparse':
            batch_y = self.classes[index_array]
        elif self.class_mode == 'binary':
            batch_y = self.classes[index_array].astype(K.floatx())
        elif self.class_mode == 'categorical':
            batch_y = np.zeros((len(batch_x), self.num_classes), dtype=K.floatx())
            for i, label in enumerate(self.classes[index_array]):
                batch_y[i, label] = 1.
        else:
            return batch_x
        return batch_x, batch_y

    def next(self):
        """For python 2.x.

        # Returns
            The next batch.
        """
        with self.lock:
            index_array = next(self.index_generator)
        # The transformation of images is not under thread lock
        # so it can be done in parallel
        return self._get_batches_of_transformed_samples(index_array)
def fixed_data_generator(train_dir, valid_dir, test_dir, image_size): 
    """
        image_size: the output of the image size, like (224, 224)
    """
    gen = MergeImageDataGenerator(rotation_range=10.,
    width_shift_range=0.05,
    height_shift_range=0.05,
    shear_range=0.1,
    zoom_range=0.1)
    gen_valid = ImageDataGenerator()
    test_gen = ImageDataGenerator()
    
    # create train generator
    train_generator = gen.flow_from_directory(train_dir, image_size, color_mode='rgb', \
                                              classes=classes, class_mode='categorical', shuffle=True, batch_size=32)
    
    # create validation generator
    valid_generator = gen_valid.flow_from_directory(valid_dir, image_size, color_mode='rgb', \
                                              classes=classes, class_mode='categorical', shuffle=False, batch_size=32)
    
    test_generator = test_gen.flow_from_directory(test_dir, image_size, color_mode='rgb', \
                                          class_mode=None, shuffle=False, batch_size=32)
    
    return train_generator, valid_generator, test_generator
from keras.applications import xception
from keras import optimizers
from keras.callbacks import ModelCheckpoint

xcp_drivers_id_list = [['p016', 'p072', 'p026', 'p066', 'p075']]

xcp_image_size = (299, 299)
xcp_input_shape = (299, 299, 3)

xcp_csv_names_list = []

#for i in range(2):
i = 0
print ('------------------------------------------------------------------------------------------------------------------')

""" 1. about data """

# create link and remove the old link
train_valid_split(xcp_drivers_id_list[i])

# get generator
train_generator, valid_generator, test_generator = fixed_data_generator(link_train_path, link_valid_path, test_link, xcp_image_size)
------------------------------------------------------------------------------------------------------------------
!!!!!!!Found 17956 images belonging to 10 classes.
Found 4468 images belonging to 10 classes.
Found 79726 images belonging to 1 classes.
""" 2. about model """

# get model
xception_model = get_model(xception.Xception, xcp_input_shape, xception.preprocess_input, 10)
from keras.optimizers import Adam

# compile
weights_file_name = 'xception.fixed.weights.hdf5'
ckpt = ModelCheckpoint(weights_file_name, verbose=1, save_best_only=True, save_weights_only=True)

adam = optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
xception_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])

# fit
hist_1 = xception_model.fit_generator(train_generator, len(train_generator), epochs=1,\
                                    validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
Epoch 1/1
562/562 [==============================] - 514s 915ms/step - loss: 0.3387 - acc: 0.8970 - val_loss: 0.3116 - val_acc: 0.8957

Epoch 00001: val_loss improved from inf to 0.31157, saving model to xception.fixed.weights.hdf5
adam = optimizers.Adam(lr=1e-5, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
xception_model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])

hist_1 = xception_model.fit_generator(train_generator, len(train_generator), epochs=1,\
                                    validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
Epoch 1/1
562/562 [==============================] - 514s 915ms/step - loss: 0.0338 - acc: 0.9914 - val_loss: 0.2759 - val_acc: 0.8988

Epoch 00001: val_loss improved from 0.31157 to 0.27592, saving model to xception.fixed.weights.hdf5
hist_1 = xception_model.fit_generator(train_generator, len(train_generator), epochs=1,\
                                    validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[ckpt])
Epoch 1/1
562/562 [==============================] - 511s 910ms/step - loss: 0.0215 - acc: 0.9945 - val_loss: 0.2713 - val_acc: 0.9029

Epoch 00001: val_loss improved from 0.27592 to 0.27130, saving model to xception.fixed.weights.hdf5

可以看出来,xception的单模型表现明显提升,由于时间关系我就不再训练更多的模型进行融合了,但是这肯定是一条可行的道路。

7. 总结

通过上述的过程走下来,我们得到了一个不错的模型,在kaggle的排行榜中可以排到top 10%。该项目有多个重点,如下:

  • 验证集训练集划分的时候,因为图像数据的特殊性(相似性),所以需要根据司机ID来进行划分,而不应该根据司机状态

  • 由于同一个司机的不同状态的图像很相似,所以应该开放尽可能多的层进行训练(本项目中我没有lock任何层)

  • 由于是kaggle竞赛,衡量标准仅仅是softmax cross entropy,对模型预测时间没有要求,所以为了提高分数,总是可以采用模型融合的策略

  • 为了让模型具有更强的泛化能力,应该加入一些特殊的data augmentation。对于同一状态的图片,将随机的两张图片的左半部分和右半部分或者上半部分和下班部分组合成一张新的图片。这样做确实可以提高单模型的泛化能力,如果加上模型融合的话,分数会进一步提高

About

分心驾驶检测,检测驾驶员的驾驶状态,对分心驾驶的驾驶员做出提醒

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published