In [1]:
import numpy as np
#确保下载的权值文件和这个ipython文件再同一个文件夹里面，或者自己指定绝对路径
MODEL_ADDR=r'bvlc_alexnet.npy'
import tensorflow as tf
from scipy.misc import imread
from scipy.misc import imresize
from class_name import class_names
import time
#               conv1     conv2      conv3       conv4     conv5    fc6   fc7     fc8
variable_trable=[False,    False,     False,    False,    False,  False,  True,  True,]
variable_trable=[None]+variable_trable

In [2]:
def conv(input, kernel, biases, c_o, s_h, s_w, padding="VALID", group=1):
    '''From https://github.com/ethereon/caffe-tensorflow
    '''
    c_i = input.get_shape()[-1]
    convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding)
    if group == 1:
        #不对输入分组卷积
        conv = convolve(input, kernel)
    else:
        #将输入平分成group组，按[N,w,h,channel]->[0,1,2,3]也就是按输入的channel来分成两个矩阵
        input_groups = tf.split(input, group, 3)  # tf.split(3, group, input)
        #将卷积核平分成group组，按[w,h,in_channel,out_channel]->[0,1,2,3]也就是按输入的channel来分成两个矩阵
        kernel_groups = tf.split(kernel, group, 3)  # tf.split(3, group, kernel)
        #分组卷积
        output_groups = [convolve(i, k) for i, k in zip(input_groups, kernel_groups)]
        #连接卷积的结果
        conv = tf.concat(output_groups, 3)  # tf.concat(3, output_groups)
    return conv + biases

In [3]:
def lrn(x):
    #return x
    #lrn层，现在比较少用，一般用bn层代替
    return tf.nn.local_response_normalization(x,
                                              depth_radius=2,
                                              alpha=2e-05,
                                              beta=0.75,
                                              bias=1.0)
def maxpool(x):
    #因为alex net 用到的maxpool都是一样的参数，所以直接写以函数代替，不用填参数
    return tf.nn.max_pool(x, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='VALID')

In [4]:
def load_model_weight_and_biases():
    '''
    读取模型中的变量值，返回训练好的权重
    model_addr：模型的路径
    '''
    weights_dict = np.load(MODEL_ADDR, encoding='bytes').item()
    return weights_dict

In [5]:
def alexnet(x,net_data,keep_prob):#提取特征的时候train=False
    #layer_1 conv1-relu-lrn-maxpool
   
    with tf.name_scope('layer_1'):
        CONV1_W,CONV1_b=tf.Variable(net_data['conv1'][0],name='conv1_w',trainable=variable_trable[1]),\
        tf.Variable(net_data['conv1'][1],name='conv1_b',trainable=variable_trable[1])
        conv1_=conv(X, CONV1_W, CONV1_b, c_o=96, s_h=4, s_w=4, padding="VALID", group=1)
        relu1_=tf.nn.relu(conv1_)
        norm1=lrn(relu1_)
        maxpool1_=maxpool(norm1)
        
    #layer_2 conv2-relu-lrn-maxpool
    with tf.name_scope('layer_2'):
        CONV2_W,CONV2_b=tf.Variable(net_data['conv2'][0],name='conv2_w',trainable=variable_trable[2]), \
        tf.Variable(net_data['conv2'][1],name='conv2_b',trainable=variable_trable[2])
        conv2_=conv(maxpool1_, CONV2_W, CONV2_b, c_o=256, s_h=1, s_w=1, padding="SAME", group=2)#27*27*256
        relu2_=tf.nn.relu(conv2_)
        norm2=lrn(relu2_)
        maxpool2_=maxpool(norm2)
        
    #layer_3 conv3-relu
    with tf.name_scope('layer_3'):
        CONV3_W,CONV3_b=tf.Variable(net_data['conv3'][0],name='conv3_w',trainable=variable_trable[3]),\
        tf.Variable(net_data['conv3'][1],name='conv3_b',trainable=variable_trable[3])
        conv3_=conv(maxpool2_, CONV3_W, CONV3_b, c_o=384, s_h=1, s_w=1, padding="SAME", group=1)#13*13*384
        relu3_=tf.nn.relu(conv3_)
        
    #layer_4 conv4-relu
    with tf.name_scope('layer_4'):
        CONV4_W,CONV4_b=tf.Variable(net_data['conv4'][0],name='conv4_w',trainable=variable_trable[4]), \
        tf.Variable(net_data['conv4'][1],name='conv4_b',trainable=variable_trable[4])
        conv4_=conv(relu3_, CONV4_W, CONV4_b, c_o=384, s_h=1, s_w=1, padding="SAME", group=2)#13*13*384
        relu4_=tf.nn.relu(conv4_)
    
    #layer_5 conv5-relu-maxpool
    with tf.name_scope('layer_5'):
        CONV5_W,CONV5_b=tf.Variable(net_data['conv5'][0],name='conv5_w',trainable=variable_trable[5]), \
        tf.Variable(net_data['conv5'][1],name='conv5_b',trainable=variable_trable[5])
        conv5_=conv(relu4_, CONV5_W, CONV5_b, c_o=256, s_h=1, s_w=1, padding="SAME", group=2)
        relu5_=tf.nn.relu(conv5_)
        maxpool5_=maxpool(relu5_)
        
    with tf.name_scope('layer_6'):
        floatten_input=tf.reshape(maxpool5_,[-1,9216])#N*9216
        floatten_input=tf.nn.dropout(x=floatten_input,keep_prob=keep_prob)
        fc6_w,fc6_b=tf.Variable(net_data['fc6'][0],name='fc6_w',trainable=variable_trable[6]), \
        tf.Variable(net_data['fc6'][1],name='fc7_b',trainable=variable_trable[6])
        fc6_=tf.matmul(floatten_input,fc6_w)+fc6_b
        relu6_=tf.nn.relu(fc6_)#N*4096
    with tf.name_scope('layer_7'):
        relu6_=tf.nn.dropout(x=relu6_,keep_prob=keep_prob)
        fc7_w,fc7_b=tf.Variable(net_data['fc7'][0],name='fc7_w',trainable=variable_trable[7]),\
        tf.Variable(net_data['fc7'][1],name='fc7_b',trainable=variable_trable[7])
        fc7_=tf.matmul(relu6_,fc7_w)+fc7_b
        relu7_=tf.nn.relu(fc7_)#N*4096
    with tf.name_scope('layer_8'):
#         fc8_w,fc8_b=tf.Variable(net_data['fc8'][0],name='fc8_w',trainable=variable_trable[8]), \
#         tf.Variable(net_data['fc8'][1],name='fc8_b',trainable=variable_trable[8])
        relu7_=tf.nn.dropout(x=relu7_,keep_prob=keep_prob)
        #最后一层fc层必须要重新训练
        fc8_w=tf.Variable(tf.truncated_normal(shape=[4096,5],stddev=0.01),dtype=tf.float32,name='fc8_w',\
                          trainable=variable_trable[8])
        fc8_b=tf.Variable(tf.zeros(shape=[5]),dtype=tf.float32,name='fc8_b',trainable=variable_trable[8])
        fc8_=tf.matmul(relu7_,fc8_w)+fc8_b#N*1000
    return fc8_
    
    

In [6]:
reg_pen=tf.reduce_sum([tf.reduce_sum(tf.square(i)) for i in tf.trainable_variables()])

In [7]:
netdata=load_model_weight_and_biases()
with tf.name_scope('input'):
    X=tf.placeholder(dtype=tf.float32,shape=[None,227,227,3])
    Y=tf.placeholder(dtype=tf.float32,shape=[None,5])
    KEEP_PROB=tf.placeholder(dtype=tf.float32)
    LEARNRATE=tf.placeholder(dtype=tf.float32)
with tf.name_scope('predict'):
    y_pre=alexnet(X,netdata,KEEP_PROB)
    prob=tf.nn.softmax(y_pre)
#loss
with tf.name_scope('loss'):
    loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_pre,labels=Y))
    loss+=1e-4*reg_pen
with tf.name_scope('trainer'):
    trainer=tf.train.AdamOptimizer(1e-3).minimize(loss)
with tf.name_scope('accuracy'):
    acc_c=tf.equal(tf.arg_max(y_pre,1),tf.arg_max(Y,1))
    accuracy=tf.reduce_mean(tf.cast(x=acc_c,dtype=tf.float32))
sess=tf.InteractiveSession()
#writer=tf.summary.FileWriter('./mylog',sess.graph)
init=tf.global_variables_initializer()
writer=tf.summary.FileWriter(r'D:\Data warehouse\temp_dump\mylog\train_fc7tofc8')
writer_=tf.summary.FileWriter(r'D:\Data warehouse\temp_dump\mylog\test_fc7tofc8')
tf.summary.scalar('loss',loss)
tf.summary.scalar('accuracy',accuracy)
merge=tf.summary.merge_all()

In [8]:
import pickle
fp=open(r'D:\Data warehouse\5 flower\flower_photos\flower.pkl','rb')
flower_dict=pickle.load(fp)

In [9]:
TR_IMG,TE_IMG=flower_dict['train']['image'].astype(np.float32),flower_dict['test']['image'].astype(np.float32)
TR_LAB,TE_LAB=flower_dict['train']['label'].astype(np.float32),flower_dict['test']['label'].astype(np.float32)

In [10]:
del flower_dict

In [11]:
kp_te,kp_tr=1.,0.5
sess.run(init)
lr=1e-3
for i in range(1000):
    mask=np.random.choice(TR_IMG.shape[0],128,replace=False)
    x_,y_=TR_IMG[mask],TR_LAB[mask]
    
    loss_,acc_,m_=sess.run([loss,accuracy,merge],feed_dict={X:x_,Y:y_,KEEP_PROB:kp_tr,LEARNRATE:lr})
    writer.add_summary(m_)
    print('epoch:{},loss:{},train accuracy:{}'.format(i,loss_,acc_))
    for j in range(10):
        sess.run(trainer,feed_dict={X:x_,Y:y_,KEEP_PROB:kp_tr,LEARNRATE:lr})
    if i%5==0:
        mask=np.random.choice(TE_IMG.shape[0],128,replace=False)
        x_,y_=TE_IMG[mask],TE_LAB[mask]
        loss_,acc_,m_=sess.run([loss,accuracy,merge],feed_dict={X:x_,Y:y_,KEEP_PROB:kp_te})
        writer_.add_summary(m_)
        if acc_>0.8:
            lr=max(0.99*lr,1e-5)
            print ('epoch {},learning rate:{}'.format(i,lr))
        print('--epoch:{},loss:{},test accuracy:{}'.format(i,loss_,acc_))

epoch:0,loss:2.0476887226104736,train accuracy:0.28125
--epoch:0,loss:2.6251590251922607,test accuracy:0.6953125
epoch:1,loss:2.828031539916992,train accuracy:0.6796875
epoch:2,loss:2.522392749786377,train accuracy:0.703125
epoch:3,loss:1.5997705459594727,train accuracy:0.765625
epoch:4,loss:2.030029058456421,train accuracy:0.703125
epoch:5,loss:1.185170292854309,train accuracy:0.8203125
--epoch:5,loss:0.8891972303390503,test accuracy:0.75
epoch:6,loss:1.7880452871322632,train accuracy:0.7421875
epoch:7,loss:0.7682040929794312,train accuracy:0.828125
epoch:8,loss:1.1914187669754028,train accuracy:0.78125
epoch:9,loss:1.07383131980896,train accuracy:0.8046875
epoch:10,loss:0.8696883916854858,train accuracy:0.828125
--epoch:10,loss:0.9965732097625732,test accuracy:0.7734375
epoch:11,loss:0.8289076089859009,train accuracy:0.828125
epoch:12,loss:1.2061749696731567,train accuracy:0.7734375
epoch:13,loss:0.7513988018035889,train accuracy:0.8359375
epoch:14,loss:0.9110493659973145,train accur

epoch 110,learning rate:0.0008514577710948754
--epoch:110,loss:1.0242254734039307,test accuracy:0.859375
epoch:111,loss:0.38196420669555664,train accuracy:0.9296875
epoch:112,loss:0.15797176957130432,train accuracy:0.9609375
epoch:113,loss:0.1037834882736206,train accuracy:0.9765625
epoch:114,loss:0.2536734938621521,train accuracy:0.921875
epoch:115,loss:0.2950151562690735,train accuracy:0.9375
epoch 115,learning rate:0.0008429431933839266
--epoch:115,loss:0.7152047753334045,test accuracy:0.8671875
epoch:116,loss:0.22080688178539276,train accuracy:0.953125
epoch:117,loss:0.6305350065231323,train accuracy:0.9140625
epoch:118,loss:0.30261728167533875,train accuracy:0.921875
epoch:119,loss:0.3490713834762573,train accuracy:0.9296875
epoch:120,loss:0.29002442955970764,train accuracy:0.9375
epoch 120,learning rate:0.0008345137614500873
--epoch:120,loss:0.9771071076393127,test accuracy:0.859375
epoch:121,loss:0.238882377743721,train accuracy:0.9453125
epoch:122,loss:0.3811500370502472,train 

epoch 215,learning rate:0.000710553227272292
--epoch:215,loss:1.1497387886047363,test accuracy:0.8359375
epoch:216,loss:0.06985322386026382,train accuracy:0.9765625
epoch:217,loss:0.17936383187770844,train accuracy:0.9765625
epoch:218,loss:0.11962266266345978,train accuracy:0.96875
epoch:219,loss:0.15759089589118958,train accuracy:0.9765625
epoch:220,loss:0.5550674200057983,train accuracy:0.9453125
epoch 220,learning rate:0.000703447694999569
--epoch:220,loss:0.6690495014190674,test accuracy:0.8515625
epoch:221,loss:0.1216120645403862,train accuracy:0.96875
epoch:222,loss:0.14704912900924683,train accuracy:0.9765625
epoch:223,loss:0.08323919773101807,train accuracy:0.9765625
epoch:224,loss:0.06535838544368744,train accuracy:0.984375
epoch:225,loss:0.1736755222082138,train accuracy:0.9609375
epoch 225,learning rate:0.0006964132180495733
--epoch:225,loss:1.3014391660690308,test accuracy:0.828125
epoch:226,loss:0.149857759475708,train accuracy:0.9765625
epoch:227,loss:0.5378906726837158,t

epoch:319,loss:0.1452026665210724,train accuracy:0.984375
epoch:320,loss:0.16614164412021637,train accuracy:0.9765625
epoch 320,learning rate:0.0005811664141181095
--epoch:320,loss:1.6058099269866943,test accuracy:0.8203125
epoch:321,loss:0.013082778081297874,train accuracy:0.9921875
epoch:322,loss:0.26342007517814636,train accuracy:0.9609375
epoch:323,loss:0.011431097984313965,train accuracy:0.9921875
epoch:324,loss:0.33977171778678894,train accuracy:0.9609375
epoch:325,loss:0.5407716035842896,train accuracy:0.96875
epoch 325,learning rate:0.0005753547499769285
--epoch:325,loss:1.0707999467849731,test accuracy:0.8359375
epoch:326,loss:0.004900130443274975,train accuracy:1.0
epoch:327,loss:0.1463104635477066,train accuracy:0.9765625
epoch:328,loss:0.6144694089889526,train accuracy:0.953125
epoch:329,loss:0.1759500354528427,train accuracy:0.9609375
epoch:330,loss:0.12241195142269135,train accuracy:0.96875
epoch 330,learning rate:0.0005696012024771592
--epoch:330,loss:1.3399865627288818,

epoch:421,loss:0.22764043509960175,train accuracy:0.984375
epoch:422,loss:0.004962262697517872,train accuracy:1.0
epoch:423,loss:0.3769400417804718,train accuracy:0.9765625
epoch:424,loss:0.10972348600625992,train accuracy:0.984375
epoch:425,loss:0.08770225197076797,train accuracy:0.9765625
epoch 425,learning rate:0.0004705866415856499
--epoch:425,loss:1.3638253211975098,test accuracy:0.8984375
epoch:426,loss:0.24507248401641846,train accuracy:0.9765625
epoch:427,loss:9.83139470918104e-05,train accuracy:1.0
epoch:428,loss:0.08447979390621185,train accuracy:0.9921875
epoch:429,loss:0.09348877519369125,train accuracy:0.9765625
epoch:430,loss:0.14603130519390106,train accuracy:0.984375
epoch 430,learning rate:0.0004658807751697934
--epoch:430,loss:2.5096983909606934,test accuracy:0.8359375
epoch:431,loss:0.3194428086280823,train accuracy:0.9609375
epoch:432,loss:0.11442673206329346,train accuracy:0.984375
epoch:433,loss:0.19478538632392883,train accuracy:0.96875
epoch:434,loss:0.029612444

epoch:525,loss:0.10310769826173782,train accuracy:0.9765625
epoch 525,learning rate:0.0003848960788934845
--epoch:525,loss:1.829229474067688,test accuracy:0.8671875
epoch:526,loss:0.1753864288330078,train accuracy:0.9765625
epoch:527,loss:0.048099640756845474,train accuracy:0.984375
epoch:528,loss:0.003127384465187788,train accuracy:1.0
epoch:529,loss:0.009235131554305553,train accuracy:0.9921875
epoch:530,loss:0.1949600875377655,train accuracy:0.9765625
epoch 530,learning rate:0.00038104711810454966
--epoch:530,loss:2.8609447479248047,test accuracy:0.8203125
epoch:531,loss:0.28448042273521423,train accuracy:0.9765625
epoch:532,loss:0.1332232654094696,train accuracy:0.9921875
epoch:533,loss:0.017975935712456703,train accuracy:0.9921875
epoch:534,loss:0.04766406863927841,train accuracy:0.984375
epoch:535,loss:0.03194030746817589,train accuracy:0.9921875
epoch 535,learning rate:0.00037723664692350416
--epoch:535,loss:2.954306125640869,test accuracy:0.8359375
epoch:536,loss:0.059475116431

epoch:628,loss:0.11744370311498642,train accuracy:0.9765625
epoch:629,loss:0.19622330367565155,train accuracy:0.984375
epoch:630,loss:0.11769367754459381,train accuracy:0.984375
--epoch:630,loss:3.8160436153411865,test accuracy:0.78125
epoch:631,loss:0.020655330270528793,train accuracy:0.9921875
epoch:632,loss:0.039888784289360046,train accuracy:0.9921875
epoch:633,loss:0.004210059996694326,train accuracy:1.0
epoch:634,loss:0.05061301589012146,train accuracy:0.9921875
epoch:635,loss:0.5598135590553284,train accuracy:0.984375
epoch 635,learning rate:0.0003116610814491426
--epoch:635,loss:3.3699593544006348,test accuracy:0.8203125
epoch:636,loss:5.411079473560676e-05,train accuracy:1.0
epoch:637,loss:0.04343274608254433,train accuracy:0.9921875
epoch:638,loss:0.21525804698467255,train accuracy:0.9921875
epoch:639,loss:0.07236488163471222,train accuracy:0.984375
epoch:640,loss:0.09733662009239197,train accuracy:0.984375
epoch 640,learning rate:0.00030854447063465116
--epoch:640,loss:3.617

