Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NadamAccum does not support Keras v2 #2

Closed
vfdev-5 opened this issue Jun 6, 2017 · 1 comment
Closed

NadamAccum does not support Keras v2 #2

vfdev-5 opened this issue Jun 6, 2017 · 1 comment

Comments

@vfdev-5
Copy link

vfdev-5 commented Jun 6, 2017

Hi,

I tried NadamAccum with Keras v2 and there is some minor adaptation is needed :

  • Remove useless imports
  • Add names to variables similar to Keras Nadam code
  • Cast to float accum_iters
diff --git a/optimizers.py b/optimizers.py
index 3ab053e..b168255 100644
--- a/optimizers.py
+++ b/optimizers.py
@@ -1,7 +1,5 @@
 from __future__ import absolute_import
 from keras import backend as K
-import numpy as np
-from keras.utils.generic_utils import get_from_module
 from keras.optimizers import Optimizer
 from six.moves import zip
 
@@ -28,26 +26,28 @@ class NadamAccum(Optimizer):
                  epsilon=1e-8, schedule_decay=0.004, accum_iters=1, **kwargs):
         super(NadamAccum, self).__init__(**kwargs)
         self.__dict__.update(locals())
-        self.iterations = K.variable(0.)
-        self.m_schedule = K.variable(1.)
-        self.lr = K.variable(lr)
-        self.beta_1 = K.variable(beta_1)
-        self.beta_2 = K.variable(beta_2)
+        self.iterations = K.variable(0., name='iterations')
+        self.m_schedule = K.variable(1., name='m_schedule')
+        self.lr = K.variable(lr, name='lr')
+        self.beta_1 = K.variable(beta_1, name='beta_1')
+        self.beta_2 = K.variable(beta_2, name='beta_2')
         self.schedule_decay = schedule_decay
-        self.accum_iters = K.variable(accum_iters)
+        self.epsilon = epsilon
+        self.accum_iters = K.variable(accum_iters, name='accum_iters')
 
     def get_updates(self, params, constraints, loss):
         grads = self.get_gradients(loss, params)
         self.updates = [K.update_add(self.iterations, 1)]
 
         t = (self.iterations + 1.)/self.accum_iters
-        accum_switch = K.equal((self.iterations + 1.) % self.accum_iters, 0)
+        accum_switch = K.cast(K.equal((self.iterations + 1.) % self.accum_iters, 0), dtype=K.floatx())
+
         # Due to the recommendations in [2], i.e. warming momentum schedule
         momentum_cache_t = self.beta_1 * (1. - 0.5 * (K.pow(0.96, t * self.schedule_decay)))
         momentum_cache_t_1 = self.beta_1 * (1. - 0.5 * (K.pow(0.96, (t + 1) * self.schedule_decay)))
         m_schedule_new = self.m_schedule * momentum_cache_t
         m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
-        self.updates.append((self.m_schedule, accum_switch*m_schedule_new + (1-accum_switch)*self.m_schedule))
+        self.updates.append((self.m_schedule, accum_switch*m_schedule_new + (1. - accum_switch)*self.m_schedule))
 
         shapes = [x.shape for x in K.batch_get_value(params)]
         ms = [K.zeros(shape) for shape in shapes]
@@ -67,9 +67,9 @@ class NadamAccum(Optimizer):
             v_t_prime = v_t / (1. - K.pow(self.beta_2, t))
             m_t_bar = (1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime
 
-            self.updates.append(K.update(m, (1-accum_switch)*m + accum_switch*m_t))
-            self.updates.append(K.update(v, (1-accum_switch)*v + accum_switch*v_t))
-            self.updates.append(K.update(ga, (1-accum_switch)*(ga + gp)))
+            self.updates.append(K.update(m, (1. - accum_switch)*m + accum_switch*m_t))
+            self.updates.append(K.update(v, (1. - accum_switch)*v + accum_switch*v_t))
+            self.updates.append(K.update(ga, (1. - accum_switch)*(ga + gp)))
 
             p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
             new_p = p_t
@@ -89,4 +89,4 @@ class NadamAccum(Optimizer):
                   'schedule_decay': self.schedule_decay,
                   'accum_iters': self.accum_iters}
         base_config = super(NadamAccum, self).get_config()
-        return dict(list(base_config.items()) + list(config.items()))
\ No newline at end of file
+        return dict(list(base_config.items()) + list(config.items()))

HTH

@the-moliver
Copy link
Owner

fixed thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants