Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,7 @@ venv/*

# ignore emacs temp files
*#

# vscode
.vscode/*
*~
1 change: 1 addition & 0 deletions release/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ sh_binary(
"//tensorflow_quantum/core/ops:tfq_utility_ops_py",
"//tensorflow_quantum/core/ops:tfq_simulate_ops_py",
"//tensorflow_quantum/core/ops/math_ops:inner_product_op_py",
"//tensorflow_quantum/core/ops/math_ops:fidelity_op_py",
"//tensorflow_quantum/core/ops/noise:noisy_samples_op_py",
"//tensorflow_quantum/core/ops/noise:noisy_expectation_op_py",
"//tensorflow_quantum/core/ops/noise:noisy_sampled_expectation_op_py",
Expand Down
1 change: 1 addition & 0 deletions scripts/import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_imports():

# Math ops.
_ = tfq.math.inner_product
_ = tfq.math.fidelity

# Noisy simulation ops.
_ = tfq.noise.expectation
Expand Down
1 change: 1 addition & 0 deletions tensorflow_quantum/core/ops/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ py_library(
":tfq_utility_ops_py",
# test addons
"//tensorflow_quantum/core/ops/math_ops:inner_product_op_py",
"//tensorflow_quantum/core/ops/math_ops:fidelity_op_py",
"//tensorflow_quantum/core/ops/noise:noisy_expectation_op_py",
],
)
Expand Down
17 changes: 17 additions & 0 deletions tensorflow_quantum/core/ops/math_ops/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ cc_binary(
],
)

py_library(
name = "fidelity_op_py",
srcs = ["fidelity_op.py"],
deps = [
":inner_product_op_py",
],
)

py_library(
name = "inner_product_op_py",
srcs = ["inner_product_op.py"],
Expand All @@ -91,6 +99,15 @@ py_test(
],
)

py_test(
name = "fidelity_op_test",
srcs = ["fidelity_op_test.py"],
deps = [
":fidelity_op_py",
"//tensorflow_quantum/python:util",
],
)

py_test(
name = "inner_product_grad_test",
srcs = ["inner_product_grad_test.py"],
Expand Down
1 change: 1 addition & 0 deletions tensorflow_quantum/core/ops/math_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
"""Module for tfq.core.ops.math_ops.*"""

from tensorflow_quantum.core.ops.math_ops.inner_product_op import inner_product
from tensorflow_quantum.core.ops.math_ops.fidelity_op import fidelity
84 changes: 84 additions & 0 deletions tensorflow_quantum/core/ops/math_ops/fidelity_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2020 The TensorFlow Quantum Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module for tfq.math.fidelity op."""
import tensorflow as tf
from tensorflow_quantum.core.ops.math_ops import inner_product_op


@tf.function
def fidelity(programs, symbol_names, symbol_values, other_programs):
"""Calculate the fidelity between circuits.

Compute (potentially many) fidelities between the given circuits and
the symbol free comparison circuits.

Calculates out[i][j] = $ | \langle \psi_{\text{programs[i]}} \\
(\text{symbol_values[i]}) | \psi_{\text{other_programs[j]}} \rangle \\
|^2 $


>>> symbols = sympy.symbols('alpha beta')
>>> qubits = cirq.GridQubit.rect(1, 2)
>>> reference_circuits = [
... cirq.Circuit((cirq.H**symbols[0]).on_each(qubits)),
... cirq.Circuit(
... cirq.X(qubits[0]) ** symbols[0],
... cirq.Y(qubits[1]) ** symbols[1])
... ]
>>> other_circuits = [
... cirq.Circuit(cirq.X.on_each(qubits)),
... cirq.Circuit((cirq.Y**0.125).on_each(qubits)),
... cirq.Circuit((cirq.X**0.5).on_each(qubits))
... ]
>>> reference_tensor = tfq.convert_to_tensor(reference_circuits)
>>> symbol_tensor = tf.convert_to_tensor([s.name for s in symbols])
>>> values_tensor = tf.convert_to_tensor(np.arange(4).reshape(2, 2))
>>> other_tensor = tfq.convert_to_tensor([other_circuits, other_circuits])
>>> fid = tfq.math.fidelity(reference_tensor, symbol_tensor,
... values_tensor, other_tensor)
>>> fid
tf.Tensor(
[[ 0., 0.925, 0.25],
[ 0., 0.036, 0.25]],shape=(2, 3), dtype=float32)



Note: `other_programs` must not contain any free symbols. These can
be resolved beforehand with `tfq.resolve_parameters`.

Args:
programs: `tf.Tensor` of strings with shape [batch_size] containing
the string representations of the circuits
symbol_names: `tf.Tensor` of strings with shape [n_params], which
is used to specify the order in which the values in
`symbol_values` should be placed inside of the circuits in
`programs`.
symbol_values: `tf.Tensor` of real numbers with shape
[batch_size, n_params] specifying parameter values to resolve
into the circuits specificed by programs, following the ordering
dictated by `symbol_names`.
other_programs: `tf.Tensor` of strings with shape [batch_size, n_others]
containing the string representations of the circuits with which to
compute the overlap on `programs` with. Must not contain any free
symbols.
Returns:
`tf.Tensor` with shape [batch_size, n_others] where `out[i][j]` is equal
to the fidelity of `programs[i]` with `symbol_values[i]`
resolved in and `other_programs[i][j]`.
"""
ip = inner_product_op.inner_product(programs, symbol_names,
tf.cast(symbol_values, tf.float32),
other_programs)
return tf.math.abs(ip)**2
Loading