Skip to content

Commit

Permalink
trival
Browse files Browse the repository at this point in the history
  • Loading branch information
yanzhicong committed Sep 14, 2018
1 parent 52bee4c commit c3005a1
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 14 deletions.
85 changes: 85 additions & 0 deletions cfgs/cla/mnist_test.json
@@ -0,0 +1,85 @@
{
"config name" : "mnist_classification",

"dataset" : "mnist",
"dataset params" : {
"semi-supervised" : true,
"labelled indices filepath" : "./mnist_temp_data_dir/method_1_9.pkl",
"output shape" : [28, 28, 1],
"batch_size" : 128
},

"assets dir" : "assets/mnist/tests",

"model" : "classification",
"model params" : {
"name" : "mnist",

"input shape" : [28, 28, 1],
"nb classes" : 10,

"optimizer" : "adam",
"optimizer params" : {
"lr" : 0.001,
"lr scheme" : "exponential",
"lr params" : {
"decay_steps" : 10000,
"decay_rate" : 0.1
},
"beta1" : 0.5,
"beta2" : 0.9
},
"classification loss" : "cross entropy",

"summary" : true,

"classifier" : "classifier",
"classifier params" : {
"batch_norm" : "fused_batch_norm",

"including_conv" : true,
"conv_nb_blocks" : 3,
"conv_nb_layers" : [2, 2, 2],
"conv_nb_filters" : [32, 64, 128],
"conv_nb_ksize" : [3, 3, 3],
"no_maxpooling" : true,

"including_top" : true,
"fc_nb_nodes" : [600, 600],

"output_dims" : 10,
"output_activation" : "none",

"debug" : true
}
},

"trainer" : "supervised",
"trainer params" : {

"summary hyperparams string" : "method_1_9",

"continue train" : false,
"multi thread" : true,
"batch_size" : 32,
"train steps" : 20000,
"summary steps" : 1000,
"log steps" : 100,
"save checkpoint steps" : 10000,

"debug" : true,

"validators" : [
{
"validator" : "dataset_validator",
"validate steps" : 1000,
"has summary" : true,
"validator params" : {
"metric" : "accuracy",
"metric type" : "top1"
}
}
]
}
}

14 changes: 9 additions & 5 deletions dataset/base_simple_dataset.py
Expand Up @@ -61,7 +61,6 @@ def __init__(self, config):
self.y_test = None
self.nb_classes = None


def build_dataset(self):
assert(self.x_train is not None and self.y_train is not None)
assert(self.x_test is not None and self.y_test is not None)
Expand All @@ -77,8 +76,7 @@ def build_dataset(self):
os.makedirs(self.extra_file_path)

# if semisupervised training, prepare labelled train set indices,
self.nb_labelled_images_per_class = self.config.get('nb_labelled_images_per_class', 100)
self.labelled_image_indices = self._get_labelled_image_indices(self.nb_labelled_images_per_class)
self.labelled_image_indices = self._get_labelled_image_indices()

# unlabelled train set
self.x_train_u = self.x_train
Expand All @@ -93,8 +91,14 @@ def build_dataset(self):
self.x_train_u = self.x_train


def _get_labelled_image_indices(self, nb_images_per_class):
pickle_filepath = os.path.join(self.extra_file_path, 'labelled_image_indices_%d.pkl'%nb_images_per_class)
def _get_labelled_image_indices(self):

if 'labelled indices filepath' in self.config:
pickle_filepath = self.config['labelled indices filepath']
else:
nb_images_per_class = self.config.get('nb_labelled_images_per_class', 100)
pickle_filepath = os.path.join(self.extra_file_path, 'labelled_image_indices_%d.pkl'%nb_images_per_class)

if os.path.exists(pickle_filepath):
return pickle.load(open(pickle_filepath, 'rb'))
else:
Expand Down
84 changes: 84 additions & 0 deletions train_batch.py
@@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
# MIT License
#
# Copyright (c) 2018 ZhicongYan
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# ==============================================================================

import os
import sys
import argparse
import time
from datetime import datetime
from shutil import copyfile

import tensorflow as tf

sys.path.append('./')
sys.path.append('./lib')
sys.path.append('../')

from cfgs.networkconfig import get_config
from dataset.dataset import get_dataset
from model.model import get_model
from trainer.trainer import get_trainer

