from tket.circuit import Tk2Circuit
from tket.pattern import Rule, RuleMatcher

import sympy
from pytket import Circuit as Tk1Circuit
from typing import Any
from tket.passes import lower_to_pytket

from guppylang import guppy
from guppylang.std.quantum import qubit, cx, rz
from guppylang.std.qsystem import zz_phase
from guppylang.std.angles import angle
from guppylang.std.builtins import owned

def guppy_to_circuit(func_def: Any) -> Tk2Circuit:
    pkg = func_def.compile_function()
    f_name = pkg.modules[0].entrypoint_op().f_name

    byt = pkg.to_bytes()
    circ = Tk2Circuit.from_bytes(byt, f_name)

    return lower_to_pytket(circ)

@guppy
def cnots_to_zzphase_lhs(
    q0: qubit @ owned, q1: qubit @ owned, angle: angle
) -> tuple[qubit, qubit]:
    cx(q0, q1)
    rz(q1, angle)
    cx(q0, q1)
    return (q0, q1)

@guppy
def cnots_to_zzphase_rhs(
    q0: qubit @ owned, q1: qubit @ owned, angle: angle
) -> tuple[qubit, qubit]:
    zz_phase(q0, q1, angle)
    return (q0, q1)

guppy_cnots_to_zzphase_lhs = guppy_to_circuit(cnots_to_zzphase_lhs)
guppy_cnots_to_zzphase_rhs = guppy_to_circuit(cnots_to_zzphase_rhs)

theta = sympy.Symbol("theta")
pattern_lhs = Tk1Circuit(2).CX(0, 1).Rz(theta, 1).CX(0, 1)
pattern_rhs = Tk1Circuit(2).ZZPhase(theta, 0, 1)

tket_cnots_to_zzphase_lhs = Tk2Circuit(pattern_lhs)
tket_cnots_to_zzphase_rhs = Tk2Circuit(pattern_rhs)

guppy_cnots_to_zzphase = Rule(l=guppy_cnots_to_zzphase_lhs, r=guppy_cnots_to_zzphase_rhs)
guppy_cnots_matcher = RuleMatcher([guppy_cnots_to_zzphase])

tket_cnots_to_zzphase = Rule(l=tket_cnots_to_zzphase_lhs, r=tket_cnots_to_zzphase_rhs)
tket_cnots_matcher = RuleMatcher([tket_cnots_to_zzphase])

original_circ = (
    Tk1Circuit(3).CX(1, 2).Rz(1 / 4, 2).CX(1, 2).CX(0, 1).Rz(1 / 2, 1).CX(0, 1)
)
merged_circ = Tk2Circuit(original_circ)
print("Original circuit:")
print(merged_circ.render_mermaid())

while rewrite := guppy_cnots_matcher.find_match(merged_circ):
    merged_circ.apply_rewrite(rewrite)

print("Rewritten circuit (guppy):")
print(merged_circ.render_mermaid())

merged_circ = Tk2Circuit(original_circ)

while rewrite := tket_cnots_matcher.find_match(merged_circ):
    merged_circ.apply_rewrite(rewrite)

print("Rewritten circuit (tket):")
print(merged_circ.render_mermaid())
