Skip to content

Commit

Permalink
conf
Browse files Browse the repository at this point in the history
  • Loading branch information
pchavanne committed Dec 10, 2016
1 parent 7614c21 commit 0abdb8c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 15 deletions.
19 changes: 7 additions & 12 deletions examples/updates_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,12 @@
with open('model.json', 'w') as f:
json.dump(model.to_json(),f)

for layer in layers:
l = getattr(yadll.layers, 'InputLayer')(vars(model.network.layers[0]))


class A(object):
def __init__(self, a, b=1, **kwargs):
self.a=a
self.b=b

def p(self):
print self.a

net.load_params('net_params.yp')
model_2.network.load_params('net_params.yp')
predicted_values_2 = model_2.predict(test_set_x[:30])
print predicted_values_2
print test_set_y[:30]

model_2.data = data

model_2.train()
10 changes: 8 additions & 2 deletions yadll/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
from theano.tensor.signal import pool
from theano.tensor.nnet import conv

import yadll
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -213,6 +213,8 @@ def __init__(self, incoming, nb_units, W=glorot_uniform, b=constant,
self.b = initializer(b, shape=(self.shape[1],), name='b')
self.params.append(self.b)
self.activation = activation
if isinstance(activation, basestring):
self.activation = getattr(yadll.activation, activation)
self.l1 = l1
self.l2 = l2
if l1 and l1 != 0:
Expand Down Expand Up @@ -283,7 +285,11 @@ def __init__(self, incoming, nb_class, W=constant, activation=softmax, **kwargs)
super(LogisticRegression, self).__init__(incoming, nb_class, W=W,
activation=activation, **kwargs)

def
def to_conf(self):
conf = super(LogisticRegression, self).to_conf()
conf['nb_class'] = conf.pop('nb_units')
return conf


class Dropout(Layer):
"""
Expand Down
2 changes: 1 addition & 1 deletion yadll/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def from_conf(self, conf):
self.network = yadll.network.Network()
self.network.from_conf(_conf['network'])
self.updates = getattr(yadll.updates, _conf['updates'])
self.report = _conf['file']
self.report = _conf['report']
self.file = _conf['file']
pass

Expand Down

0 comments on commit 0abdb8c

Please sign in to comment.