parser = argparse.ArgumentParser(description='')
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--config_file', type=str, default='cvae1') # target config file, stored in ./cfgs
parser.add_argument('--disp_config', type=bool, default=False) # if there is error in config file, set True to print the config file with line number

args = parser.parse_args()

if __name__ == '__main__':
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
tf.reset_default_graph()

# load config file
config = get_config(args.config_file, args.disp_config)

# make the assets directory and copy the config file to it
# so if you want to reproduce the result in assets dir
# just copy the config_file.json to ./cfgs folder and run python3 train.py --config=(config_file)
if not os.path.exists(config['assets dir']):
os.makedirs(config['assets dir'])
cfg_filename = datetime.now().strftime('config_file_%y-%m-%d_%H-%M-%S.json')
copyfile(os.path.join('./cfgs', args.config_file + '.json'),
os.path.join(config['assets dir'], cfg_filename))

# prepare dataset
dataset = get_dataset(config['dataset'], config['dataset params'])

tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth = True

with tf.Session(config=tfconfig) as sess:

# build model
config['model params']['assets dir'] = config['assets dir']
model = get_model(config['model'], config['model params'])

# start training
config['trainer params']['assets dir'] = config['assets dir']
trainer = get_trainer(config['trainer'], config['trainer params'], model, sess)

trainer.train(sess, dataset, model)

13 changes: 8 additions & 5 deletions validator/base_validator.py
Expand Up @@ -58,28 +58,31 @@ def validate(self, model, dataset, sess, step):
def parallel_data_reading(self, dataset, indices, phase, method, buffer_depth, nb_threads=4):

self.t_should_stop = False

data_queue = queue.Queue(maxsize=buffer_depth)

def read_data_inner_loop(dataset, data_inner_queue, indices, t_ind, nb_threads):
def read_data_inner_loop(dataset, data_queue, indices, t_ind, nb_threads):
for i, ind in enumerate(indices):
if i % nb_threads == t_ind:
# read img and label by its index
img, label = dataset.read_image_by_index(ind, 'val', 'supervised')
if isinstance(img, list) and isinstance(label, list):
for _img, _label in zip(img, label):
data_inner_queue.put((img, label))
data_queue.put((img, label))
elif img is not None:
data_inner_queue.put((img, label))
data_queue.put((img, label))


def read_data_loop(indices, dataset, data_inner_queue, nb_threads):
def read_data_loop(indices, dataset, data_queue, nb_threads):
threads = [threading.Thread(target=read_data_inner_loop,
args=(dataset, data_inner_queue, indices, t_ind, nb_threads)) for t_ind in range(nb_threads)]
args=(dataset, data_queue, indices, t_ind, nb_threads)) for t_ind in range(nb_threads)]
for t in threads:
t.start()
for t in threads:
t.join()
self.t_should_stop = True


t = threading.Thread(target=read_data_loop, args=(indices, dataset, data_queue, nb_threads))
t.start()

Expand Down
8 changes: 4 additions & 4 deletions validator/dataset_validator.py
Expand Up @@ -100,19 +100,18 @@ def validate(self, model, dataset, sess, step):
nb_samples = np.minimum(len(indices), self.nb_samples)
indices = np.random.choice(indices, size=nb_samples, replace=False)


self.t_should_stop = False
t, data_queue = self.parallel_data_reading(dataset, indices, 'val', 'supervised', self.batch_size(self.buffer_depth))
t, data_queue = self.parallel_data_reading(dataset, indices, 'val', 'supervised', self.batch_size*self.buffer_depth)

batch_x = []
batch_y = []

while not self.t_should_stop:
while not self.t_should_stop or not data_queue.empty():
if not data_queue.empty():
img, label = data_queue.get()
batch_x.append(img)
batch_y.append(label)
else:
time.sleep(1)

if len(batch_x) == self.batch_size:
batch_p = model.predict(sess, np.array(batch_x))
Expand All @@ -128,6 +127,7 @@ def validate(self, model, dataset, sess, step):

t.join()


label_list = np.concatenate(label_list, axis=0)
pred_list = np.concatenate(pred_list, axis=0)

Expand Down

0 comments on commit c3005a1

Please sign in to comment.