Skip to content

Commit

Permalink
Address feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
cgranade committed Sep 7, 2022
1 parent ce6f9e4 commit 6356d40
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 54 deletions.
120 changes: 87 additions & 33 deletions src/qutip_qip/qir.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
from tempfile import NamedTemporaryFile
from typing import Union, overload, TYPE_CHECKING

if TYPE_CHECKING:
from typing_extensions import Literal

Expand All @@ -25,15 +26,14 @@
from qutip_qip.circuit import QubitCircuit
from qutip_qip.operations import Gate, Measurement

__all__ = [
"circuit_to_qir",
"QirFormat"
]
__all__ = ["circuit_to_qir", "QirFormat"]


class QirFormat(Enum):
"""
Specifies the format used to serialize QIR.
"""

#: Specifies that QIR should be encoded as LLVM bitcode (typically, files
#: ending in `.bc`).
BITCODE = auto()
Expand All @@ -44,7 +44,9 @@ class QirFormat(Enum):
MODULE = auto()

@classmethod
def ensure(cls, val : Union[Literal["bitcode", "text", "module"], QirFormat]) -> QirFormat:
def ensure(
cls, val: Union[Literal["bitcode", "text", "module"], QirFormat]
) -> QirFormat:
"""
Given a value, returns a value ensured to be of type `QirFormat`,
attempting to convert if needed.
Expand All @@ -56,38 +58,85 @@ def ensure(cls, val : Union[Literal["bitcode", "text", "module"], QirFormat]) ->

return cls(val)


# Specify return types for each different format, so that IDE tooling and type
# checkers can resolve the return type based on arguments.
@overload
def circuit_to_qir(circuit : QubitCircuit, format : Union[Literal[QirFormat.BITCODE], Literal["bitcode"]], module_name : str) -> bytes: ...
def circuit_to_qir(
circuit: QubitCircuit,
format: Union[Literal[QirFormat.BITCODE], Literal["bitcode"]],
module_name: str,
) -> bytes:
...


@overload
def circuit_to_qir(circuit : QubitCircuit, format : Union[Literal[QirFormat.TEXT], Literal["text"]], module_name : str) -> str: ...
def circuit_to_qir(
circuit: QubitCircuit,
format: Union[Literal[QirFormat.TEXT], Literal["text"]],
module_name: str,
) -> str:
...


@overload
def circuit_to_qir(circuit : QubitCircuit, format : Union[Literal[QirFormat.MODULE], Literal["module"]], module_name : str) -> pqp.QirModule: ...
def circuit_to_qir(
circuit: QubitCircuit,
format: Union[Literal[QirFormat.MODULE], Literal["module"]],
module_name: str,
) -> pqp.QirModule:
...


def circuit_to_qir(circuit, format, module_name="qutip_circuit"):
"""Converts a qubit circuit to its representation in QIR.
def circuit_to_qir(circuit, format, module_name = "qutip_circuit"):
"""
Given a circuit acting on qubits, generates a representation of that
circuit using Quantum Intermediate Representation (QIR).
:param circuit: The circuit to be translated to QIR.
:param format: The QIR serialization to be used. If `"text"`, returns a
Parameters
----------
circuit
The circuit to be translated to QIR.
format
The QIR serialization to be used. If `"text"`, returns a
plain-text representation using LLVM IR. If `"bitcode"`, returns a
dense binary representation ideal for use with other compilation tools.
If `"module"`, returns a PyQIR module object that can be manipulated
further before generating QIR.
:param module_name: The name of the module to be emitted.
module_name
The name of the module to be emitted.
Returns
-------
A QIR representation of `circuit`, encoded using the format specified by
`format`.
"""
# Define as an inner function to make it easier to call from conditional
# branches.
def append_operation(module: pqg.SimpleModule, builder: pqg.BasicQisBuilder, op: Gate):
def append_operation(
module: pqg.SimpleModule, builder: pqg.BasicQisBuilder, op: Gate
):
if op.classical_controls:
result = op.classical_controls[0]
value = "zero" if op.classical_control_value == 0 else "one"
# Pull off the first control and recurse.
op_with_less_controls = Gate(**op.__dict__)
op_with_less_controls.classical_controls = op_with_less_controls.classical_controls[1:]
branch_body = {value: (lambda: append_operation(module, builder, op_with_less_controls))}
op_with_less_controls.classical_controls = (
op_with_less_controls.classical_controls[1:]
)
op_with_less_controls.classical_control_value = (
(op_with_less_controls.classical_control_value[1:])
if op_with_less_controls.classical_control_value is not None
else None
)
branch_body = {
value: (
lambda: append_operation(
module, builder, op_with_less_controls
)
)
}
builder.if_result(module.results[result], **branch_body)
return

Expand All @@ -110,7 +159,13 @@ def append_operation(module: pqg.SimpleModule, builder: pqg.BasicQisBuilder, op:
elif op.name == "SNOT":
builder.h(module.qubits[op.targets[0]])
elif op.name in ("CNOT", "CX"):
builder.cx(module.qubits[op.controls[0]], module.qubits[op.targets[0]])
builder.cx(
module.qubits[op.controls[0]], module.qubits[op.targets[0]]
)
elif op.name == "CZ":
builder.cz(
module.qubits[op.controls[0]], module.qubits[op.targets[0]]
)
elif op.name == "RX":
builder.rx(op.control_value, module.qubits[op.targets[0]])
elif op.name == "RY":
Expand All @@ -119,14 +174,15 @@ def append_operation(module: pqg.SimpleModule, builder: pqg.BasicQisBuilder, op:
builder.rz(op.control_value, module.qubits[op.targets[0]])
elif op.name in ("CRZ", "TOFFOLI"):
raise NotImplementedError(
"Decomposition of CRZ and Toffoli gates into base " +
"profile instructions is not yet implemented."
"Decomposition of CRZ and Toffoli gates into base "
+ "profile instructions is not yet implemented."
)
else:
raise ValueError(
f"Gate {op.name} not supported by the basic QIR builder, " +
"and may require a custom declaration."
f"Gate {op.name} not supported by the basic QIR builder, "
+ "and may require a custom declaration."
)

fmt = QirFormat.ensure(format)

module = pqg.SimpleModule(module_name, circuit.N, circuit.num_cbits or 0)
Expand All @@ -141,19 +197,15 @@ def append_operation(module: pqg.SimpleModule, builder: pqg.BasicQisBuilder, op:
append_operation(module, builder, op)

elif isinstance(op, Measurement):
# TODO: Validate indices.
if op.name == "Z":
builder.m(module.qubits[op.targets[0]], module.results[op.classical_store])
else:
raise ValueError(
f"Measurement kind {op.name} not supported by the QIR " +
"base profile, and may require a custom declaration."
)
builder.m(
module.qubits[op.targets[0]],
module.results[op.classical_store],
)

else:
raise NotImplementedError(
f"Instruction {op} is not implemented in the QIR base " +
"profile and may require a custom declaration."
f"Instruction {op} is not implemented in the QIR base "
+ "profile and may require a custom declaration."
)

if fmt == QirFormat.TEXT:
Expand All @@ -162,7 +214,7 @@ def append_operation(module: pqg.SimpleModule, builder: pqg.BasicQisBuilder, op:
return module.bitcode()
elif fmt == QirFormat.MODULE:
bitcode = module.bitcode()
f = NamedTemporaryFile(suffix='.bc', delete=False)
f = NamedTemporaryFile(suffix=".bc", delete=False)
try:
f.write(bitcode)
finally:
Expand All @@ -174,4 +226,6 @@ def append_operation(module: pqg.SimpleModule, builder: pqg.BasicQisBuilder, op:
pass
return module
else:
assert False, "Internal error; should have caught invalid format enum earlier."
assert (
False
), "Internal error; should have caught invalid format enum earlier."
71 changes: 50 additions & 21 deletions tests/test_qir.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,48 +16,68 @@

T = typing.TypeVar("T")


def _assert_is_single(collection: typing.List[T]) -> T:
assert len(collection) == 1
return collection[0]

def _assert_arg_is_qubit(arg: pyqir.parser.QirOperand, idx: typing.Optional[int] = None):

def _assert_arg_is_qubit(
arg: pyqir.parser.QirOperand, idx: typing.Optional[int] = None
):
assert isinstance(arg, pyqir.parser.QirQubitConstant)
if idx is not None:
assert arg.id == idx

def _assert_arg_is_result(arg: pyqir.parser.QirOperand, idx: typing.Optional[int] = None):

def _assert_arg_is_result(
arg: pyqir.parser.QirOperand, idx: typing.Optional[int] = None
):
assert isinstance(arg, pyqir.parser.QirResultConstant)
if idx is not None:
assert arg.id == idx

def _assert_arg_is_double(arg: pyqir.parser.QirOperand, angle: typing.Optional[float] = None):

def _assert_arg_is_double(
arg: pyqir.parser.QirOperand, angle: typing.Optional[float] = None
):
assert isinstance(arg, pyqir.parser.QirDoubleConstant)
if angle is not None:
np.testing.assert_allclose(arg.value, angle)

def _assert_is_simple_qis_call(inst: pyqir.parser.QirInstr, gate_name: str, targets: typing.List[int]):

def _assert_is_simple_qis_call(
inst: pyqir.parser.QirInstr, gate_name: str, targets: typing.List[int]
):
assert isinstance(inst, pyqir.parser.QirQisCallInstr)
assert inst.func_name == f"__quantum__qis__{gate_name}__body"
assert len(inst.func_args) == len(targets)
for target, arg in zip(targets, inst.func_args):
_assert_arg_is_qubit(arg, target)

def _assert_is_rotation_qis_call(inst: pyqir.parser.QirInstr, gate_name: str, angle: float, target: int):

def _assert_is_rotation_qis_call(
inst: pyqir.parser.QirInstr, gate_name: str, angle: float, target: int
):
assert isinstance(inst, pyqir.parser.QirQisCallInstr)
assert inst.func_name == f"__quantum__qis__{gate_name}__body"
assert len(inst.func_args) == 2
angle_arg, target_arg = inst.func_args
_assert_arg_is_double(angle_arg, angle)
_assert_arg_is_qubit(target_arg, target)

def _assert_is_measurement_qis_call(inst: pyqir.parser.QirInstr, gate_name: str, target: int, result: int):

def _assert_is_measurement_qis_call(
inst: pyqir.parser.QirInstr, gate_name: str, target: int, result: int
):
assert isinstance(inst, pyqir.parser.QirQisCallInstr)
assert inst.func_name == f"__quantum__qis__{gate_name}__body"
assert len(inst.func_args) == 2
target_arg, result_arg = inst.func_args
_assert_arg_is_qubit(target_arg, target)
_assert_arg_is_result(result_arg, result)


class TestConverter:
"""
Test suite that checks that conversions from circuits to QIR produce
Expand All @@ -75,13 +95,15 @@ def test_simple_x_circuit(self):
"""
circuit = QubitCircuit(1)
circuit.add_gate("X", targets=[0])
parsed_qir_module: pyqir.parser.QirModule = qir.circuit_to_qir(circuit, format=qir.QirFormat.MODULE)
parsed_qir_module: pyqir.parser.QirModule = qir.circuit_to_qir(
circuit, format=qir.QirFormat.MODULE
)
parsed_func = _assert_is_single(parsed_qir_module.entrypoint_funcs)
assert parsed_func.required_qubits == 1
assert parsed_func.required_results == 0
parsed_block = _assert_is_single(parsed_func.blocks)
inst = _assert_is_single(parsed_block.instructions)
_assert_is_simple_qis_call(inst, 'x', [0])
_assert_is_simple_qis_call(inst, "x", [0])

def test_simple_cnot_circuit(self):
"""
Expand All @@ -90,13 +112,15 @@ def test_simple_cnot_circuit(self):
"""
circuit = QubitCircuit(2)
circuit.add_gate("CX", targets=[1], controls=[0])
parsed_qir_module: pyqir.parser.QirModule = qir.circuit_to_qir(circuit, format=qir.QirFormat.MODULE)
parsed_qir_module: pyqir.parser.QirModule = qir.circuit_to_qir(
circuit, format=qir.QirFormat.MODULE
)
parsed_func = _assert_is_single(parsed_qir_module.entrypoint_funcs)
assert parsed_func.required_qubits == 2
assert parsed_func.required_results == 0
parsed_block = _assert_is_single(parsed_func.blocks)
inst = _assert_is_single(parsed_block.instructions)
_assert_is_simple_qis_call(inst, 'cnot', [0, 1])
_assert_is_simple_qis_call(inst, "cnot", [0, 1])

def test_simple_rz_circuit(self):
"""
Expand All @@ -105,13 +129,15 @@ def test_simple_rz_circuit(self):
"""
circuit = QubitCircuit(1)
circuit.add_gate("RZ", targets=[0], control_value=0.123)
parsed_qir_module: pyqir.parser.QirModule = qir.circuit_to_qir(circuit, format=qir.QirFormat.MODULE)
parsed_qir_module: pyqir.parser.QirModule = qir.circuit_to_qir(
circuit, format=qir.QirFormat.MODULE
)
parsed_func = _assert_is_single(parsed_qir_module.entrypoint_funcs)
assert parsed_func.required_qubits == 1
assert parsed_func.required_results == 0
parsed_block = _assert_is_single(parsed_func.blocks)
inst = _assert_is_single(parsed_block.instructions)
_assert_is_rotation_qis_call(inst, 'rz', 0.123, 0)
_assert_is_rotation_qis_call(inst, "rz", 0.123, 0)

def test_teleport_circuit(self):
# NB: this test is a bit detailed, as it checks metadata throughout
Expand All @@ -128,7 +154,9 @@ def test_teleport_circuit(self):
circuit.add_gate("X", targets=[there], classical_controls=[0])
circuit.add_gate("Z", targets=[there], classical_controls=[1])

parsed_qir_module: pyqir.parser.QirModule = qir.circuit_to_qir(circuit, format=qir.QirFormat.MODULE)
parsed_qir_module: pyqir.parser.QirModule = qir.circuit_to_qir(
circuit, format=qir.QirFormat.MODULE
)
parsed_func = _assert_is_single(parsed_qir_module.entrypoint_funcs)
assert parsed_func.required_qubits == 3
assert parsed_func.required_results == 2
Expand All @@ -141,7 +169,6 @@ def assert_readresult(inst, result: int):
_assert_arg_is_result(arg, result)
return inst.output_name


entry = parsed_func.blocks[0]
then = parsed_func.blocks[1]
else_ = parsed_func.blocks[2]
Expand All @@ -155,13 +182,15 @@ def assert_readresult(inst, result: int):
# others names are semantically relevant, and thus can change
# without that being a breaking change.
assert entry.name == "entry"
_assert_is_rotation_qis_call(entry.instructions[0], 'rz', 0.123, msg)
_assert_is_simple_qis_call(entry.instructions[1], 'h', [here])
_assert_is_simple_qis_call(entry.instructions[2], 'cnot', [here, there])
_assert_is_simple_qis_call(entry.instructions[3], 'cnot', [msg, here])
_assert_is_simple_qis_call(entry.instructions[4], 'h', [msg])
_assert_is_measurement_qis_call(entry.instructions[5], 'mz', msg, 0)
_assert_is_measurement_qis_call(entry.instructions[6], 'mz', here, 1)
_assert_is_rotation_qis_call(entry.instructions[0], "rz", 0.123, msg)
_assert_is_simple_qis_call(entry.instructions[1], "h", [here])
_assert_is_simple_qis_call(
entry.instructions[2], "cnot", [here, there]
)
_assert_is_simple_qis_call(entry.instructions[3], "cnot", [msg, here])
_assert_is_simple_qis_call(entry.instructions[4], "h", [msg])
_assert_is_measurement_qis_call(entry.instructions[5], "mz", msg, 0)
_assert_is_measurement_qis_call(entry.instructions[6], "mz", here, 1)
cond_label = assert_readresult(entry.instructions[7], 0)
term = entry.terminator
assert isinstance(term, pyqir.parser.QirCondBrTerminator)
Expand Down

0 comments on commit 6356d40

Please sign in to comment.