diff --git a/release/BUILD b/release/BUILD index 103820b4e..dff6e9b7d 100644 --- a/release/BUILD +++ b/release/BUILD @@ -12,6 +12,7 @@ sh_binary( # Core module. "//tensorflow_quantum/core:__init__.py", "//tensorflow_quantum/core/ops:__init__.py", + "//tensorflow_quantum/core/ops/math_ops:__init__.py", "//tensorflow_quantum/core/proto:__init__.py", "//tensorflow_quantum/core/serialize:__init__.py", @@ -37,6 +38,7 @@ sh_binary( "//tensorflow_quantum/core/ops:tfq_unitary_op_py", "//tensorflow_quantum/core/ops:tfq_utility_ops_py", "//tensorflow_quantum/core/ops:tfq_simulate_ops_py", + "//tensorflow_quantum/core/ops/math_ops:inner_product_op_py", "//tensorflow_quantum/core/serialize:serializer", "//tensorflow_quantum/datasets:cluster_state", "//tensorflow_quantum/datasets:spin_system", diff --git a/scripts/import_test.py b/scripts/import_test.py index 581af3add..1d975d551 100644 --- a/scripts/import_test.py +++ b/scripts/import_test.py @@ -34,6 +34,9 @@ def test_imports(): _ = tfq.padded_to_ragged2d _ = tfq.resolve_parameters + # Math ops. + _ = tfq.math.inner_product + # Util functions. _ = tfq.convert_to_tensor _ = tfq.get_quantum_concurrent_op_mode diff --git a/tensorflow_quantum/__init__.py b/tensorflow_quantum/__init__.py index 614e75eb8..1a16cf6c7 100644 --- a/tensorflow_quantum/__init__.py +++ b/tensorflow_quantum/__init__.py @@ -21,6 +21,9 @@ get_unitary_op, padded_to_ragged, padded_to_ragged2d, resolve_parameters) +# Import math ops. +from tensorflow_quantum.core import math_ops as math + # Re-label python module as layers module. import tensorflow_quantum.python.layers as layers diff --git a/tensorflow_quantum/core/__init__.py b/tensorflow_quantum/core/__init__.py index acb40a707..2e60d6927 100644 --- a/tensorflow_quantum/core/__init__.py +++ b/tensorflow_quantum/core/__init__.py @@ -24,3 +24,6 @@ # Special case for append op which we didn't name well. from tensorflow_quantum.core.ops import \ tfq_append_circuit as append_circuit + +# Import math ops. +from tensorflow_quantum.core.ops import math_ops \ No newline at end of file diff --git a/tensorflow_quantum/core/ops/__init__.py b/tensorflow_quantum/core/ops/__init__.py index 1dd8a2199..65e1938d7 100644 --- a/tensorflow_quantum/core/ops/__init__.py +++ b/tensorflow_quantum/core/ops/__init__.py @@ -24,3 +24,6 @@ padded_to_ragged2d, resolve_parameters, tfq_append_circuit) + +# Import math_ops. +from tensorflow_quantum.core.ops import math_ops diff --git a/tensorflow_quantum/core/ops/math_ops/BUILD b/tensorflow_quantum/core/ops/math_ops/BUILD new file mode 100644 index 000000000..5db4c7d0c --- /dev/null +++ b/tensorflow_quantum/core/ops/math_ops/BUILD @@ -0,0 +1,84 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +# Export for the PIP package. +exports_files(["__init__.py"]) + +config_setting( + name = "windows", + constraint_values = ["@bazel_tools//platforms:windows"], +) + +cc_binary( + name = "_tfq_math_ops.so", + srcs = [ + "tfq_inner_product.cc", + ], + copts = select({ + ":windows": [ + "/D__CLANG_SUPPORT_DYN_ANNOTATION__", + "/D_USE_MATH_DEFINES", + "/DEIGEN_MPL2_ONLY", + "/DEIGEN_MAX_ALIGN_BYTES=64", + "/DEIGEN_HAS_TYPE_TRAITS=0", + "/DTF_USE_SNAPPY", + "/showIncludes", + "/MD", + "/O2", + "/DNDEBUG", + "/w", + "-DWIN32_LEAN_AND_MEAN", + "-DNOGDI", + "/d2ReducedOptimizeHugeFunctions", + "/arch:AVX", + "/std:c++14", + "-DTENSORFLOW_MONOLITHIC_BUILD", + "/DPLATFORM_WINDOWS", + "/DEIGEN_HAS_C99_MATH", + "/DTENSORFLOW_USE_EIGEN_THREADPOOL", + "/DEIGEN_AVOID_STL_ARRAY", + "/Iexternal/gemmlowp", + "/wd4018", + "/wd4577", + "/DNOGDI", + "/UTF_COMPILE_LIBRARY", + ], + "//conditions:default": [ + "-pthread", + "-std=c++11", + "-D_GLIBCXX_USE_CXX11_ABI=0", + ], + }), + features = select({ + ":windows": ["windows_export_all_symbols"], + "//conditions:default": [], + }), + linkshared = 1, + deps = [ + "//tensorflow_quantum/core/ops:parse_context", + "//tensorflow_quantum/core/ops:tfq_simulate_utils", + "//tensorflow_quantum/core/src:util_qsim", + "//tensorflow_quantum/core/src:circuit_parser_qsim", + "@qsim//lib:qsim_lib", + ], +) + +py_library( + name = "inner_product_op_py", + srcs = ["inner_product_op.py"], + data = [":_tfq_math_ops.so"], + deps = [ + "//tensorflow_quantum/core/ops:load_module", + ], +) + +py_test( + name = "inner_product_op_test", + srcs = ["inner_product_op_test.py"], + python_version = "PY3", + deps = [ + ":inner_product_op_py", + "//tensorflow_quantum/python:util", + ], +) diff --git a/tensorflow_quantum/core/ops/math_ops/__init__.py b/tensorflow_quantum/core/ops/math_ops/__init__.py new file mode 100644 index 000000000..7842b8784 --- /dev/null +++ b/tensorflow_quantum/core/ops/math_ops/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2020 The TensorFlow Quantum Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module for tfq.core.ops.math_ops.*""" + +from tensorflow_quantum.core.ops.math_ops.inner_product_op import inner_product diff --git a/tensorflow_quantum/core/ops/math_ops/inner_product_op.py b/tensorflow_quantum/core/ops/math_ops/inner_product_op.py new file mode 100644 index 000000000..d808ef45d --- /dev/null +++ b/tensorflow_quantum/core/ops/math_ops/inner_product_op.py @@ -0,0 +1,55 @@ +# Copyright 2020 The TensorFlow Quantum Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module to register python op gradient.""" +import os +import tensorflow as tf +from tensorflow_quantum.core.ops.load_module import load_module + +MATH_OP_MODULE = load_module(os.path.join("math_ops", "_tfq_math_ops.so")) + + +def inner_product(programs, symbol_names, symbol_values, other_programs): + """Calculate the inner product between circuits. + + Calculates out[i][j] = \langle \psi_{\text{programs[i]}} \\ + (\text{symvol_values[i]}) | \psi_{\text{other_programs[j]}} \rangle + + Note: `other_programs` must not contain any free symbols. These can resolved + beforehand with `tfq.resolve_parameters`. + + Args: + programs: `tf.Tensor` of strings with shape [batch_size] containing + the string representations of the circuits + symbol_names: `tf.Tensor` of strings with shape [n_params], which + is used to specify the order in which the values in + `symbol_values` should be placed inside of the circuits in + `programs`. + symbol_values: `tf.Tensor` of real numbers with shape + [batch_size, n_params] specifying parameter values to resolve + into the circuits specificed by programs, following the ordering + dictated by `symbol_names`. + other_programs: `tf.Tensor` of strings with shape [batch_size, n_others] + containing the string representations of the circuits with which to + compute the overlap on `programs` with. Must not contain any free + symbols. + Returns: + `tf.Tensor` with shape [batch_size, n_others] where `out[i][j]` is equal + to the inner product of `programs[i]` with `symbol_values[i]` + resolved in and `other_programs[i][j]`. + + """ + return MATH_OP_MODULE.tfq_inner_product(programs, symbol_names, + tf.cast(symbol_values, tf.float32), + other_programs) diff --git a/tensorflow_quantum/core/ops/math_ops/inner_product_op_test.py b/tensorflow_quantum/core/ops/math_ops/inner_product_op_test.py new file mode 100644 index 000000000..b9be44080 --- /dev/null +++ b/tensorflow_quantum/core/ops/math_ops/inner_product_op_test.py @@ -0,0 +1,317 @@ +# Copyright 2020 The TensorFlow Quantum Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests that specifically target tfq_simulate_ops.""" +import numpy as np +from absl.testing import parameterized +import tensorflow as tf +import cirq + +from tensorflow_quantum.core.ops.math_ops import inner_product_op +from tensorflow_quantum.python import util + + +class InnerProductTest(tf.test.TestCase, parameterized.TestCase): + """Tests tfq_inner_product.""" + + def test_inner_product_inputs(self): + """Make sure that inner_product fails gracefully on bad inputs.""" + n_qubits = 5 + batch_size = 5 + symbol_names = ['alpha'] + qubits = cirq.GridQubit.rect(1, n_qubits) + circuit_batch, resolver_batch = \ + util.random_symbol_circuit_resolver_batch( + qubits, symbol_names, batch_size) + + symbol_values_array = np.array( + [[resolver[symbol] + for symbol in symbol_names] + for resolver in resolver_batch]) + + other_batch = [ + util.random_circuit_resolver_batch(qubits, 3)[0] + for i in range(batch_size) + ] + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'programs must be rank 1'): + # Circuit tensor has too many dimensions. + inner_product_op.inner_product( + util.convert_to_tensor([circuit_batch]), symbol_names, + symbol_values_array, util.convert_to_tensor(other_batch)) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'symbol_names must be rank 1.'): + # symbol_names tensor has too many dimensions. + inner_product_op.inner_product( + util.convert_to_tensor(circuit_batch), np.array([symbol_names]), + symbol_values_array, util.convert_to_tensor(other_batch)) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'symbol_values must be rank 2.'): + # symbol_values_array tensor has too many dimensions. + inner_product_op.inner_product( + util.convert_to_tensor(circuit_batch), symbol_names, + np.array([symbol_values_array]), + util.convert_to_tensor(other_batch)) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'symbol_values must be rank 2.'): + # symbol_values_array tensor has too few dimensions. + inner_product_op.inner_product( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array[0], util.convert_to_tensor(other_batch)) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'other_programs must be rank 2.'): + # other_programs tensor has too few dimensions. + inner_product_op.inner_product( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array, util.convert_to_tensor(circuit_batch)) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'other_programs must be rank 2.'): + # pauli_sums tensor has too many dimensions. + inner_product_op.inner_product( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array, + util.convert_to_tensor([[x] for x in other_batch])) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'Unparseable proto'): + # circuit tensor has the right type but invalid values. + inner_product_op.inner_product(['junk'] * batch_size, symbol_names, + symbol_values_array, + util.convert_to_tensor(other_batch)) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'Could not find symbol in parameter map'): + # symbol_names tensor has the right type but invalid values. + inner_product_op.inner_product( + util.convert_to_tensor(circuit_batch), ['junk'], + symbol_values_array, util.convert_to_tensor(other_batch)) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'not found in reference circuit'): + # other_programs tensor has the right type but operates on + # qubits that the reference ciruit doesn't have. + new_qubits = [cirq.GridQubit(5, 5), cirq.GridQubit(9, 9)] + new_circuits, _ = util.random_circuit_resolver_batch( + new_qubits, batch_size) + inner_product_op.inner_product( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array, + util.convert_to_tensor([[x] for x in new_circuits])) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'not found in paired circuit'): + # other_programs tensor has the right type but operates on + # qubits that the reference ciruit doesn't have. + new_qubits = cirq.GridQubit.rect(1, n_qubits - 1) + new_circuits, _ = util.random_circuit_resolver_batch( + new_qubits, batch_size) + inner_product_op.inner_product( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array, + util.convert_to_tensor([[x] for x in new_circuits])) + + with self.assertRaisesRegex(TypeError, 'Cannot convert'): + # circuits tensor has the wrong type. + inner_product_op.inner_product([1.0] * batch_size, symbol_names, + symbol_values_array, + util.convert_to_tensor(other_batch)) + + with self.assertRaisesRegex(TypeError, 'Cannot convert'): + # symbol_names tensor has the wrong type. + inner_product_op.inner_product( + util.convert_to_tensor(circuit_batch), [0.1234], + symbol_values_array, util.convert_to_tensor(other_batch)) + + with self.assertRaisesRegex(tf.errors.UnimplementedError, ''): + # symbol_values tensor has the wrong type. + inner_product_op.inner_product( + util.convert_to_tensor(circuit_batch), symbol_names, + [['junk']] * batch_size, util.convert_to_tensor(other_batch)) + + with self.assertRaisesRegex(TypeError, 'Cannot convert'): + # other_programs tensor has the wrong type. + inner_product_op.inner_product( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array, [[1.0]] * batch_size) + + with self.assertRaisesRegex(TypeError, 'missing'): + # we are missing an argument. + # pylint: disable=no-value-for-parameter + inner_product_op.inner_product( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array) + # pylint: enable=no-value-for-parameter + + with self.assertRaisesRegex(TypeError, 'positional arguments'): + # pylint: disable=too-many-function-args + inner_product_op.inner_product( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array, util.convert_to_tensor(other_batch), []) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + expected_regex='do not match'): + # batch programs has wrong batch size. + inner_product_op.inner_product( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array, + util.convert_to_tensor(other_batch[:int(batch_size * 0.5)])) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + expected_regex='do not match'): + # batch programs has wrong batch size. + inner_product_op.inner_product( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array[::int(batch_size * 0.5)], + util.convert_to_tensor(other_batch)) + + with self.assertRaisesRegex( + tf.errors.InvalidArgumentError, + expected_regex='Found symbols in other_programs'): + # other_programs has symbols. + inner_product_op.inner_product( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array, + util.convert_to_tensor([[x] for x in circuit_batch])) + + res = inner_product_op.inner_product( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array.astype(np.float64), + util.convert_to_tensor(other_batch)) + self.assertDTypeEqual(res, np.complex64) + + @parameterized.parameters([ + { + 'n_qubits': 5, + 'batch_size': 10, + 'inner_dim_size': 1 + }, + { + 'n_qubits': 10, + 'batch_size': 10, + 'inner_dim_size': 2 + }, + { + 'n_qubits': 5, + 'batch_size': 10, + 'inner_dim_size': 5 + }, + ]) + def test_correctness_with_symbols(self, n_qubits, batch_size, + inner_dim_size): + """Test that inner_product works with symbols.""" + symbol_names = ['alpha', 'beta', 'gamma'] + qubits = cirq.GridQubit.rect(1, n_qubits) + circuit_batch, resolver_batch = \ + util.random_symbol_circuit_resolver_batch( + qubits, symbol_names, batch_size) + + other_batch = [ + util.random_circuit_resolver_batch(qubits, inner_dim_size)[0] + for i in range(batch_size) + ] + + symbol_values_array = np.array( + [[resolver[symbol] + for symbol in symbol_names] + for resolver in resolver_batch]) + + programs = util.convert_to_tensor(circuit_batch) + other_programs = util.convert_to_tensor(other_batch) + symbol_names = tf.convert_to_tensor(symbol_names, + dtype=tf.dtypes.string) + symbol_values = tf.convert_to_tensor(symbol_values_array) + + out = inner_product_op.inner_product(programs, symbol_names, + symbol_values, other_programs) + + out_arr = np.empty((batch_size, inner_dim_size), dtype=np.complex64) + for i in range(batch_size): + final_circuit = cirq.resolve_parameters(circuit_batch[i], + resolver_batch[i]) + final_wf = cirq.final_wavefunction(final_circuit) + for j in range(inner_dim_size): + internal_wf = cirq.final_wavefunction(other_batch[i][j]) + out_arr[i][j] = np.vdot(final_wf, internal_wf) + + self.assertAllClose(out, out_arr) + + @parameterized.parameters([ + { + 'n_qubits': 5, + 'batch_size': 10, + 'inner_dim_size': 1 + }, + { + 'n_qubits': 10, + 'batch_size': 10, + 'inner_dim_size': 2 + }, + { + 'n_qubits': 5, + 'batch_size': 10, + 'inner_dim_size': 5 + }, + ]) + def test_correctness_without_symbols(self, n_qubits, batch_size, + inner_dim_size): + """Test that inner_product works with symbols.""" + qubits = cirq.GridQubit.rect(1, n_qubits) + circuit_batch, _ = \ + util.random_circuit_resolver_batch( + qubits, batch_size) + + other_batch = [ + util.random_circuit_resolver_batch(qubits, inner_dim_size)[0] + for i in range(batch_size) + ] + + programs = util.convert_to_tensor(circuit_batch) + other_programs = util.convert_to_tensor(other_batch) + symbol_names = tf.convert_to_tensor([], dtype=tf.dtypes.string) + symbol_values = tf.convert_to_tensor([[] for _ in range(batch_size)]) + + out = inner_product_op.inner_product(programs, symbol_names, + symbol_values, other_programs) + + out_arr = np.empty((batch_size, inner_dim_size), dtype=np.complex64) + for i in range(batch_size): + final_wf = cirq.final_wavefunction(circuit_batch[i]) + for j in range(inner_dim_size): + internal_wf = cirq.final_wavefunction(other_batch[i][j]) + out_arr[i][j] = np.vdot(final_wf, internal_wf) + + self.assertAllClose(out, out_arr) + + def test_correctness_empty(self): + """Test the inner product between two empty circuits.""" + + empty_cicuit = util.convert_to_tensor([cirq.Circuit()]) + empty_symbols = tf.convert_to_tensor([], dtype=tf.dtypes.string) + empty_values = tf.convert_to_tensor([[]]) + other_program = util.convert_to_tensor([[cirq.Circuit()]]) + + out = inner_product_op.inner_product(empty_cicuit, empty_symbols, + empty_values, other_program) + expected = np.array([[1.0]], dtype=np.complex64) + self.assertAllClose(out, expected) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_quantum/core/ops/math_ops/tfq_inner_product.cc b/tensorflow_quantum/core/ops/math_ops/tfq_inner_product.cc new file mode 100644 index 000000000..09c091a95 --- /dev/null +++ b/tensorflow_quantum/core/ops/math_ops/tfq_inner_product.cc @@ -0,0 +1,314 @@ +/* Copyright 2020 The TensorFlow Quantum Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "../qsim/lib/circuit.h" +#include "../qsim/lib/gate_appl.h" +#include "../qsim/lib/gates_cirq.h" +#include "../qsim/lib/seqfor.h" +#include "../qsim/lib/simmux.h" +#include "cirq/google/api/v2/program.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow_quantum/core/ops/parse_context.h" +#include "tensorflow_quantum/core/src/util_qsim.h" + +namespace tfq { + +using ::cirq::google::api::v2::Program; +using ::tensorflow::Status; +using ::tfq::proto::PauliSum; + +typedef qsim::Cirq::GateCirq QsimGate; +typedef qsim::Circuit QsimCircuit; +typedef std::vector> QsimFusedCircuit; + +class TfqInnerProductOp : public tensorflow::OpKernel { + public: + explicit TfqInnerProductOp(tensorflow::OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(tensorflow::OpKernelContext* context) override { + // TODO (mbbrough): add more dimension checks for other inputs here. + const int num_inputs = context->num_inputs(); + OP_REQUIRES(context, num_inputs == 4, + tensorflow::errors::InvalidArgument(absl::StrCat( + "Expected 4 inputs, got ", num_inputs, " inputs."))); + + // Create the output Tensor. + const int output_dim_batch_size = context->input(0).dim_size(0); + const int output_dim_internal_size = context->input(3).dim_size(1); + tensorflow::TensorShape output_shape; + output_shape.AddDim(output_dim_batch_size); + output_shape.AddDim(output_dim_internal_size); + + tensorflow::Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + auto output_tensor = output->matrix>(); + + // Parse program protos. + std::vector programs; + std::vector num_qubits; + std::vector> other_programs; + OP_REQUIRES_OK(context, + GetProgramsAndNumQubits(context, &programs, &num_qubits, + &other_programs)); + + std::vector maps; + OP_REQUIRES_OK(context, GetSymbolMaps(context, &maps)); + + OP_REQUIRES(context, programs.size() == maps.size(), + tensorflow::errors::InvalidArgument(absl::StrCat( + "Number of circuits and symbol_values do not match. Got ", + programs.size(), " circuits and ", maps.size(), + " symbol values."))); + + // Construct qsim circuits for programs. + std::vector qsim_circuits(programs.size(), QsimCircuit()); + std::vector fused_circuits(programs.size(), + QsimFusedCircuit({})); + + auto construct_f = [&](int start, int end) { + for (int i = start; i < end; i++) { + OP_REQUIRES_OK(context, QsimCircuitFromProgram( + programs[i], maps[i], num_qubits[i], + &qsim_circuits[i], &fused_circuits[i])); + } + }; + + const int num_cycles = 1000; + context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor( + output_dim_batch_size, num_cycles, construct_f); + + // Construct qsim circuits for other_programs. + std::vector> other_qsim_circuits( + output_dim_batch_size, + std::vector(output_dim_internal_size, QsimCircuit())); + std::vector> other_fused_circuits( + output_dim_batch_size, + std::vector(output_dim_internal_size, + QsimFusedCircuit({}))); + + auto construct_f2 = [&](int start, int end) { + for (int i = start; i < end; i++) { + int ii = i / output_dim_internal_size; + int jj = i % output_dim_internal_size; + Status status = QsimCircuitFromProgram( + other_programs[ii][jj], {}, num_qubits[ii], + &other_qsim_circuits[ii][jj], &other_fused_circuits[ii][jj]); + OP_REQUIRES(context, status.ok(), + tensorflow::errors::InvalidArgument(absl::StrCat( + "Found symbols in other_programs.", + "No symbols are allowed in these circuits."))); + } + }; + + context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor( + output_dim_batch_size * output_dim_internal_size, num_cycles, + construct_f2); + + int max_num_qubits = 0; + for (const int num : num_qubits) { + max_num_qubits = std::max(max_num_qubits, num); + } + + // Cross reference with standard google cloud compute instances + // Memory ~= 2 * num_threads * (2 * 64 * 2 ** num_qubits in circuits) + // e2s2 = 2 CPU, 8GB -> Can safely do 25 since Memory = 4GB + // e2s4 = 4 CPU, 16GB -> Can safely do 25 since Memory = 8GB + // ... + if (max_num_qubits >= 26 || output_dim_batch_size == 1) { + ComputeLarge(num_qubits, fused_circuits, other_fused_circuits, context, + &output_tensor); + } else { + ComputeSmall(num_qubits, max_num_qubits, fused_circuits, + other_fused_circuits, context, &output_tensor); + } + } + + private: + void ComputeLarge( + const std::vector& num_qubits, + const std::vector& fused_circuits, + const std::vector>& other_fused_circuits, + tensorflow::OpKernelContext* context, + tensorflow::TTypes, 1>::Matrix* output_tensor) { + // Instantiate qsim objects. + const auto tfq_for = tfq::QsimFor(context); + using Simulator = qsim::Simulator; + using StateSpace = Simulator::StateSpace; + using State = StateSpace::State; + + // Begin simulation. + int largest_nq = 1; + State sv = StateSpace(largest_nq, tfq_for).CreateState(); + State scratch = StateSpace(largest_nq, tfq_for).CreateState(); + + // Simulate programs one by one. Parallelizing over wavefunctions + // we no longer parallelize over circuits. Each time we encounter a + // a larger circuit we will grow the Statevector as necessary. + for (int i = 0; i < fused_circuits.size(); i++) { + int nq = num_qubits[i]; + Simulator sim = Simulator(nq, tfq_for); + StateSpace ss = StateSpace(nq, tfq_for); + if (nq > largest_nq) { + // need to switch to larger statespace. + largest_nq = nq; + sv = ss.CreateState(); + scratch = ss.CreateState(); + } + // TODO: add heuristic here so that we do not always recompute + // the state if there is a possibility that circuit[i] and + // circuit[i + 1] produce the same state. + ss.SetStateZero(sv); + for (int j = 0; j < fused_circuits[i].size(); j++) { + qsim::ApplyFusedGate(sim, fused_circuits[i][j], sv); + } + for (int j = 0; j < other_fused_circuits[i].size(); j++) { + // (#679) Just ignore empty program + if (fused_circuits[i].size() == 0) { + (*output_tensor)(i, j) = std::complex(1, 0); + continue; + } + + ss.SetStateZero(scratch); + for (int k = 0; k < other_fused_circuits[i][j].size(); k++) { + qsim::ApplyFusedGate(sim, other_fused_circuits[i][j][k], scratch); + } + + std::complex result = ss.InnerProduct(sv, scratch); + (*output_tensor)(i, j) = + std::complex(static_cast(result.real()), + static_cast(result.imag())); + } + } + } + + void ComputeSmall( + const std::vector& num_qubits, const int max_num_qubits, + const std::vector& fused_circuits, + const std::vector>& other_fused_circuits, + tensorflow::OpKernelContext* context, + tensorflow::TTypes, 1>::Matrix* output_tensor) { + const auto tfq_for = qsim::SequentialFor(1); + using Simulator = qsim::Simulator; + using StateSpace = Simulator::StateSpace; + using State = StateSpace::State; + + const int output_dim_internal_size = output_tensor->dimension(1); + + auto DoWork = [&](int start, int end) { + int old_batch_index = -2; + int cur_batch_index = -1; + int largest_nq = 1; + int cur_internal_index; + + State sv = StateSpace(largest_nq, tfq_for).CreateState(); + State scratch = StateSpace(largest_nq, tfq_for).CreateState(); + for (int i = start; i < end; i++) { + cur_batch_index = i / output_dim_internal_size; + cur_internal_index = i % output_dim_internal_size; + + const int nq = num_qubits[cur_batch_index]; + Simulator sim = Simulator(nq, tfq_for); + StateSpace ss = StateSpace(nq, tfq_for); + + // (#679) Just ignore empty program + if (fused_circuits[cur_batch_index].size() == 0) { + (*output_tensor)(cur_batch_index, cur_internal_index) = + std::complex(1, 0); + continue; + } + + if (cur_batch_index != old_batch_index) { + // We've run into a new wavefunction we must compute. + // Only compute a new wavefunction when we have to. + if (nq > largest_nq) { + sv = ss.CreateState(); + scratch = ss.CreateState(); + largest_nq = nq; + } + // no need to update scratch_state since ComputeExpectation + // will take care of things for us. + ss.SetStateZero(sv); + for (int j = 0; j < fused_circuits[cur_batch_index].size(); j++) { + qsim::ApplyFusedGate(sim, fused_circuits[cur_batch_index][j], sv); + } + } + + ss.SetStateZero(scratch); + for (int k = 0; + k < + other_fused_circuits[cur_batch_index][cur_internal_index].size(); + k++) { + qsim::ApplyFusedGate( + sim, other_fused_circuits[cur_batch_index][cur_internal_index][k], + scratch); + } + + std::complex result = ss.InnerProduct(sv, scratch); + (*output_tensor)(cur_batch_index, cur_internal_index) = + std::complex(static_cast(result.real()), + static_cast(result.imag())); + + old_batch_index = cur_batch_index; + } + }; + + const int64_t num_cycles = + 200 * (int64_t(1) << static_cast(max_num_qubits)); + context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor( + fused_circuits.size() * output_dim_internal_size, num_cycles, DoWork); + } +}; + +REGISTER_KERNEL_BUILDER(Name("TfqInnerProduct").Device(tensorflow::DEVICE_CPU), + TfqInnerProductOp); + +REGISTER_OP("TfqInnerProduct") + .Input("programs: string") + .Input("symbol_names: string") + .Input("symbol_values: float") + .Input("other_programs: string") + .Output("inner_products: complex64") + .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) { + tensorflow::shape_inference::ShapeHandle programs_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &programs_shape)); + + tensorflow::shape_inference::ShapeHandle symbol_names_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &symbol_names_shape)); + + tensorflow::shape_inference::ShapeHandle symbol_values_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &symbol_values_shape)); + + tensorflow::shape_inference::ShapeHandle other_programs_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &other_programs_shape)); + + tensorflow::shape_inference::DimensionHandle output_rows = + c->Dim(programs_shape, 0); + tensorflow::shape_inference::DimensionHandle output_cols = + c->Dim(other_programs_shape, 1); + c->set_output(0, c->Matrix(output_rows, output_cols)); + + return tensorflow::Status::OK(); + }); + +} // namespace tfq diff --git a/tensorflow_quantum/core/ops/parse_context.cc b/tensorflow_quantum/core/ops/parse_context.cc index 7126b7340..2d8f5b573 100644 --- a/tensorflow_quantum/core/ops/parse_context.cc +++ b/tensorflow_quantum/core/ops/parse_context.cc @@ -89,6 +89,43 @@ Status ParsePrograms(OpKernelContext* context, const std::string& input_name, return Status::OK(); } +Status ParsePrograms2D(OpKernelContext* context, const std::string& input_name, + std::vector>* programs) { + const tensorflow::Tensor* input; + Status status = context->input(input_name, &input); + if (!status.ok()) { + return status; + } + + if (input->dims() != 2) { + // Never parse anything other than a 1d list of circuits. + return Status(tensorflow::error::INVALID_ARGUMENT, + absl::StrCat("other_programs must be rank 2. Got rank ", + input->dims(), ".")); + } + + const auto program_strings = input->matrix(); + const int num_programs = program_strings.dimension(0); + const int num_entries = program_strings.dimension(1); + programs->assign(num_programs, std::vector(num_entries, Program())); + + auto DoWork = [&](int start, int end) { + for (int i = start; i < end; i++) { + OP_REQUIRES_OK( + context, + ParseProto(program_strings(i / num_entries, i % num_entries), + &programs->at(i / num_entries).at(i % num_entries))); + } + }; + + // TODO(mbbrough): Determine if this is a good cycle estimate. + const int cycle_estimate = 1000; + context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor( + num_programs * num_entries, cycle_estimate, DoWork); + + return Status::OK(); +} + Status GetProgramsAndProgramsToAppend( OpKernelContext* context, std::vector* programs, std::vector* programs_to_append) { @@ -160,6 +197,50 @@ Status GetProgramsAndNumQubits( return Status::OK(); } +tensorflow::Status GetProgramsAndNumQubits( + OpKernelContext* context, std::vector* programs, + std::vector* num_qubits, + std::vector>* other_programs) { + // 1. Parse input programs + // 2. Parse other_programs + // 3. Convert GridQubit locations to integers and ensure exact matching. + Status status = ParsePrograms(context, "programs", programs); + if (!status.ok()) { + return status; + } + + status = ParsePrograms2D(context, "other_programs", other_programs); + if (!status.ok()) { + return status; + } + + if (programs->size() != other_programs->size()) { + return Status(tensorflow::error::INVALID_ARGUMENT, + absl::StrCat("programs and other_programs batch dimension", + " do not match. Foud: ", programs->size(), + " and ", other_programs->size())); + } + + // Resolve qubit ID's in parallel. + num_qubits->assign(programs->size(), -1); + auto DoWork = [&](int start, int end) { + for (int i = start; i < end; i++) { + Program& program = (*programs)[i]; + unsigned int this_num_qubits; + OP_REQUIRES_OK(context, ResolveQubitIds(&program, &this_num_qubits, + &(*other_programs)[i])); + (*num_qubits)[i] = this_num_qubits; + } + }; + + // TODO(mbbrough): Determine if this is a good cycle estimate. + const int cycle_estimate = 1000 * (*other_programs)[0].size(); + context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor( + num_qubits->size(), cycle_estimate, DoWork); + + return Status::OK(); +} + Status GetPauliSums(OpKernelContext* context, std::vector>* p_sums) { // 1. Parses PauliSum proto. diff --git a/tensorflow_quantum/core/ops/parse_context.h b/tensorflow_quantum/core/ops/parse_context.h index 37fdd6927..510598755 100644 --- a/tensorflow_quantum/core/ops/parse_context.h +++ b/tensorflow_quantum/core/ops/parse_context.h @@ -32,6 +32,11 @@ tensorflow::Status ParsePrograms( tensorflow::OpKernelContext* context, const std::string& input_name, std::vector* programs); +// Simplest Program proto parsing in 2D. +tensorflow::Status ParsePrograms2D( + tensorflow::OpKernelContext* context, const std::string& input_name, + std::vector>* programs); + // Parses a vector of programs along with another vector of programs to append tensorflow::Status GetProgramsAndProgramsToAppend( tensorflow::OpKernelContext* context, @@ -53,6 +58,16 @@ tensorflow::Status GetProgramsAndNumQubits( std::vector* num_qubits, std::vector>* p_sums = nullptr); +// Parses Cirq Program protos out of the 'circuit_specs' input Tensor. Also +// resolves the QubitIds inside of the Program. This override also parses and +// resolves other_programs. Ensuring all qubits found in programs[i] are also +// found in all programs[i][j] for all j. +tensorflow::Status GetProgramsAndNumQubits( + tensorflow::OpKernelContext* context, + std::vector* programs, + std::vector* num_qubits, + std::vector>* other_programs); + // Parses PauliSum protos out of the 'pauli_sums' input tensor. Note this // function does NOT resolve QubitID's as any paulisum needs a reference // program to "discover" all of the active qubits and define the ordering.