In [42]:
import numpy as np
import pyqir
import qualtran
from qualtran.bloqs.basic_gates import Hadamard, XGate, YGate, ZGate, Rx, Ry, Rz, TGate, TwoBitSwap, CNOT, Toffoli
from qualtran.bloqs.qft.qft_text_book import QFTTextBook
import pyqir
from pyqir import (
    BasicBlock,
    Builder,
    Context,
    Function,
    Linkage,
    Module,
    ModuleFlagBehavior,
)

In [60]:
import attrs
from typing import Dict
from qualtran import Bloq, BloqBuilder, SoquetT, Signature, Register


@attrs.frozen
class Swap(Bloq):
    n: int

    @property
    def signature(self):
        return Signature.build(x=self.n, y=self.n)

    def build_composite_bloq(
            self, bb: BloqBuilder, *, x: SoquetT, y: SoquetT
    ) -> Dict[str, SoquetT]:
        xs = bb.split(x)
        ys = bb.split(y)

        for i in range(self.n):
            xs[i], ys[i] = bb.add(CNOT(), ctrl=xs[i], target=ys[i])
        return {
            'x': bb.join(xs),
            'y': bb.join(ys),
        }

@attrs.frozen
class ExampleBaseBloq(Bloq):
    @property
    def signature(self):
        return Signature.build(x=1, y=1)

@attrs.frozen
class ExampleHighLevelBloq(Bloq):
    n: int

    @property
    def signature(self):
        return Signature.build(x=self.n, y=self.n)

    def build_composite_bloq(
            self, bb: BloqBuilder, *, x: SoquetT, y: SoquetT
    ) -> Dict[str, SoquetT]:
        xs = bb.split(x)
        ys = bb.split(y)

        for i in range(self.n):
            xs[i], ys[i] = bb.add(ExampleBaseBloq(), x=xs[i], y=ys[i])
        return {
            'x': bb.join(xs),
            'y': bb.join(ys),
        }

@attrs.frozen
class ExampleNonTrivialShapeBloq(Bloq):
    n: int

    @property
    def signature(self):
        return Signature([
            Register('x', bitsize=1, shape=(self.n,)),
            Register('y', bitsize=1, shape=(self.n,)),
        ])

    def build_composite_bloq(
            self, bb: BloqBuilder, *, x: SoquetT, y: SoquetT
    ) -> Dict[str, SoquetT]:
        for i in range(self.n):
            x[i], y[i] = bb.add(ExampleBaseBloq(), x=x[i], y=y[i])
        return {'x': x, 'y': y}

In [3]:
PYQIR_OP_MAP = {
    # Single-Qubit Clifford Gates
    Hadamard: pyqir._native.h,
    XGate: pyqir._native.x,
    YGate: pyqir._native.y,
    ZGate: pyqir._native.z,
    # Single-Qubit Rotation Gates
    Rx: pyqir._native.rx,
    Ry: pyqir._native.ry,
    Rz: pyqir._native.rz,
    # Single-Qubit Non-Clifford Gates
    TGate: pyqir._native.t,
    # Two-Qubit Gates
    TwoBitSwap: pyqir._native.swap,
    CNOT: pyqir._native.cx,
    # Three-Qubit Gates
    Toffoli: pyqir._native.ccx,
}

In [110]:
def get_num_qubits_for_bloq(bloq: qualtran.Bloq):
    num_qubits = 0
    for register in bloq.signature.lefts():
        shape = register.shape[0] if len(register.shape) != 0 else 1
        num_qubits += register.bitsize*shape
    return num_qubits

def create_func_for_bloq(bloq: qualtran.Bloq, name, qubit_type, void_type, mod: pyqir.Module):
    num_qubits = get_num_qubits_for_bloq(bloq)
    return Function(
    pyqir.FunctionType(void_type, [qubit_type]*num_qubits),
    Linkage.EXTERNAL,
    name,
    mod
    )

def create_ir_map(bloq: qualtran.Bloq):
    param_counter = 0
    ir_map = {}
    # Loop through all registers in signature
    for register in bloq.signature.lefts():
        shape = register.shape[0] if len(register.shape) != 0 else 1 # get the shape as an int (we are asumming its 1d for simplicity)
        # map the (register_name, index_in_register) to the overall index
        ir_map.update({(register.name, i*register.bitsize + j): param_counter + i*register.bitsize + j for i in range(shape) for j in range(register.bitsize)})
        param_counter += register.bitsize * shape
    return ir_map


def find_ir_for_index(i, ir_map):
    for key, val in ir_map.items():
        if i == val:
            return key
    raise Exception(f"Error in find_ir_for_index: No IR for index {i}")

def get_indexes_from_reg(ir_map, target_reg_name):
    return [ir_map[(reg_name, i)] for (reg_name, i) in ir_map.keys() if reg_name == target_reg_name]

def generate_irs_from_soquet(soq):
    reg_name = soq.reg.name
    starting_index = soq.idx[0] if len(soq.idx) != 0 else 0
    return [(reg_name, starting_index*soq.reg.bitsize + i) for i in range(soq.reg.bitsize)]

def map_soquet_to_params(soq, soq_map, ir_map):
    if isinstance(soq, np.ndarray):
        # renaming to soqs for clarity
        soqs = soq
        return [el for soq in soqs for el in map_soquet_to_params_helper(soq, soq_map, ir_map)]
    return map_soquet_to_params_helper(soq, soq_map, ir_map)

def map_soquet_to_params_helper(soquet, soq_map, ir_map):
    if soquet in soq_map:
        irs = soq_map[soquet]
        return [ir_map[ir] for ir in irs]
    irs = generate_irs_from_soquet(soquet)
    return [ir_map[ir] for ir in irs]

