In [1]:
import vpython
from vpython import canvas, color, vector, sphere, box, cylinder, arrow, cone, text

<IPython.core.display.Javascript object>

In [2]:

def draw_axes(L=100,w=1):
    # Create X, Y, Z axes using arrows
    arrow(pos=vector(0,0,0), axis=vector(1,0,0), color=color.red, shaftwidth=w, length=L, label="x")
    arrow(pos=vector(0,0,0), axis=vector(0,1,0), color=color.green, shaftwidth=w, length=L, label="y")
    arrow(pos=vector(0,0,0), axis=vector(0,0,1), color=color.blue, shaftwidth=w, length=L, label="z")

def print_camera_info(scene):
    print(f"pos = {scene.camera.pos}")
    print(f"axis = {scene.axis}")
    print(f"up = {scene.camera.up}")
    print(f"center = {scene.center}")
    

def draw_line(start, end, color, radius=0.6):
    cylinder(pos = start, axis = end-start, radius = radius, color=color)
    

def cylinder_arrow(pos, axis, color, radius=0.9):
    end = pos + axis
    shaft_axis = 0.8*axis
    shaft_end = pos + shaft_axis
    cylinder(pos = pos, axis = shaft_axis, radius = radius, color=color)
    cone(pos = shaft_end, axis = end-shaft_end, radius = 1.8*radius, color=color)


In [3]:
# create square coupler surface code

flag_data_color = vpython.color.yellow
data_flag_color = vpython.color.blue
measure_flag_color = vpython.color.red
flag_measure_color = vpython.color.green
const_E_color = vpython.color.black

data_color = vpython.color.orange
flag_color = vpython.color.white
measure_color = vpython.color.black

text_color = vector(1,1,1)*1


def draw_E_field_directions(pos, size = 30):
    
    cylinder_arrow(pos=pos, axis=vector(0,0,size), color = const_E_color)
    cylinder_arrow(pos=pos, axis=vector(size,0,0), color = flag_measure_color)
    cylinder_arrow(pos=pos, axis=vector(0,size,0), color = data_flag_color)
    cylinder_arrow(pos=pos, axis=vector(-size,0,0), color = measure_flag_color)
    cylinder_arrow(pos=pos, axis=vector(0,-size,0), color = flag_data_color)
    
    up = vector(0,0,1)
    text_height = 6
    text(text="E0", pos=pos+vector(0,0,size+1.5), height=text_height, up=up, color=text_color)
    text(text="E1", pos=pos+vector(size,-text_height/2,0), height=text_height, up=vector(0,1,0), color=text_color)
    text(text="E2", pos=pos+vector(0,size,0), height=text_height, up=vector(0,1,0), color=text_color)
    text(text="E3", pos=pos+vector(-size-text_height,-text_height/2,0), height=text_height, up=vector(0,1,0), color=text_color)
    text(text="E4", pos=pos+vector(0,-size-text_height,0), height=text_height, up=vector(0,1,0), color=text_color)

    

def draw_qubit_legend(pos, q_radius=4.5, sep = 30, inactive_measure=False):
    text_up = vector(1,0,0)
    text_height = 6
    text_offset = vector(-6,-9,0)
    
    d_pos = pos+vector(sep,0,0)
    sphere(pos=d_pos, color=data_color, radius=q_radius)
    text(text="Data (1P)", pos=d_pos+text_offset, up=text_up, height=text_height, color=text_color)
    
    m_pos=pos
    sphere(pos = m_pos, color=measure_color, radius=q_radius)
    text(text="Measure (1P)", pos=m_pos+text_offset, up=text_up, height=text_height, color=text_color)
    
    f_pos=pos+vector(-sep,0,0)
    sphere(pos = f_pos, color=flag_color, radius=q_radius)
    text(text="Flag (2P)", pos=f_pos+text_offset, up=text_up, height=text_height, color=text_color)
    
    if inactive_measure:
        im_pos = pos+vector(-2*sep,0,0)
        sphere(pos = im_pos, color=measure_color, radius=q_radius, opacity=0.2)
        text(text="Measure (1P+)", pos=im_pos+text_offset, up=text_up, height=text_height, color=text_color)

def draw_SET(x, y, z, length = 13, height=3):
    box(pos=vector(x,y,z), length=length, width=height, height=length)
    
    
