Skip to content

Commit

Permalink
Merge pull request #2 from TreezzZ/master
Browse files Browse the repository at this point in the history
Fix JMLH bugs
  • Loading branch information
ymcidence committed Apr 3, 2020
2 parents 4be0237 + 2cccdc2 commit 778dd1c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions model/jmlh.py
Expand Up @@ -17,13 +17,13 @@ def call(self, inputs, training=None, mask=None):
batch_size = tf.shape(inputs[1])[0]
fc_1 = self.fc_1(inputs[1])
eps = tf.ones([batch_size, self.bbn_dim]) / 2.
code = binary_activation.binary_activation(fc_1, eps)
code, _ = binary_activation.binary_activation(fc_1, eps)
cls = self.fc_2(code)
return code, tf.nn.sigmoid(fc_1), cls

def run(self, feat_in):
batch_size = tf.shape(feat_in)[0]
fc_1 = self.fc_1(feat_in)
eps = tf.ones([batch_size, self.bbn_dim]) / 2.
code = binary_activation.binary_activation(fc_1, eps)
code, _ = binary_activation.binary_activation(fc_1, eps)
return code
2 changes: 1 addition & 1 deletion train/jmlh_train.py
Expand Up @@ -26,7 +26,7 @@ def train_step(model: JMLH, batch_data, opt: tf.optimizers.Optimizer):
model_input = batch_data
code, prob, cls_prob = model(model_input, training=True)

loss = jmlh_loss(prob, cls_prob, label=batch_data[3])
loss = jmlh_loss(prob, cls_prob, label=batch_data[2])

gradient = tape.gradient(loss, sources=model.trainable_variables)
opt.apply_gradients(zip(gradient, model.trainable_variables))
Expand Down

0 comments on commit 778dd1c

Please sign in to comment.