In [10]:
import plotly.graph_objects as go

def add_circle(fig, radius, centre_x, centre_y, fill="white", line="black", line_width=1, hover_text=None):
    # Calculate the circle's bounding box
    x0, y0 = centre_x - radius, centre_y - radius
    x1, y1 = centre_x + radius, centre_y + radius
    
    # Add the circle as a shape
    fig.add_shape(
        type="circle",
        xref="x", yref="y",
        x0=x0, y0=y0, x1=x1, y1=y1,
        fillcolor=fill,
        line=dict(color=line, width=line_width),
        opacity=1.0,  # Set opacity to make the fill color visible
    )
    
    # Add hover text as a scatter plot
    if hover_text:
        fig.add_trace(go.Scatter(
            x=[centre_x],
            y=[centre_y],
            mode='markers',
            marker=dict(size=0, opacity=0),
            text=[hover_text],
            hoverinfo='text'
        ))

def add_square(fig, radius, centre_x, centre_y, fill="white", line="black", line_width=1, hover_text=None):
    # Calculate the circle's bounding box
    x0, y0 = centre_x - radius, centre_y - radius
    x1, y1 = centre_x + radius, centre_y + radius
    
    # Add the circle as a shape
    fig.add_shape(
        type="rect",
        xref="x", yref="y",
        x0=x0, y0=y0, x1=x1, y1=y1,
        fillcolor=fill,
        line=dict(color=line, width=line_width),
        opacity=1.0,  # Set opacity to make the fill color visible
    )
    
    # Add hover text as a scatter plot
    if hover_text:
        fig.add_trace(go.Scatter(
            x=[centre_x],
            y=[centre_y],
            mode='markers',
            marker=dict(size=0, opacity=0),
            text=[hover_text],
            hoverinfo='text'
        ))

def draw_line(fig, coord1, coord2, color='blue', thickness=2):
    """
    Draw a line on a plotly figure between two coordinates.

    Parameters:
    fig (go.Figure): The figure to draw on.
    coord1 (tuple): The first coordinate as a tuple (x1, y1).
    coord2 (tuple): The second coordinate as a tuple (x2, y2).
    color (str): The color of the line.
    thickness (int): The thickness of the line.
    """
    fig.add_trace(
        go.Scatter(
            x=[coord1[0], coord2[0]],
            y=[coord1[1], coord2[1]],
            mode='lines',
            line=dict(color=color, width=thickness)
        )
    )

# Configuration for the figure (hiding axes, setting background, etc.)
def configure_figure(fig):
    fig.update_layout(
        paper_bgcolor='white',  # Set the background color to white
        plot_bgcolor='white',   # Set the plot area background color to white
        xaxis=dict(
            showgrid=False,     # Remove grid lines
            zeroline=False,     # Hide the zero line
            showline=False,     # Hide the axes lines
            showticklabels=False, # Hide the tick labels
        ),
        yaxis=dict(
            showgrid=False,     # Remove grid lines
            zeroline=False,     # Hide the zero line
            showline=False,     # Hide the axes lines
            showticklabels=False, # Hide the tick labels
        ),
        xaxis_scaleanchor='y',  # Ensure the aspect ratio is equal
        yaxis_scaleanchor='x'   # Ensure the aspect ratio is equal
    )

# # Example usage
# fig = go.Figure()
# add_circle(fig, radius=0.2, centre_x=0, centre_y=0, fill="white", line="black", line_width=3, hover_text="This is a circle")
# add_square(fig, radius=0.2, centre_x=0, centre_y=1, fill="white", line="black", line_width=3, hover_text="This is a circle")

# configure_figure(fig)
# fig.update_xaxes(range=[-2, 2])
# fig.update_yaxes(range=[-2, 2])
# # fig.show()


import qec.codes

# Example usage
fig = go.Figure()

class plotly_shape:
    def __init__(self):
        self.shape_type = None
        self.x = None
        self.y = None
        self.label = None
        self.id = None

