Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tensorflow_quantum/core/serialize/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,37 @@ def _phase_flip_channel_deserializer():
args=args)


def _bit_flip_channel_serializer():
"""Make standard serializer for BitFlip channel."""
args = [
# cirq channels can't contain symbols.
cirq.google.SerializingArg(serialized_name="p",
serialized_type=float,
op_getter=lambda x: x.gate.p),
cirq.google.SerializingArg(serialized_name="control_qubits",
serialized_type=str,
op_getter=lambda x: ''),
cirq.google.SerializingArg(serialized_name="control_values",
serialized_type=str,
op_getter=lambda x: '')
]
return cirq.google.GateOpSerializer(gate_type=cirq.BitFlipChannel,
serialized_gate_id="BF",
args=args,
can_serialize_predicate=_CONSTANT_TRUE)


def _bit_flip_channel_deserializer():
"""Make standard deserializer for BitFlip channel."""
args = [
cirq.google.DeserializingArg(serialized_name="p",
constructor_arg_name="p")
]
return cirq.google.GateOpDeserializer(serialized_gate_id="BF",
gate_constructor=cirq.BitFlipChannel,
args=args)


# Gates.
def _eigen_gate_serializer(gate_type, serialized_id):
"""Make standard serializer for eigen gates."""
Expand Down Expand Up @@ -699,6 +730,7 @@ def _scalar_combiner(exponent, global_shift, exponent_scalar,
] + [
_amplitude_damp_channel_serializer(),
_asymmetric_depolarize_serializer(),
_bit_flip_channel_serializer(),
_depolarize_channel_serializer(),
_fsim_gate_serializer(),
_gad_channel_serializer(),
Expand All @@ -717,6 +749,7 @@ def _scalar_combiner(exponent, global_shift, exponent_scalar,
] + [
_amplitude_damp_channel_deserializer(),
_asymmetric_depolarize_deserializer(),
_bit_flip_channel_deserializer(),
_depolarize_channel_deserializer(),
_fsim_gate_deserializer(),
_gad_channel_deserializer(),
Expand Down
6 changes: 5 additions & 1 deletion tensorflow_quantum/core/serialize/serializer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,11 @@ def _get_noise_proto_pairs():

# Phase flip.
(cirq.Circuit(cirq.phase_flip(p=0.1)(q0)),
_build_op_proto("PF", ['p'], [0.1], ['0_0']))
_build_op_proto("PF", ['p'], [0.1], ['0_0'])),

# Bit flip.
(cirq.Circuit(cirq.bit_flip(p=0.1)(q0)),
_build_op_proto("BF", ['p'], [0.1], ['0_0']))
]
return pairs

Expand Down
22 changes: 21 additions & 1 deletion tensorflow_quantum/core/src/circuit_parser_qsim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,26 @@ inline Status PhaseFlipChannel(const Operation& op,
return Status::OK();
}

inline Status BitFlipChannel(const Operation& op, const unsigned int num_qubits,
const unsigned int time,
NoisyQsimCircuit* ncircuit) {
int q;
bool unused;
float p;
Status u;
unused = absl::SimpleAtoi(op.qubits(0).id(), &q);

u = ParseProtoArg(op, "p", {}, &p);
if (!u.ok()) {
return u;
}

auto chan =
qsim::Cirq::BitFlipChannel<float>::Create(time, num_qubits - q - 1, p);
ncircuit->channels.push_back(chan);
return Status::OK();
}

tensorflow::Status ParseAppendChannel(const Operation& op,
const unsigned int num_qubits,
const unsigned int time,
Expand All @@ -750,7 +770,7 @@ tensorflow::Status ParseAppendChannel(const Operation& op,
{"DP", &DepolarizingChannel}, {"ADP", &AsymmetricDepolarizingChannel},
{"GAD", &GADChannel}, {"AD", &AmplitudeDampingChannel},
{"RST", &ResetChannel}, {"PD", &PhaseDampingChannel},
{"PF", &PhaseFlipChannel}};
{"PF", &PhaseFlipChannel}, {"BF", &BitFlipChannel}};

auto build_f = chan_func_map.find(op.gate().id());
if (build_f == chan_func_map.end()) {
Expand Down
36 changes: 36 additions & 0 deletions tensorflow_quantum/core/src/circuit_parser_qsim_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,42 @@ TEST(QsimCircuitParserTest, PhaseFlip) {
ASSERT_EQ(test_circuit.num_qubits, 1);
}

TEST(QsimCircuitParserTest, BitFlip) {
float p = 0.1234;
auto reference = qsim::Cirq::BitFlipChannel<float>::Create(0, 0, p);
Program program_proto;
Circuit* circuit_proto = program_proto.mutable_circuit();
circuit_proto->set_scheduling_strategy(circuit_proto->MOMENT_BY_MOMENT);
Moment* moments_proto = circuit_proto->add_moments();

// Add channel.
Operation* operations_proto = moments_proto->add_operations();
Gate* gate_proto = operations_proto->mutable_gate();
gate_proto->set_id("BF");

// Set the args.
google::protobuf::Map<std::string, Arg>* args_proto =
operations_proto->mutable_args();
(*args_proto)["p"] = MakeArg(p);

// Set the control args.
(*args_proto)["control_qubits"] = MakeControlArg("");
(*args_proto)["control_values"] = MakeControlArg("");

// Set the qubits.
Qubit* qubits_proto = operations_proto->add_qubits();
qubits_proto->set_id("0");

NoisyQsimCircuit test_circuit;

ASSERT_EQ(
NoisyQsimCircuitFromProgram(program_proto, {}, 1, false, &test_circuit),
tensorflow::Status::OK());
AssertChannelEqual(test_circuit.channels[0], reference);
ASSERT_EQ(test_circuit.channels.size(), 1);
ASSERT_EQ(test_circuit.num_qubits, 1);
}

TEST(QsimCircuitParserTest, NoisyEmpty) {
Program program_proto;
Circuit* circuit_proto = program_proto.mutable_circuit();
Expand Down
6 changes: 6 additions & 0 deletions tensorflow_quantum/python/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
cirq.ResetChannel,
cirq.PhaseDampingChannel,
cirq.PhaseFlipChannel,
cirq.BitFlipChannel,
]


Expand Down Expand Up @@ -84,6 +85,7 @@ def get_supported_channels():
channel_mapping[cirq.ResetChannel()] = 1
channel_mapping[cirq.PhaseDampingChannel(0.01)] = 1
channel_mapping[cirq.PhaseFlipChannel(0.01)] = 1
channel_mapping[cirq.BitFlipChannel(0.01)] = 1

return channel_mapping

Expand Down Expand Up @@ -534,6 +536,10 @@ def _channel_approx_eq(op_true, op_deser, atol=1e-5):
if isinstance(op_deser, cirq.PhaseFlipChannel):
return abs(op_true.p - op_deser.p) < atol

if isinstance(op_true, cirq.BitFlipChannel):
if isinstance(op_deser, cirq.BitFlipChannel):
return abs(op_true.p - op_deser.p) < atol

return False


Expand Down