Skip to content

Commit

Permalink
fix the problem that network config with the rectifier activation fun…
Browse files Browse the repository at this point in the history
…ction cannot be saved
  • Loading branch information
yajiemiao committed Mar 11, 2015
1 parent 600cbe5 commit 6e6e840
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 1 deletion.
1 change: 1 addition & 0 deletions cmds/run_Extract_Feats.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
numpy_rng = numpy.random.RandomState(89677)
theano_rng = RandomStreams(numpy_rng.randint(2 ** 30))
cfg = cPickle.load(open(nnet_cfg,'r'))
cfg.init_activation()
model = None
if cfg.model_type == 'DNN':
model = DNN(numpy_rng=numpy_rng, theano_rng = theano_rng, cfg = cfg)
Expand Down
1 change: 1 addition & 0 deletions cmds2/run_FeatExt_Kaldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@

# load network configuration
cfg = cPickle.load(open(nnet_cfg,'r'))
cfg.init_activation()

# set up the model with model config
log('> ... setting up the model and loading parameters')
Expand Down
3 changes: 3 additions & 0 deletions io_func/model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def _cfg2file(cfg, filename='cfg.out'):
cfg.lrate = None
cfg.train_sets = None; cfg.train_xy = None; cfg.train_x = None; cfg.train_y = None
cfg.valid_sets = None; cfg.valid_xy = None; cfg.valid_x = None; cfg.valid_y = None
cfg.activation = None # saving the rectifier function causes errors; thus we don't save the activation function
# the activation function is initialized from the activation text ("sigmoid") when the network
# configuration is loaded
with open(filename, "wb") as output:
cPickle.dump(cfg, output, cPickle.HIGHEST_PROTOCOL)

Expand Down
4 changes: 4 additions & 0 deletions utils/network_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ def init_data_reading_test(self, data_spec):
dataset, dataset_args = read_data_args(data_spec)
self.test_sets, self.test_xy, self.test_x, self.test_y = read_dataset(dataset, dataset_args)

# initialize the activation function
def init_activation(self):
self.activation = parse_activation(self.activation_text)

def parse_config_common(self, arguments):
# parse batch_size, momentum and learning rate
if arguments.has_key('batch_size'):
Expand Down
6 changes: 5 additions & 1 deletion utils/rbm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def __init__(self):
def init_data_reading(self, train_data_spec):
train_dataset, train_dataset_args = read_data_args(train_data_spec)
self.train_sets, self.train_xy, self.train_x, self.train_y = read_dataset(train_dataset, train_dataset_args)


# initialize the activation function
def init_activation(self):
self.activation = parse_activation(self.activation_text)

# parse the arguments to get the values for various variables
def parse_config_common(self, arguments):
if arguments.has_key('gbrbm_learning_rate'):
Expand Down
4 changes: 4 additions & 0 deletions utils/sda_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def init_data_reading(self, train_data_spec):
train_dataset, train_dataset_args = read_data_args(train_data_spec)
self.train_sets, self.train_xy, self.train_x, self.train_y = read_dataset(train_dataset, train_dataset_args)

# initialize the activation function
def init_activation(self):
self.activation = parse_activation(self.activation_text)

# parse the arguments to get the values for various variables
def parse_config_common(self, arguments):
if arguments.has_key('corruption_level'):
Expand Down

0 comments on commit 6e6e840

Please sign in to comment.