def compile_bloq(bloq: qualtran.Bloq, qubit_type, void_type, module, context, builder, func_dict=dict()):
    func_name = f"{bloq.pretty_name()}_{get_num_qubits_for_bloq(bloq)}"
    if func_name in func_dict:
        return func_dict[func_name]
    bloq_func = create_func_for_bloq(bloq, func_name, qubit_type, void_type, module)
    ir_map = create_ir_map(bloq)
    soq_map = {}

    # It seems like QFTTextBook and PhaseGradientUnitary can be decomposed even though supports_decompose_bloq returns false
    if not bloq.supports_decompose_bloq() and bloq.pretty_name() not in ('QFTTextBook', 'PhaseGradientUnitary'):
        return bloq_func, ir_map

    # Create a block to insert instructions in
    basic_block = BasicBlock(context, "block1", bloq_func)

    # iterate through the DAG of sub bloqs
    for sub_bloq in bloq.decompose_bloq().iter_bloqsoqs():
        bloq_instance, inputs, outputs = sub_bloq
        if bloq_instance.bloq.short_name() == 'Split':
            reg_name = inputs['reg'].reg.name
            for i in range(len(outputs[0])):
                ir = soq_map[inputs['reg']][i] if inputs['reg'] in soq_map else (reg_name, i)
                soq_map[outputs[0][i]] = [ir]
            continue
        if bloq_instance.bloq.short_name() == 'Join':
            irs = [el for soq in inputs['reg'] for el in soq_map[soq]]
            soq_map[outputs[0]] = irs
            continue
        if type(bloq_instance.bloq) in PYQIR_OP_MAP:
            builder.insert_at_end(basic_block)
            param_indexes = [map_soquet_to_params(soquet, soq_map, ir_map)[0] for soquet in inputs.values()] # always one element soquets in elem bloqs
            params = [bloq_func.params[i] for i in param_indexes]
            PYQIR_OP_MAP[type(bloq_instance.bloq)](builder, *params)
        else:
            param_indexes = [None for _ in range(get_num_qubits_for_bloq(bloq_instance.bloq))]

            sub_bloq_func, sub_bloq_ir_map = compile_bloq(bloq_instance.bloq, qubit_type, void_type, module, context, builder, func_dict)
            sub_func_name = f"{bloq_instance.bloq.pretty_name()}_{get_num_qubits_for_bloq(bloq_instance.bloq)}"
            func_dict[sub_func_name] = sub_bloq_func, sub_bloq_ir_map
            for key in inputs.keys():
                caller_param_indices_for_key = map_soquet_to_params(inputs[key], soq_map, ir_map)
                callee_param_indices_for_key = get_indexes_from_reg(sub_bloq_ir_map, key)
                for param_index, qubit_param_index in list(zip(callee_param_indices_for_key, caller_param_indices_for_key)):
                    param_indexes[param_index] = qubit_param_index

            builder.insert_at_end(basic_block)
            params = [bloq_func.params[i] for i in param_indexes]
            builder.call(sub_bloq_func, params)

        # map soquets to ir
        param_index = 0
        for soq in outputs:
            if not isinstance(soq, np.ndarray):
                soq = np.array([soq])
            for single_soq in soq:
                irs = []
                for i in range(single_soq.reg.bitsize):
                    irs.append(find_ir_for_index(param_indexes[param_index], ir_map))
                    param_index+=1
                soq_map[single_soq] = irs

    return bloq_func, ir_map


def convert_qualtran(bloq: qualtran.Bloq):
    context = Context()
    mod = pyqir.qir_module(
        context,
        "Main"
    )
    builder = Builder(context)
    qubit_type = pyqir.qubit_type(context)
    void_type = pyqir.Type.void(context)
    entry = pyqir.entry_point(
            mod, "main", get_num_qubits_for_bloq(bloq), 0
        )
    entry_block = BasicBlock(context, "entry", entry)
    builder.insert_at_end(entry_block)
    qubits = [pyqir.qubit(context, n) for n in range(get_num_qubits_for_bloq(bloq))]
    bloq_func, _ = compile_bloq(bloq, qubit_type, void_type, mod, context, builder)
    builder.insert_at_end(entry_block)
    builder.call(bloq_func, qubits)
    return mod

qft_bloq = QFTTextBook(20)
qft_mod = convert_qualtran(qft_bloq)

In [111]:
print(qft_mod)

; ModuleID = 'Main'
source_filename = "Main"

%Qubit = type opaque

define void @main() #0 {
entry:
  call void @QFTTextBook_20(%Qubit* null, %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 3 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*), %Qubit* inttoptr (i64 5 to %Qubit*), %Qubit* inttoptr (i64 6 to %Qubit*), %Qubit* inttoptr (i64 7 to %Qubit*), %Qubit* inttoptr (i64 8 to %Qubit*), %Qubit* inttoptr (i64 9 to %Qubit*), %Qubit* inttoptr (i64 10 to %Qubit*), %Qubit* inttoptr (i64 11 to %Qubit*), %Qubit* inttoptr (i64 12 to %Qubit*), %Qubit* inttoptr (i64 13 to %Qubit*), %Qubit* inttoptr (i64 14 to %Qubit*), %Qubit* inttoptr (i64 15 to %Qubit*), %Qubit* inttoptr (i64 16 to %Qubit*), %Qubit* inttoptr (i64 17 to %Qubit*), %Qubit* inttoptr (i64 18 to %Qubit*), %Qubit* inttoptr (i64 19 to %Qubit*))
}

define void @QFTTextBook_20(%Qubit* %0, %Qubit* %1, %Qubit* %2, %Qubit* %3, %Qubit* %4, %Qubit* %5, %Qubit* %6, %Qubit* %7, %Qubit* %8, %Qubit