def draw_SETs(N, sep=5, side_length=13, height=2):
    for i in range(N-1):
        for j in range(N-1):
            if i%4 == 1 and j%4==1 or i%4==3 and j%4==3:
                x = i*sep
                y = j*sep 
                z = 0
                draw_SET(x, y, z, side_length, height=height)
            
            
def draw_wire(orientation, pos, length, radius=5, color=color.blue, z=20):
    if orientation == 'x':
        axis = vector(length,0,z)
        pos = vector(0,pos,z)
    elif orientation == 'y':
        axis = vector(0,length,0)
        pos = vector(pos,0,-z)
    else:
        raise Exception("Invalid orientation: must be 'x' or 'y'.")
    cylinder(pos = pos, axis = axis, radius = radius, color=color, opacity=0.3)
    
            
def draw_wires(N, sep, radius=2.5):
    
    for i in range(N):
        if i%2 == 1:
            draw_wire('y', i*sep, (N-1)*sep, radius=radius, color=color.white)
            draw_wire('x', i*sep, (N-1)*sep, radius=radius, color=color.white)
            
    
def draw_control_apparatus(N, sep = 5, wire_radius=5, SET_length=9, SET_height=1.5):
    
    draw_SETs(N, sep=sep, side_length=SET_length, height=SET_height)
    draw_wires(N, sep, radius=wire_radius)
    
def draw_data_and_flag_qubits(N, sep, qubit_radius=3):
    for i in range(N):
        for j in range(N):
            if i%2 == 1 and j%2==1:
                # middle of heavy square
                continue
            x = i*sep
            y = j*sep
            z = sep/6
            if i%2==0 and j%2==0:
                # vertex
                qubit_color = flag_color
            else:
                # edge
                z*=-1
                if i%2 == 0:
                    # vertical edge / data qubit
                    qubit_color = data_color
                else:
                    # vertical edge / data qubit
                    continue
            pos = vector(x, y, z)
            sphere(pos=pos, radius=qubit_radius, color=qubit_color)
            
def draw_measure_1_qubits(N, sep, qubit_radius=3, qubit_color = measure_color, opacity=1):
    for i in range(N):
        for j in range(N):
            
            if i%2 == 1 and j%2==1:
                # middle of heavy square
                continue
            if i%2 == 1 and j%2==0 and (i+j)%4==1:
                # horizontal edge / measure qubit
                x = i*sep
                y = j*sep
                z = -sep/6
                pos = vector(x, y, z)
                color = measure_color
                sphere(pos=pos, radius=qubit_radius, color=qubit_color, opacity=opacity)
                
def draw_measure_2_qubits(N, sep, qubit_radius=3, qubit_color = measure_color, opacity=1):
    for i in range(N):
        for j in range(N):
            
            if i%2 == 1 and j%2==1 and (i+j)%4==3:
                # middle of heavy square
                continue
            if i%2 == 1 and j%2==0:
                # horizontal edge / measure qubit
                x = i*sep
                y = j*sep
                z = -sep/6
                pos = vector(x, y, z)
                color = measure_color
                sphere(pos=pos, radius=qubit_radius, color=qubit_color, opacity=opacity)
    
    
    
def draw_qubits(N, sep, qubits = ['M', 'F', 'D'], qubit_radius=3):
    
    if 'F' in qubits and 'D' in qubits:
        draw_data_and_flag_qubits(N, sep, qubit_radius=qubit_radius)
    
    if 'M' in qubits:
        draw_measure_1_qubits(N, sep, qubit_radius=qubit_radius)
        draw_measure_2_qubits(N, sep, qubit_radius=qubit_radius)
    elif 'M1' in qubits and 'M2' not in qubits:
        draw_measure_1_qubits(N, sep, qubit_radius=qubit_radius)
        draw_measure_2_qubits(N, sep, qubit_radius=qubit_radius, opacity=0.15)
    elif 'M2' in qubits or 'M1' not in qubits:
        draw_measure_2_qubits(N, sep, qubit_radius=qubit_radius)
        draw_measure_1_qubits(N, sep, qubit_radius=qubit_radius, opacity=0.15)
