In [96]:
from pyquil import Program,get_qc
from pyquil.api import WavefunctionSimulator
from pyquil.wavefunction import get_bitstring_from_index
import numpy as np
from IPython.display import Latex,display_latex
from ipywidgets import interact
from ipywidgets.widgets.interaction import interactive
from ipywidgets import IntSlider

class ProgramOutput:
    def __init__(self,P=""):
        self.qc = get_qc("9q-square-qvm")
        self.wfs = WavefunctionSimulator()
        P = "\n".join([f"DECLARE s{i} BIT[32]" for i in range(32)]) + P
        P = Program(P)
        P = self.qc.compile(P)
        P = str(P)
        P = P.split("\n")
        self.Program = P
        self.end_line = len(P)
        self.Program = [f"I {i}" for i in range(9)] + self.Program
        #self.Program = [f"DECLARE s{i} BIT[32]" for i in range(32)] + self.Program

    def run(self,end_line=None):
        if end_line is None: end_line = len(self.Program)-41
        P = self.Program[:41 + end_line]
        P = "\n".join(P)
        P = Program(P)
        self.outqc = self.qc.run(P).get_register_map()
        self.outwfs = self.wfs.wavefunction(P)
        self.end_line = end_line
    
    def step(self):
        self.end_line = min(self.end_line + 1,len(self.Program)-41)
        self.run(self.end_line)

    def __repr__(self):
        s = [""]
        state = str(self.outwfs)
        if len(state) < 100:
            s.append("State :\t" + state)    
        else:
            s.append("State :\t" + state[:30] + " .... " + state[-30:])
        psi = [self.outwfs[i] for i in range(2**9)]
        s.append("Psi :\t " + str(psi))
        # Registers
        s.append("\nRegisters\n")
        for i in range(32):
            reg = f"s{i}"
            regval = self.outqc[reg][0]
            regval_bin = "".join([str(x) for x in regval])
            regval_int = sum([regval[i]*2**(31-i) for i in range(32)])
            if regval_int > 2**32 : regval_int = regval_int - 2**32
            regval_int = str(regval_int)
            reg_string = reg + "\t : 0b" +  regval_bin + " = " + regval_int
            s.append(reg_string)
        return "\n".join(s)

    def display(self,states=None,regs=range(32)):
        psi = [str(np.round(self.outwfs[i],4)) for i in range(2**9)]
        if states is None:
            psi = [psi[i] + r"& \ket {" + str(i) + "}" for i in range(2**9) if psi[i]!="0j"]
            if len(psi) > 32:
                psi = psi[:15] + [r"\vdots"]*2 +  psi[-15:] 
        else:
            psi = [psi[i] + r"& \ket {" + str(i) + "}" for i in states] + [r"\vdots"]
        psi = r"\begin{bmatrix}" + (r" \\ ").join(psi) + r"\end{bmatrix}"
        #psi = r"\begin{matrix} +"+r"\rangle \\".join(str(self.outwfs).split(">"))+r"\end{matrix}"
        s = [r"\text{register} & \text{binary value} & \text{signed decimal}"]
        for i in regs:
            reg = f"s{i}"
            regval = self.outqc[reg][0]
            regval_bin = "".join([str(x) for x in regval])
            regval_int = sum([regval[i]*2**(31-i) for i in range(32)])
            if regval_int > 2**32 : regval_int = regval_int - 2**32
            regval_int = str(regval_int)
            reg_string = reg + " & " +  regval_bin + " & " + regval_int
            s.append(reg_string)
        R = r"\begin{bmatrix}" + r"\\ ".join(s) +r" \end{bmatrix}"
        P = self.Program[41:]
        ran = P[:self.end_line]
        not_ran = P[self.end_line:]
        ran = [r"\red{\text{" + x + r"}}" for x in ran]
        not_ran = [r"\text{" + x + r"}" for x in not_ran]
        P = ran + not_ran
        P = r"\\ ".join(P)
        P = r"\begin{bmatrix}" + P + r"\end{bmatrix}"
        
        Full = r"$$ \begin{bmatrix}" + r"P & \ket\Psi & C \\" + r" &".join([P,psi,R])+ r"\end{bmatrix} $$"
        Full = Latex(Full)
        #display_latex(Full)
        return Full

    def run_and_display(self,i=None):
        self.run(i)
        display_latex(self.display())

    def interact(self):
        return interactive(self.run_and_display,{"manual":True},i=IntSlider(min=0, max=len(self.Program)-41, step=1))



In [98]:
P = """
H 0
H 1
H 2
MEASURE 0 s0[1]
"""
PO = ProgramOutput(P)
PO.run_and_display(5)