Skip to content

Commit

Permalink
SVG circuit drawing updates (#2414)
Browse files Browse the repository at this point in the history
 - Vertical spacing
 - Color qubit lines blue (note: this requires a hack / assumption)
 - Add xmlns so images show up when notebooks are rendered with nbsphinx
   (it is a miracle I figured out this was the problem...)
  • Loading branch information
mpharrigan authored and CirqBot committed Nov 1, 2019
1 parent 1df0979 commit 93a9312
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 11 deletions.
125 changes: 114 additions & 11 deletions cirq/contrib/svg/svg.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import TYPE_CHECKING, List, Tuple, cast
from typing import TYPE_CHECKING, List, Tuple, cast, Dict

import matplotlib.textpath

if TYPE_CHECKING:
import cirq

QBLUE = '#1967d2'


def _get_text_width(t: str) -> float:
tp = matplotlib.textpath.TextPath((0, 0), t, size=14, prop='Arial')
Expand Down Expand Up @@ -60,45 +62,135 @@ def _fit_horizontal(tdd: 'cirq.TextDiagramDrawer',
return col_starts, col_widths


def _fit_vertical(tdd: 'cirq.TextDiagramDrawer',
ref_boxheight: float, row_padding: float) \
-> Tuple[List[float], List[float], Dict[float, int]]:
"""Return data structures used to turn tdd vertical coordinates into
well-spaced SVG coordinates.
The eagle eyed coder may notice that this function is very
similar to _fit_horizonal. That function was written first
because horizontal spacing is very important for being able
to see all the gates but vertical spacing is just for aesthetics.
It wasn't until this function was written that I (mpharrigan)
noticed that -- unlike the x-coordinates (which are all integers) --
y-coordinates come in half-integers. Please use yi_map to convert
TextDiagramDrawer y-values to y indices which can be used to index
into row_starts and row_heights.
See gh-2313 to track this (and other) hacks that could be improved.
Returns:
row_starts: A list that maps y indices to the starting y position
(in SVG px)
row_heights: A list that maps y indices to the height of each row
(in SVG px). Y-index `yi` goes from row_starts[yi] to
row_starts[yi] + row_heights[yi]
yi_map:
A mapping from half-integer TextDiagramDrawer coordinates
to integer y indices. Apply this mapping before indexing into
the former two return values (ie row_starts and row_heights)
"""
# Note: y values come as half integers. Map to integers
all_yis = sorted({yi for _, yi in tdd.entries.keys()} |
{yi1 for _, yi1, _, _ in tdd.vertical_lines} |
{yi2 for _, _, yi2, _ in tdd.vertical_lines} |
{yi for yi, _, _, _ in tdd.horizontal_lines})
yi_map = {yi: i for i, yi in enumerate(all_yis)}

max_yi = max(yi_map[yi] for yi in all_yis)
row_heights = [0.0] * (max_yi + 2)
for (_, yi), _ in tdd.entries.items():
yi = yi_map[yi]
row_heights[yi] = max(ref_boxheight, row_heights[yi])

for yi_float in all_yis:
row_heights[yi_map[yi_float]] += row_padding

row_starts = [0.0]
for i in range(1, max_yi + 3):
row_starts.append(row_starts[i - 1] + row_heights[i - 1])

return row_starts, row_heights, yi_map


def _debug_spacing(col_starts, row_starts):
"""Return a string suitable for inserting inside an <svg> tag that
draws green lines where columns and rows start. This is very useful
if you're developing this code and are debugging spacing issues.
"""
# coverage: ignore
t = ''
for i, cs in enumerate(col_starts):
t += f'<line id="cs-{i}" ' \
f'x1="{cs}" x2="{cs}" y1="0" y2="{row_starts[-1]}" ' \
f'stroke="green" stroke-width="1" />'
for i, rs in enumerate(row_starts):
t += f'<line id="rs-{i}" ' \
f'x1="0" x2="{col_starts[-1]}" y1="{rs}" y2="{rs}" ' \
f'stroke="green" stroke-width="1" />'
return t


def tdd_to_svg(
tdd: 'cirq.TextDiagramDrawer',
ref_rowheight: float = 60,
ref_boxwidth: float = 40,
ref_boxheight: float = 40,
col_padding: float = 20,
y_top_pad: float = 5,
row_padding: float = 10,
) -> str:
height = tdd.height() * ref_rowheight
row_starts, row_heights, yi_map = _fit_vertical(tdd=tdd,
ref_boxheight=ref_boxheight,
row_padding=row_padding)
col_starts, col_widths = _fit_horizontal(tdd=tdd,
ref_boxwidth=ref_boxwidth,
col_padding=col_padding)

t = f'<svg width="{col_starts[-1]}" height="{height}">'
t = f'<svg xmlns="http://www.w3.org/2000/svg" ' \
f'width="{col_starts[-1]}" height="{row_starts[-1]}">'

# Developers: uncomment below to draw green lines to debug
# col_starts and row_starts
# t += _debug_spacing(col_starts, row_starts)

for yi, xi1, xi2, _ in tdd.horizontal_lines:
xi1 = cast(int, xi1)
xi2 = cast(int, xi2)
y = yi * ref_rowheight + y_top_pad + ref_boxheight / 2
x1 = col_starts[xi1] + col_widths[xi1] / 2
x2 = col_starts[xi2] + col_widths[xi2] / 2

yi = yi_map[yi]
y = row_starts[yi] + row_heights[yi] / 2

if xi1 == 0:
# qubits start at far left and their wires shall be blue
stroke = QBLUE
else:
# coverage: ignore
stroke = 'black'
t += f'<line x1="{x1}" x2="{x2}" y1="{y}" y2="{y}" ' \
f'stroke="black" stroke-width="1" />'
f'stroke="{stroke}" stroke-width="1" />'

for xi, yi1, yi2, _ in tdd.vertical_lines:
y1 = yi1 * ref_rowheight + y_top_pad + ref_boxheight / 2
y2 = yi2 * ref_rowheight + y_top_pad + ref_boxheight / 2
yi1 = yi_map[yi1]
yi2 = yi_map[yi2]
y1 = row_starts[yi1] + row_heights[yi1] / 2
y2 = row_starts[yi2] + row_heights[yi2] / 2

xi = cast(int, xi)
x = col_starts[xi] + col_widths[xi] / 2
t += f'<line x1="{x}" x2="{x}" y1="{y1}" y2="{y2}" ' \
f'stroke="black" stroke-width="3" />'

for (xi, yi), v in tdd.entries.items():
xi = cast(int, xi)
yi = yi_map[yi]

x = col_starts[xi] + col_widths[xi] / 2
y = yi * ref_rowheight + y_top_pad + ref_boxheight / 2
y = row_starts[yi] + row_heights[yi] / 2

boxheight = ref_boxheight
boxwidth = max(ref_boxwidth, _get_text_width(v.text))
boxwidth = col_widths[xi] - tdd.horizontal_padding.get(xi, col_padding)
boxx = x - boxwidth / 2
boxy = y - boxheight / 2

Expand All @@ -122,6 +214,15 @@ def tdd_to_svg(
return t


def _validate_circuit(circuit: 'cirq.Circuit'):
if len(circuit) == 0:
raise ValueError("Can't draw SVG diagram for empty circuits")

if any(len(mom) == 0 for mom in circuit.moments):
raise ValueError("Can't draw SVG diagram for circuits with empty "
"moments. Run it through cirq.DropEmptyMoments()")


class SVGCircuit:
"""A wrapper around cirq.Circuit to enable rich display in a Jupyter
notebook.
Expand All @@ -137,11 +238,13 @@ def __init__(self, circuit: 'cirq.Circuit'):

def _repr_svg_(self) -> str:
# coverage: ignore
_validate_circuit(self.circuit)
tdd = self.circuit.to_text_diagram_drawer(transpose=False)
return tdd_to_svg(tdd)


def circuit_to_svg(circuit: 'cirq.Circuit') -> str:
"""Render a circuit as SVG."""
_validate_circuit(circuit)
tdd = circuit.to_text_diagram_drawer(transpose=False)
return tdd_to_svg(tdd)
13 changes: 13 additions & 0 deletions cirq/contrib/svg/svg_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

import cirq
from cirq.contrib.svg import circuit_to_svg

Expand All @@ -12,3 +14,14 @@ def test_svg():
cirq.Z(a), cirq.measure(a, b, c, key='z')))
assert '<svg' in svg_text
assert '</svg>' in svg_text


def test_validation():
with pytest.raises(ValueError):
circuit_to_svg(cirq.Circuit())

q0 = cirq.LineQubit(0)
with pytest.raises(ValueError):
circuit_to_svg(
cirq.Circuit([cirq.Moment([cirq.X(q0)]),
cirq.Moment([])]))

0 comments on commit 93a9312

Please sign in to comment.