def draw_surface_code(fig, dx: int, dz: int, spacing: float = 1, radius: float = 1):

    sc = qec.codes.SurfaceCode(dx, dz)
    hx = sc.hx.toarray()
    hz = sc.hz.toarray()

    draw_qubits = []
    qubit_id = 0
    y_shift = dx*spacing
    for i in range(0, dz):
        for j in range(0, dx):

            qubit_id += 1

            ps = plotly_shape()
            ps.id = qubit_id
            ps.shape_type = "circle"
            ps.x = j*spacing
            ps.y = y_shift - i*spacing
            ps.label = f"Qubit ({ps.id})"

            draw_qubits.append(ps)
            

            # add_circle(fig, radius=0.2, centre_x=j*spacing, centre_y=y_shift - i*spacing, fill="white", line="black", line_width=3, hover_text=f"Qubit ({qubit_id})")


    y_shift = dx*spacing
    for i in range(0, dz-1):
        for j in range(0, dx-1):
            qubit_id += 1
            ps = plotly_shape()
            ps.id = qubit_id
            ps.shape_type = "circle"
            ps.x = j*spacing + 0.5*spacing
            ps.y = y_shift - i*spacing - 0.5*spacing
            ps.label = f"Qubit ({ps.id})"
            draw_qubits.append(ps)
            # add_circle(fig, radius=0.2, centre_x=j*spacing+0.5*spacing, centre_y=y_shift - i*spacing - 0.5*spacing, fill="white", line="black", line_width=3, hover_text=f"Qubit ({qubit_id})")

    # for qubit in draw_qubits:
    #     add_circle(fig, radius=0.2, centre_x=qubit.x, centre_y=qubit.y, fill="white", line="black", line_width=3, hover_text=qubit.label)

    draw_checks = []
    check_id = 0

    y_shift = dx*spacing
    for i in range(0, dz):
        for j in range(0, dx-1):
            check_id += 1
            ps = plotly_shape()
            ps.id = check_id
            ps.shape_type = "square"
            ps.x = j*spacing + 0.5*spacing
            ps.y = y_shift - i*spacing
            ps.label = f"Z Stabiliser ({ps.id})"
            draw_checks.append(ps)
            # add_circle(fig, radius=0.2, centre_x=j*spacing+0.5*spacing, centre_y=y_shift - i*spacing - 0.5*spacing, fill="white", line="black", line_width=3, hover_text=f"Qubit ({qubit_id})")


    y_shift = dx*spacing
    for i in range(0, dz-1):
        for j in range(0, dx):

            check_id += 1

            ps = plotly_shape()
            ps.id = check_id
            ps.shape_type = "square"
            ps.x = j*spacing
            ps.y = y_shift - i*spacing - 0.5*spacing
            ps.label = f"X Stabiliser ({ps.id})"

            draw_checks.append(ps)
            

            # add_circle(fig, radius=0.2, centre_x=j*spacing, centre_y=y_shift - i*spacing, fill="white", line="black", line_width=3, hover_text=f"Qubit ({qubit_id})")

    ## add z edges
    for col_idx in range(hz.shape[1]):
        for row_idx in range(hz.shape[0]):
            if hz[row_idx, col_idx] == 1:
                draw_line(fig, (draw_qubits[col_idx].x, draw_qubits[col_idx].y), (draw_checks[row_idx].x, draw_checks[row_idx].y), color='black', thickness=1)

    for qubit in draw_qubits:
        add_circle(fig, radius=radius, centre_x=qubit.x, centre_y=qubit.y, fill="white", line="black", line_width=2, hover_text=qubit.label)

    for check in draw_checks[:hz.shape[0]]:
        # print(check.x, check.y, check.label)
        add_square(fig, radius=radius, centre_x=check.x, centre_y=check.y, fill="white", line="black", line_width=2, hover_text=check.label)

    for check in draw_checks[hz.shape[0]:]:
        # print(check.x, check.y, check.label)
        add_square(fig, radius=radius, centre_x=check.x, centre_y=check.y, fill="blue", line="black", line_width=2, hover_text=check.label)



draw_surface_code(fig, 5, 5, 1, 0.1)
configure_figure(fig)
# fig.update_xaxes(range=[-2, 2])
# fig.update_yaxes(range=[-2, 2])
# fig.show()
fig.write_html("surface_code.html")

In [46]:
import svgwrite
import qec.codes

class Shape:
    def __init__(self, shape_type, x, y, label):
        self.shape_type = shape_type
        self.x = x
        self.y = y
        self.label = label

def draw_surface_code_svg(dx: int, dz: int, spacing: float = 50, radius: float = 5):
    sc = qec.codes.SurfaceCode(dx, dz)
    hx = sc.hx.toarray()
    hz = sc.hz.toarray()

    draw_qubits = []
    draw_checks = []

    for i in range(dz):
        for j in range(dx):
            x = j * spacing
            y = dz * spacing - i * spacing
            draw_qubits.append(Shape('circle', x, y, f"Qubit ({len(draw_qubits) + 1})"))

    for i in range(dz-1):
        for j in range(dx-1):
            x = j * spacing + 0.5 * spacing
            y = dz * spacing - i * spacing - 0.5 * spacing
            draw_qubits.append(Shape('circle', x, y, f"Qubit ({len(draw_qubits) + 1})"))

    for i in range(dz):
        for j in range(dx-1):
            x = j * spacing + 0.5 * spacing
            y = dz * spacing - i * spacing
            draw_checks.append(Shape('rect', x, y, f"Z Stabiliser ({len(draw_checks) + 1})"))

    for i in range(dz-1):
        for j in range(dx):
            x = j * spacing
            y = dz * spacing - i * spacing - 0.5 * spacing
            draw_checks.append(Shape('rect', x, y, f"X Stabiliser ({len(draw_checks) + 1})"))

    # Determine the minimum x and y coordinates
    min_x = min([shape.x for shape in draw_qubits + draw_checks]) - radius
    min_y = min([shape.y for shape in draw_qubits + draw_checks]) - radius

    # Adjust coordinates to ensure all elements are within the positive coordinate space
    for shape in draw_qubits + draw_checks:
        shape.x -= min_x
        shape.y -= min_y

    width = max([shape.x for shape in draw_qubits + draw_checks]) + radius
    height = max([shape.y for shape in draw_qubits + draw_checks]) + radius

    dwg = svgwrite.Drawing('surface_code.svg', profile='tiny', size=(width, height))

    # Draw lines first
    for col_idx in range(hz.shape[1]):
        for row_idx in range(hz.shape[0]):
            if hz[row_idx, col_idx] == 1:
                qubit_x, qubit_y = draw_qubits[col_idx].x, draw_qubits[col_idx].y
                check_x, check_y = draw_checks[row_idx].x, draw_checks[row_idx].y
                dwg.add(dwg.line(start=(qubit_x, qubit_y), end=(check_x, check_y), stroke='black', stroke_width=1))

    # Draw qubits
    for qubit in draw_qubits:
        circle = dwg.circle(center=(qubit.x, qubit.y), r=radius, fill='white', stroke='black', stroke_width=2)
        dwg.add(circle)

    # Draw stabilizers
    for check in draw_checks:
        fill_color = 'blue' if 'X' in check.label else 'white'
        rect = dwg.rect(insert=(check.x - radius, check.y - radius), size=(2 * radius, 2 * radius), fill=fill_color, stroke='black', stroke_width=2)
        dwg.add(rect)

    # Save to SVG file
    dwg.save()

# Create the SVG drawing
draw_surface_code_svg(5, 5, 100,5)
