Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tensorflow_quantum/core/ops/cirq_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,8 @@ def test_sampling_output_padding(self, op, all_n_qubits, n_samples):
this_expected_output[:, :max(all_n_qubits) - n_qubits] = -2
expected_outputs.append(this_expected_output)
circuits.append(
cirq.Circuit(
*cirq.X.on_each(*cirq.GridQubit.rect(1, n_qubits))))
cirq.Circuit(*cirq.X.on_each(
*cirq.GridQubit.rect(1, n_qubits))))
results = op(util.convert_to_tensor(circuits), [], [[]] * len(circuits),
[n_samples]).numpy()
self.assertAllClose(expected_outputs, results)
Expand Down Expand Up @@ -461,8 +461,8 @@ def run_sweep(self, program, params, repetitions):
circuits = []
for n_qubits in all_n_qubits:
circuits.append(
cirq.Circuit(
*cirq.X.on_each(*cirq.GridQubit.rect(1, n_qubits))))
cirq.Circuit(*cirq.X.on_each(
*cirq.GridQubit.rect(1, n_qubits))))
test_results = this_op(util.convert_to_tensor(circuits), [],
[[]] * len(circuits), [n_samples]).numpy()

Expand Down
14 changes: 13 additions & 1 deletion tensorflow_quantum/core/ops/math_ops/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ cc_binary(
name = "_tfq_math_ops.so",
srcs = [
"tfq_inner_product.cc",
"tfq_inner_product_grad.cc",
],
copts = select({
":windows": [
Expand Down Expand Up @@ -58,8 +59,9 @@ cc_binary(
deps = [
"//tensorflow_quantum/core/ops:parse_context",
"//tensorflow_quantum/core/ops:tfq_simulate_utils",
"//tensorflow_quantum/core/src:util_qsim",
"//tensorflow_quantum/core/src:adj_util",
"//tensorflow_quantum/core/src:circuit_parser_qsim",
"//tensorflow_quantum/core/src:util_qsim",
"@qsim//lib:qsim_lib",
],
)
Expand All @@ -82,3 +84,13 @@ py_test(
"//tensorflow_quantum/python:util",
],
)

py_test(
name = "inner_product_grad_test",
srcs = ["inner_product_grad_test.py"],
python_version = "PY3",
deps = [
":inner_product_op_py",
"//tensorflow_quantum/python:util",
],
)
386 changes: 386 additions & 0 deletions tensorflow_quantum/core/ops/math_ops/inner_product_grad_test.py

Large diffs are not rendered by default.

67 changes: 63 additions & 4 deletions tensorflow_quantum/core/ops/math_ops/inner_product_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,55 @@
MATH_OP_MODULE = load_module(os.path.join("math_ops", "_tfq_math_ops.so"))


def _inner_product_grad(programs, symbol_names, symbol_values, other_programs,
prev_grad):
"""Calculate the adjoint gradients of the inner product between circuits.

Compute the gradients of the (potentially many) inner products between
the given circuits and the symbol free comparison circuits.

Calculates out[i][j][k] = $ \frac{\langle \psi_{\text{programs[i]}} \\
(\text{symbol_values[i]})}{\partial \text{symbol_names[k]}} | \\
\psi_{\text{other_programs[j]}} \rangle $


Note: `other_programs` must not contain any free symbols. These can
be resolved beforehand with `tfq.resolve_parameters`.

Note: len(symbol_names) (=n_params) should be a positive integer.

Args:
programs: `tf.Tensor` of strings with shape [batch_size] containing
the string representations of the circuits
symbol_names: `tf.Tensor` of strings with shape [n_params], which
is used to specify the order in which the values in
`symbol_values` should be placed inside of the circuits in
`programs`.
symbol_values: `tf.Tensor` of real numbers with shape
[batch_size, n_params] specifying parameter values to resolve
into the circuits specificed by programs, following the ordering
dictated by `symbol_names`.
other_programs: `tf.Tensor` of strings with shape [batch_size, n_others]
containing the string representations of the circuits with which to
compute the overlap on `programs` with. Must not contain any free
symbols.
prev_grad: `tf.Tensor` of real numbers with shape [batch_size, n_ops]
backprop of values from downstream in the compute graph.

Returns:
tf.Tensor` with shape [batch_size, n_symbols] where `out[i][j]` is equal
to the gradient of the inner product between programs[i] and all
other_programs[i] w.r.t. `symbol_names[j]` and `programs[i]` is resolved
with `symbol_values[i]`.
"""
# Due to TF gradient scheme, we return complex conjugate derivative.
return tf.math.conj(
MATH_OP_MODULE.tfq_inner_product_grad(
programs, symbol_names, tf.cast(symbol_values, tf.float32),
other_programs, tf.cast(prev_grad, tf.float32)))


@tf.custom_gradient
def inner_product(programs, symbol_names, symbol_values, other_programs):
"""Calculate the inner product between circuits.

Expand Down Expand Up @@ -61,8 +110,6 @@ def inner_product(programs, symbol_names, symbol_values, other_programs):
Note: `other_programs` must not contain any free symbols. These can
be resolved beforehand with `tfq.resolve_parameters`.

Note: Currently this op is not differentiable.

Args:
programs: `tf.Tensor` of strings with shape [batch_size] containing
the string representations of the circuits
Expand All @@ -82,8 +129,20 @@ def inner_product(programs, symbol_names, symbol_values, other_programs):
`tf.Tensor` with shape [batch_size, n_others] where `out[i][j]` is equal
to the inner product of `programs[i]` with `symbol_values[i]`
resolved in and `other_programs[i][j]`.

"""

def grad(dy):

def _true_grad():
return _inner_product_grad(programs, symbol_names, symbol_values,
other_programs, dy)

ret_zero = tf.equal(tf.size(symbol_names), 0)
inner_prod_grad = tf.cond(ret_zero,
lambda: tf.zeros_like(symbol_values),
_true_grad)
return [None, None, inner_prod_grad, None]

return MATH_OP_MODULE.tfq_inner_product(programs, symbol_names,
tf.cast(symbol_values, tf.float32),
other_programs)
other_programs), grad
161 changes: 152 additions & 9 deletions tensorflow_quantum/core/ops/math_ops/inner_product_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests that specifically target tfq_simulate_ops."""
"""Tests that specifically target tfq_inner_product."""
import copy
import numpy as np
from absl.testing import parameterized
import tensorflow as tf
Expand All @@ -26,7 +27,7 @@ class InnerProductTest(tf.test.TestCase, parameterized.TestCase):
"""Tests tfq_inner_product."""

def test_inner_product_inputs(self):
"""Make sure that inner_product fails gracefully on bad inputs."""
"""Makes sure that inner_product fails gracefully on bad inputs."""
n_qubits = 5
batch_size = 5
symbol_names = ['alpha']
Expand Down Expand Up @@ -206,6 +207,11 @@ def test_inner_product_inputs(self):
self.assertDTypeEqual(res, np.complex64)

@parameterized.parameters([
{
'n_qubits': 5,
'batch_size': 1,
'inner_dim_size': 5
},
{
'n_qubits': 5,
'batch_size': 10,
Expand All @@ -224,7 +230,7 @@ def test_inner_product_inputs(self):
])
def test_correctness_with_symbols(self, n_qubits, batch_size,
inner_dim_size):
"""Test that inner_product works with symbols."""
"""Tests that inner_product works with symbols."""
symbol_names = ['alpha', 'beta', 'gamma']
qubits = cirq.GridQubit.rect(1, n_qubits)
circuit_batch, resolver_batch = \
Expand Down Expand Up @@ -264,12 +270,17 @@ def test_correctness_with_symbols(self, n_qubits, batch_size,
@parameterized.parameters([
{
'n_qubits': 5,
'batch_size': 10,
'batch_size': 1,
'inner_dim_size': 5
},
{
'n_qubits': 5,
'batch_size': 2,
'inner_dim_size': 1
},
{
'n_qubits': 10,
'batch_size': 10,
'batch_size': 3,
'inner_dim_size': 2
},
{
Expand All @@ -280,7 +291,7 @@ def test_correctness_with_symbols(self, n_qubits, batch_size,
])
def test_correctness_without_symbols(self, n_qubits, batch_size,
inner_dim_size):
"""Test that inner_product works with symbols."""
"""Tests that inner_product works without symbols."""
qubits = cirq.GridQubit.rect(1, n_qubits)
circuit_batch, _ = \
util.random_circuit_resolver_batch(
Expand Down Expand Up @@ -309,18 +320,135 @@ def test_correctness_without_symbols(self, n_qubits, batch_size,
self.assertAllClose(out, out_arr, atol=1e-5)

def test_correctness_empty(self):
"""Test the inner product between two empty circuits."""
"""Tests the inner product with empty circuits."""

empty_cicuit = util.convert_to_tensor([cirq.Circuit()])
empty_circuit = util.convert_to_tensor([cirq.Circuit()])
empty_symbols = tf.convert_to_tensor([], dtype=tf.dtypes.string)
empty_values = tf.convert_to_tensor([[]])
other_program = util.convert_to_tensor([[cirq.Circuit()]])

out = inner_product_op.inner_product(empty_cicuit, empty_symbols,
out = inner_product_op.inner_product(empty_circuit, empty_symbols,
empty_values, other_program)
expected = np.array([[1.0]], dtype=np.complex64)
self.assertAllClose(out, expected)

qubit = cirq.GridQubit(0, 0)
non_empty_circuit = util.convert_to_tensor(
[cirq.Circuit(cirq.X(qubit))])
empty_symbols = tf.convert_to_tensor([], dtype=tf.dtypes.string)
empty_values = tf.convert_to_tensor([[]])
other_program = util.convert_to_tensor([[cirq.Circuit()]])

with self.assertRaisesRegex(tf.errors.InvalidArgumentError,
'qubits not found'):
inner_product_op.inner_product(non_empty_circuit, empty_symbols,
empty_values, other_program)

@parameterized.parameters([
{
'n_qubits': 5,
'batch_size': 1,
'inner_dim_size': 5
},
{
'n_qubits': 5,
'batch_size': 3,
'inner_dim_size': 2
},
])
def test_tf_gradient_correctness_with_symbols(self, n_qubits, batch_size,
inner_dim_size):
"""Tests that tf.gradient of inner_product works with symbols."""
symbol_names = ['alpha', 'beta', 'gamma']
n_params = len(symbol_names)
qubits = cirq.GridQubit.rect(1, n_qubits)
circuit_batch, resolver_batch = \
util.random_symbol_circuit_resolver_batch(
qubits, symbol_names, batch_size)

other_batch = [
util.random_circuit_resolver_batch(qubits, inner_dim_size)[0]
for i in range(batch_size)
]

symbol_values_array = np.array(
[[resolver[symbol]
for symbol in symbol_names]
for resolver in resolver_batch])

programs = util.convert_to_tensor(circuit_batch)
other_programs = util.convert_to_tensor(other_batch)
symbol_names_tensor = tf.convert_to_tensor(symbol_names,
dtype=tf.dtypes.string)
symbol_values = tf.convert_to_tensor(symbol_values_array)

with tf.GradientTape() as tape:
tape.watch(symbol_values)
ip = inner_product_op.inner_product(programs, symbol_names_tensor,
symbol_values, other_programs)
out = tape.gradient(ip, symbol_values)

out_arr = np.zeros((batch_size, n_params), dtype=np.complex64)
# dx came from _GRAD_EPS of core/src/adj_util.cc
dx = 5e-3
for i in range(batch_size):
for k, name in enumerate(symbol_names):
if name in resolver_batch[i].param_dict:
new_resolver = copy.deepcopy(resolver_batch[i])
new_resolver.param_dict[name] += dx
final_circuit_p = cirq.resolve_parameters(
circuit_batch[i], new_resolver)
new_resolver = copy.deepcopy(resolver_batch[i])
new_resolver.param_dict[name] -= dx
final_circuit_m = cirq.resolve_parameters(
circuit_batch[i], new_resolver)
final_wf_p = cirq.final_state_vector(final_circuit_p)
final_wf_m = cirq.final_state_vector(final_circuit_m)
# Performs central finite difference.
final_wf_grad = 0.5 * (final_wf_p - final_wf_m) / dx
for j in range(inner_dim_size):
internal_wf = cirq.final_state_vector(other_batch[i][j])
out_arr[i][k] += np.vdot(final_wf_grad, internal_wf)

self.assertAllClose(out, np.conj(out_arr), atol=1e-3)

@parameterized.parameters([
{
'n_qubits': 5,
'batch_size': 1,
'inner_dim_size': 5
},
{
'n_qubits': 5,
'batch_size': 3,
'inner_dim_size': 2
},
])
def test_tf_gradient_correctness_without_symbols(self, n_qubits, batch_size,
inner_dim_size):
"""Tests that tf.gradient of inner_product works without symbols."""
qubits = cirq.GridQubit.rect(1, n_qubits)
circuit_batch, _ = \
util.random_circuit_resolver_batch(
qubits, batch_size)

other_batch = [
util.random_circuit_resolver_batch(qubits, inner_dim_size)[0]
for i in range(batch_size)
]

programs = util.convert_to_tensor(circuit_batch)
other_programs = util.convert_to_tensor(other_batch)
symbol_names = tf.convert_to_tensor([], dtype=tf.dtypes.string)
symbol_values = tf.convert_to_tensor([[] for _ in range(batch_size)])

with tf.GradientTape() as tape:
tape.watch(symbol_values)
ip = inner_product_op.inner_product(programs, symbol_names,
symbol_values, other_programs)
out = tape.gradient(ip, symbol_values)
self.assertAllClose(out, tf.zeros_like(symbol_values), atol=1e-3)

def test_correctness_no_circuit(self):
"""Test the inner product between no circuits."""

Expand All @@ -333,6 +461,21 @@ def test_correctness_no_circuit(self):
empty_values, other_program)
self.assertShapeEqual(np.zeros((0, 0)), out)

def test_tf_gradient_correctness_no_circuit(self):
"""Test the inner product grad between no circuits."""

empty_circuit = tf.raw_ops.Empty(shape=(0,), dtype=tf.string)
empty_symbols = tf.raw_ops.Empty(shape=(0,), dtype=tf.string)
empty_values = tf.raw_ops.Empty(shape=(0, 0), dtype=tf.float32)
other_program = tf.raw_ops.Empty(shape=(0, 0), dtype=tf.string)

with tf.GradientTape() as tape:
tape.watch(empty_values)
out = inner_product_op.inner_product(empty_circuit, empty_symbols,
empty_values, other_program)

self.assertShapeEqual(np.zeros((0, 0)), out)


if __name__ == "__main__":
tf.test.main()
Loading