#     for i in range(N):
#         for j in range(N):
#             if i%2 == 1 and j%2==1:
#                 # middle of heavy square
#                 continue
#             x = i*sep
#             y = j*sep
#             z = sep/6
#             if i%2==0 and j%2==0:
#                 # vertex
#                 qubit_color = flag_color
#             else:
#                 # edge
#                 z*=-1
#                 if i%2 == 0:
#                     # horizontal edge / measure qubit
#                     qubit_color = data_color
#                 else:
#                     # vertical edge / data qubit
#                     qubit_color = measure_color
#             pos = vector(x, y, z)
#             color = measure_color
#             sphere(pos=pos, radius=qubit_radius, color=qubit_color)

    
def draw_edges(N, sep, edges=['MF', 'FM', 'FD', 'DF']):
    """
    Draws edges between qubits.
    
    Args:
        N (int): Length of code in qubits
        sep (float): qubit separation
        edges (List[float]): Edges to draw. MF indicates measure-flag edges, etc.
    """
    for i in range(N):
        for j in range(N):
            if i%2 == 1 and j%2==1:
                # middle of heavy square
                continue
            x = i*sep
            y = j*sep
            z = sep/6
            if i%2==0 and j%2==0:
                # vertex
                if i < N-1:
                    # draw connecting vertical lines
                    if 'FM' in edges or 'FM1' in edges and (i+j)%4==0 or 'FM2' in edges and (i+j)%4==2:
                        draw_line(vector(x,y,z), vector(x+sep, y, -z), color=flag_measure_color)
                if j < N-1 and 'DF' in edges:
                    # draw connecting vertical lines
                    draw_line(vector(x,y,z), vector(x, y+sep, -z), color=data_flag_color)
            else:
                # edge
                z*=-1
                if i%2 == 0:
                    # horizontal edge / measure qubit
                    if j < N-1 and 'FD' in edges:
                        # draw connecting vertical lines
                        draw_line(vector(x,y,z), vector(x, y+sep, -z), color=flag_data_color)
                else:
                    # vertical edge / data qubit
                    if i < N-1:
                        if 'MF' in edges or 'MF1' in edges and (i+j)%4==1 or 'MF2' in edges and (i+j)%4==3:
                            # draw connecting vertical lines
                            draw_line(vector(x,y,z), vector(x+sep, y, -z), color=measure_flag_color)
            pos = vector(x, y, z)
            radius = 3
            
    
    
def generate_code(N, sep = 15, radius=2.5, qubits = ['M', 'F', 'D'], edges=['MF', 'FM', 'DF', 'FD']):
    
    
    E_size = 24
    scene = canvas(title="Alternating height heavy square", width=2400, height=2400, background=vpython.color.white)
    
    draw_E_field_directions(vector(-2*E_size, N*sep/2, 0), size=E_size)
    
    inactive_measure = False if 'M' in qubits else True
    draw_qubit_legend(vector(N*sep/2,-18,0), inactive_measure=inactive_measure)
    
    # draw_control_apparatus(N, sep=sep, wire_radius=1)

    draw_qubits(N, sep, qubits = qubits, qubit_radius=2.5)
    draw_edges(N, sep, edges=edges)
    
    
    
    
    scene.camera.pos = vector(-280,-300,300)
    scene.camera.axis = vector(400,200,-200)
    scene.camera.up = vector(0.5,0.5,4)
    
#     scene.camera.pos = vector(200,100,300)
#     scene.camera.axis = vector(0,0,-200)
    
    
    print_camera_info(scene)
    
    

In [4]:
generate_code(9, qubits=['M1', 'F', 'D'], edges=['MF1', 'FM1', 'DF', 'FD'])

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

pos = <-280, -300, 300>
axis = <0.816497, 0.408248, -0.408248>
up = <0.5, 0.5, 4>
center = <120, -100, 100>


In [None]:
text_height=5

def draw_stabilizer_couplings(N=7, sep=15):
    scene = canvas(title="Stabilizer couplings", width=2400, height=2400, background=vpython.color.white)
    draw_data_and_flag_qubits(N, sep)
    draw_measure_1_qubits(N, sep)
    draw_measure_2_qubits(N, sep, qubit_radius=3, opacity=0.15)
    
    draw_edges(N, sep, ['MF1'])
    cylinder_arrow(vector(60,-15,0), vector(-40,0,0), color = measure_flag_color, radius=0.9)
    text(text="E3", pos=vector(10,-15-text_height/2,0), height=text_height, up=vector(0,1,0), color=text_color)
    
    scene.camera.pos = vector(620, 150,1600)
    scene.camera.axis = vector(0,0,-200)
    print_camera_info(scene)
    
draw_stabilizer_couplings()

print(f"data_flag_color = {data_flag_color}")
print(f"flag_data_color = {flag_data_color}")
print(f"flag_measure_color = {flag_measure_color}")
print(f"measure_flag_color = {measure_flag_color}")
