In [None]:
import sys
import numpy as np
caffe_root = '/opt/caffe/'  # this file should be run from {caffe_root}/examples (otherwise change this line)
sys.path.insert(0, caffe_root + 'python')
import caffe

caffe.set_mode_gpu()
caffe.set_device(0)

model_def = './deploy_8inceptionv2_voting.prototxt' #模型定义
model_weights = './8_inceptionv2_ave.caffemodel'  #模型参数
mean = [117, 117, 117] #模型训练时的mean
sacle = 0.0078125 #模型训练时的scale
height = 224 #模型输入图片高度
width = 224 #模型输入宽度
store_size = 1000 #保存的batch大小
label_size = 17 #label长度
teacher_logits_size = 61 #模型输出的长度
output_layer_name = 'teacher_logits' #模型输出层名字
filename = './train.txt' #训练集文件名

In [None]:
# 加载teacher model
net = caffe.Net(model_def,      # defines the structure of the model
                model_weights,  # contains the trained weights
                caffe.TEST)     # use test mode (e.g., don't perform dropout)

net.blobs['data'].reshape(1,        # batch size
                          3,         # 3-channel (RGB) images
                          height, width)  # image size is 224 x 224

In [None]:
# caffe 图片的预处理
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
transformer.set_transpose('data', (2,0,1))
mu = np.array(mean)
transformer.set_raw_scale('data', 255)
transformer.set_mean('data', mu)
transformer.set_channel_swap('data', (2,1,0))  # swap channels from RGB to BGR

In [None]:
#读训练集

with open(filename) as f:
    content = f.readlines()
# you may also want to remove whitespace characters like `\n` at the end of each line
lines = [x.strip().split() for x in content]
tail_len = len(lines) % store_size
total_h5_file_size =len(lines) / store_size + 1

In [None]:
#遍历训练集，保存需要的teacher模型输出，以及原图片（可选）

import h5py, os
from tqdm import tqdm

filename = './train.txt'
with open(filename) as f:
    content = f.readlines()
# you may also want to remove whitespace characters like `\n` at the end of each line
lines = [x.strip().split() for x in content]

# If you do not have enough memory split data into
# multiple batches and generate multiple separate h5 files
image = np.zeros((store_size, 3, height, width), dtype='f4' )
teacher_logits = np.zeros((store_size, teacher_logits_size), dtype='f4' )
ground_truth = np.zeros((store_size, label_size), dtype='f4' )
for i,l in enumerate(lines):
    if i > 0 and i % store_size == 0:
        print i/store_size
        with h5py.File('./distilled_data/train_' + str(i/store_size) + '.h5','w') as H:
            H.create_dataset('image', data=image) # note the name X given to the dataset!
            H.create_dataset('teacher_logits', data=teacher_logits) # note the name y given to the dataset!
            H.create_dataset('ground_truth', data=ground_truth) # also save the ground truth
            print 'train_glass_' + str(i/store_size) + '.h5'
    img = caffe.io.load_image(l[0])
    ground_truth[i%store_size] = l[1:]
    img = transformer.preprocess('data', img) * scale
    image[i%store_size] = img
    net.blobs['data'].data[...] = img
    output = net.forward()
    output_logits = output[output_layer_name][0]
    teacher_logits[i%store_size] = output_logits

#处理tail部分
tail_teacher_logits = teacher_logits[:tail_len]
tail_ground_truth = ground_truth[:tail_len]
tail_images = image[:tail_len]

with h5py.File('./distilled_data/train_' + total_h5_file_size + '.h5','w') as H:
    H.create_dataset('image', data=tail_images) # note the name X given to the dataset!
    H.create_dataset('teacher_logits', data=tail_teacher_logits) # note the name y given to the dataset!
    H.create_dataset('ground_truth', data=tail_ground_truth) # also save the ground truth
    
# 写一个h5的list文件，方便caffe训练时读入
with open('./distilled_data/train_h5_list.txt','w') as L:
    for i in range(1, total_h5_file_size+1):
        L.write( '/train/execute/distillation/distilled_data/train_' + str(i) + '.h5' + '\n' ) # list all h5 files you are going to use