## 准备数据
- 基本数据-原图 data\stone\raw\image\\*.png
- 基本数据-标签 data\stone\raw\label\\*.png
- 基本数据-测试 data\stone\raw\test\\*.png


## 路径设置


In [None]:
import os
import cv2

abspath = os.path.abspath('.')

raw_image_path = os.path.join(abspath, 'data', 'stone', 'raw', 'image')
raw_label_path = os.path.join(abspath, 'data', 'stone', 'raw', 'label')
raw_test_path  = os.path.join(abspath, 'data', 'stone', 'raw', 'test')

gray_path       = os.path.join(abspath, 'data', 'stone', 'gray')
gray_image_path = os.path.join(abspath, 'data', 'stone', 'gray', 'image')
gray_label_path = os.path.join(abspath, 'data', 'stone', 'gray', 'label')

train_path = os.path.join(abspath, 'data', 'stone', 'train')
train_image_path = os.path.join(train_path, 'image')
train_label_path = os.path.join(train_path, 'label')

test_image_path = os.path.join(abspath, 'data', 'stone', 'test', 'image')
test_label_path = os.path.join(abspath, 'data', 'stone', 'test', 'label')

all_path = [raw_image_path, raw_label_path, 
            gray_path, gray_image_path, gray_label_path, 
            train_path, train_image_path, train_label_path,
            test_image_path, test_label_path]

for path in all_path:
    if os.path.exists(path):
        pass
    else:
        os.makedirs(path)


## 转换为8位灰度图

In [None]:
def color_to_gray(path_src, path_dst):
    for root, dirs, files in os.walk(path_src):
        for file_name in files:
            file_path = os.path.join(root, file_name)
            img = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
            new_file_path = os.path.join(path_dst, file_name)
            cv2.imwrite(new_file_path, img)

color_to_gray(raw_image_path, gray_image_path)
color_to_gray(raw_label_path, gray_label_path)
color_to_gray(raw_test_path, test_image_path)

## 数据增强，生成数据存放在 train_path

In [None]:
from data import *

data_gen_args = dict(rotation_range=0.2,
                    width_shift_range=0.05,
                    height_shift_range=0.05,
                    shear_range=0.05,
                    zoom_range=0.05,
                    horizontal_flip=True,
                    fill_mode='nearest')
myGenerator = trainGenerator(20, gray_path, 'image', 'label', data_gen_args, save_to_dir=train_path)
num_batch = 3
for i,batch in enumerate(myGenerator):
    if(i >= num_batch):
        break

## 调整训练集的位置

In [None]:
import shutil

for file_name in os.listdir(train_path):
    if file_name.startswith('image_'):
        shutil.move(os.path.join(train_path, file_name), train_image_path)
    elif file_name.startswith('mask_'):
        shutil.move(os.path.join(train_path, file_name), train_label_path)
    else:
        pass


## 开始训练

In [None]:
from model import *
from data import *


In [None]:
data_gen_args = dict(rotation_range=0.2,
                    width_shift_range=0.05,
                    height_shift_range=0.05,
                    shear_range=0.05,
                    zoom_range=0.05,
                    horizontal_flip=True,
                    fill_mode='nearest')
myGene = trainGenerator(2, train_path, 'image', 'label', data_gen_args, save_to_dir = None)
model = unet()
model_checkpoint = ModelCheckpoint('unet_membrane.hdf5', monitor='loss',verbose=1, save_best_only=True)


In [None]:
model.fit_generator(myGene,steps_per_epoch=2000,epochs=5,callbacks=[model_checkpoint])

## Train with npy file

In [None]:
#imgs_train,imgs_mask_train = geneTrainNpy("data/membrane/train/aug/","data/membrane/train/aug/")
#model.fit(imgs_train, imgs_mask_train, batch_size=2, nb_epoch=10, verbose=1,validation_split=0.2, shuffle=True, callbacks=[model_checkpoint])

## test your model and save predicted results

In [None]:
testGene = testGenerator("data/membrane/test")
model = unet()
model.load_weights("unet_membrane.hdf5")
results = model.predict_generator(testGene,30,verbose=1)
saveResult("data/membrane/test",results)