epoch:731,loss:0.001699978020042181,train accuracy:1.0
epoch:732,loss:0.3918132781982422,train accuracy:0.9765625
epoch:733,loss:0.14611706137657166,train accuracy:0.9765625
epoch:734,loss:0.009467218071222305,train accuracy:0.9921875
epoch:735,loss:0.0036882830318063498,train accuracy:1.0
epoch 735,learning rate:0.00025490976069630927
--epoch:735,loss:2.095144033432007,test accuracy:0.859375
epoch:736,loss:0.08720417320728302,train accuracy:0.9921875
epoch:737,loss:0.388663649559021,train accuracy:0.984375
epoch:738,loss:1.8626440834168534e-08,train accuracy:1.0
epoch:739,loss:2.600793777673971e-06,train accuracy:1.0
epoch:740,loss:0.0007780772284604609,train accuracy:1.0
epoch 740,learning rate:0.00025236066308934616
--epoch:740,loss:1.9450596570968628,test accuracy:0.8984375
epoch:741,loss:0.2499639391899109,train accuracy:0.9765625
epoch:742,loss:0.035179127007722855,train accuracy:0.9921875
epoch:743,loss:0.04826473444700241,train accuracy:0.9921875
epoch:744,loss:0.09219756722450

epoch:836,loss:0.3257191479206085,train accuracy:0.9765625
epoch:837,loss:0.031184932217001915,train accuracy:0.9921875
epoch:838,loss:0.3203675150871277,train accuracy:0.96875
epoch:839,loss:0.016256945207715034,train accuracy:0.9921875
epoch:840,loss:0.33585402369499207,train accuracy:0.984375
epoch 840,learning rate:0.00020849246173476125
--epoch:840,loss:3.854454278945923,test accuracy:0.828125
epoch:841,loss:3.368236139067449e-05,train accuracy:1.0
epoch:842,loss:2.7640015105134808e-05,train accuracy:1.0
epoch:843,loss:9.313225191043273e-10,train accuracy:1.0
epoch:844,loss:0.01444405410438776,train accuracy:0.9921875
epoch:845,loss:0.24671274423599243,train accuracy:0.9921875
epoch 845,learning rate:0.00020640753711741362
--epoch:845,loss:2.369833469390869,test accuracy:0.8515625
epoch:846,loss:0.04567420855164528,train accuracy:0.9921875
epoch:847,loss:0.05051787942647934,train accuracy:0.984375
epoch:848,loss:0.03864132612943649,train accuracy:0.984375
epoch:849,loss:0.15173999

epoch:941,loss:0.003450306598097086,train accuracy:1.0
epoch:942,loss:0.0005844201077707112,train accuracy:1.0
epoch:943,loss:0.0003197983023710549,train accuracy:1.0
epoch:944,loss:0.00023654369579162449,train accuracy:1.0
epoch:945,loss:0.04183729737997055,train accuracy:0.9921875
epoch 945,learning rate:0.00017052743088958637
--epoch:945,loss:6.325770854949951,test accuracy:0.8046875
epoch:946,loss:0.40898868441581726,train accuracy:0.9921875
epoch:947,loss:0.2204410284757614,train accuracy:0.984375
epoch:948,loss:2.8959813789697364e-05,train accuracy:1.0
epoch:949,loss:0.18559543788433075,train accuracy:0.96875
epoch:950,loss:0.0818776786327362,train accuracy:0.9921875
epoch 950,learning rate:0.0001688221565806905
--epoch:950,loss:4.270170211791992,test accuracy:0.8359375
epoch:951,loss:0.2550261318683624,train accuracy:0.984375
epoch:952,loss:1.728393044686527e-06,train accuracy:1.0
epoch:953,loss:0.105906642973423,train accuracy:0.9921875
epoch:954,loss:0.4919532835483551,train a

In [None]:
writer.close()
writer_.close()