Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SPSA optimizer #653

Merged
merged 39 commits into from Jan 20, 2022
Merged

Add SPSA optimizer #653

merged 39 commits into from Jan 20, 2022

Conversation

lockwo
Copy link
Contributor

@lockwo lockwo commented Jan 5, 2022

SPSA is a common optimizer that plays a prominent role in many quantum optimization procedures. It already exists in other popular packages such as qiskit. Although there are external implementations (e.g. noisyopt), providing a built in version offers a number of advantages (not the least of which is ease of use). The code is short and not complex (and shares the same format as the rotosolve optimizer). I would also like to answer two question which may arise (and I had when implementing):

Why SPSA?
There are a number of black box optimization algorithms; however, SPSA is probably one of most common in the field of quantum optimization. It is a go to optimizer for VQE (see https://www.nature.com/articles/nature23879) as it can be somewhat robust to noise and take O(1) forward passes independent of the number of parameters (unlike PS or other grad approaches). Additionally, the most common black box optimization package (scipy optimize minimize) does not include it.

Why as an optimizer?
Theoretically, the gradient approach SPSA uses could be implemented as a differentiator (just do the gradient estimation as a differentiation approach). However, I think there are two primary drawbacks to this. One, there is well documented knowledge about the hyperparameters to be used in SPSA (see https://ieeexplore.ieee.org/document/705889). If the gradient update was implemented as a differentiator, these would have to be incorporated into a TF optimizer somehow and would likely not be a friendly interface (adjusting how the differentiation is done, e.g. the perturbation parameter, during optimization external to the class doesn't sound good). Additionally, (and this may not actually be true, it is purely based on intuition and some toy examples), the noisy-ness of the gradient estimation (which is fine on average) may be problematic for hybrid systems (e.g. when using the gradients and backproping through) thus the optimizer more readily encourages non-hybrid systems.

I recognize this is an unsolicited (if you will) PR so if this isn't something that is a desired inclusion it's fine, but I have been using SPSA extensively in qiskit and I think TFQ would benefit from its inclusion. This PR is a simple implementation I drafted this afternoon but there is room for advancements (e.g. second order approximations).

@lockwo
Copy link
Contributor Author

lockwo commented Jan 5, 2022

I see there are some formatting problems, I will fix those shortly.

Copy link
Collaborator

@MichaelBroughton MichaelBroughton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great first pass! I think just a few tweaks here and there along with some more baseline test coverage to ensure the namedtuple results are behaving correctly and we should be good to merge.

Comment on lines 148 to 154
expectation_value_function: A Python callable that accepts
a point as a real `tf.Tensor` and returns a `tf.Tensor`s
of real dtype containing the value of the function.
The function to be minimized. The input is of shape `[n]`,
where `n` is the size of the trainable parameters.
The return value is a real `tf.Tensor` Scalar (matching shape
`[1]`). This must be a linear combination of quantum
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we tighten up the wording here to something like:

"Python callable that accepts a real valued tf.Tensor with shape [n] where n is the number of function parameters,....."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed

tensorflow_quantum/python/optimizers/spsa_minimizer.py Outdated Show resolved Hide resolved
Comment on lines 119 to 120
a=1.0,
c=1.0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of calling these "a" and "c", would we be able to do something like "lr" or "lr_scaling" just to be a little more descriptive names ? also might want to throw in a seed parameter for this optimizer since there is some randomness involved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to lr and perturb, respectively. Added seed parameter.

Comment on lines 136 to 145
>>> n = 10 # Number of sinusoids
>>> coefficient = tf.random.uniform(shape=[n])
>>> min_value = -tf.math.reduce_sum(tf.abs(coefficient))
>>> func = lambda x:tf.math.reduce_sum(tf.sin(x) * coefficient)
>>> # Optimize the function with SPSA, start with random parameters
>>> result = tfq.optimizers.SPSA_minimize(func, np.random.random(n))
>>> result.converged
tf.Tensor(True, shape=(), dtype=bool)
>>> result.objective_value
tf.Tensor(-4.7045116, shape=(), dtype=float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a unit test that covers the use case shown here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified the example and added the same format test case

tensorflow_quantum/python/optimizers/spsa_minimizer.py Outdated Show resolved Hide resolved
tensorflow_quantum/python/optimizers/spsa_minimizer.py Outdated Show resolved Hide resolved
class SPSAMinimizerTest(tf.test.TestCase, parameterized.TestCase):
"""Tests for the SPSA optimization algorithm."""

def test_nonlinear_function_optimization(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add three (or more) tests with simple learning functions (maybe more sinusoids?) to verify a few other key baseline functionalities:

  1. If failed to converge, does the number of function iterations in the result line up with the supplied arguments
  2. If failed to converge due to tolerance misses, is that in fact reflected in a case where convergence to the tolerance is very close
  3. do blocking and allow_increase interact correctly with one another and produce expected results when they are enabled/disabled and changed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added tests for 1, 3 but I'm not sure what you mean by 2. Like it fails to converge because it evaluates two times and the different is smaller than the tolerance?

self.assertAlmostEqual(func(result.position).numpy(), 0, delta=1e-4)
self.assertTrue(result.converged)

def test_keras_model_optimization(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job fuzzing our a more complex use case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a lot smaller tolerance than the standard implementation :) (see https://github.com/andim/noisyopt/blob/master/noisyopt/tests/test_noisyopt.py#L97). I compared with existing implementations. Additionally, I added a test case that is basically just the one from qiskit (https://github.com/Qiskit/qiskit-terra/blob/main/test/python/algorithms/optimizers/test_spsa.py) and they have <= -0.95 when the minimum is -1 (which it doesn't reach even with exact expectation values), so it doesn't seem uncommon.

@lockwo
Copy link
Contributor Author

lockwo commented Jan 13, 2022

How is "//tensorflow_quantum/python/differentiators:gradient_test" failing when I changed 3 unrelated files? It wasn't failing before.

@MichaelBroughton
Copy link
Collaborator

MichaelBroughton commented Jan 13, 2022

unfortunately It's a little flaky (it's on the todo list to fix). Have you got things in a good state where you'd like another review ?

@lockwo
Copy link
Contributor Author

lockwo commented Jan 14, 2022

I addressed most things, if you could just go over my conversation responses I had a question or two and with those answered I'll fix anything up and then be ready for round 2.

Copy link
Collaborator

@MichaelBroughton MichaelBroughton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking pretty good now, just a few more nits and we should be good to merge.

Comment on lines 184 to 185
if seed is not None:
tf.random.set_seed(seed)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is setting the global seed, which we don't want (it will impact code outside of this function). Something like this might be a little better:

g = tf.random.Generator.from_seed(1234)
g.normal(shape=(2, 3)) # produces same output as tf.random.normal, but from the seed in g.

Ideally all random ops in this function should draw from g

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. There is only 1 random op (for the deltas), but it now draws from the generator.

Comment on lines 162 to 165
a: Scalar `tf.Tensor` of real dtype. Specifies the learning rate
alpha: Scalar `tf.Tensor` of real dtype. Specifies scaling of the
learning rate.
c: Scalar `tf.Tensor` of real dtype. Specifies the size of the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Need to update the names here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

Comment on lines 172 to 174
allowable increase in objective function (only applies if blocking
is true).
name: (Optional) Python `str`. The name prefixed to the ops created
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Need to add seed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

@lockwo
Copy link
Contributor Author

lockwo commented Jan 18, 2022

Alright, I think I made all the updates (and the gradient tests didn't flake this time)

Copy link
Collaborator

@MichaelBroughton MichaelBroughton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants