# SSD Visualisations for Recent Benchmark Circuits

This notebook examines a selection of recently added benchmark circuits and uses the new ``Circuit.to_networkx_ssd`` helper to expose each circuit's subsystem descriptor (SSD).
The resulting graphs highlight how partitions, conversion layers and backend assignments interact across the workload.


## Circuits covered

* **W-state preparation:** Exercises the linear-depth construction that spreads a single excitation across all qubits.
* **Grover search:** Uses the ancilla-free multi-controlled-X decomposition recently added to the benchmarks module.
* **QFT on a GHZ state:** Combines entangling preparation with a global Fourier transform to stress conversion planning.

Each circuit is re-instantiated without classical simplification so that the planner operates on the full quantum workload before rendering the SSD.


In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd
from IPython.display import display

PROJECT_ROOT = Path('..').resolve().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

from benchmarks import circuits as benchmark_circuits
from quasar.circuit import Circuit

plt.rcParams.update({
    'figure.figsize': (10, 6),
    'axes.titlesize': 14,
    'axes.labelsize': 12,
})
pd.set_option('display.max_colwidth', None)
pd.options.display.float_format = '{:.3f}'.format


In [None]:
RECENT_CIRCUIT_SPECS = [
    (
        'W-state preparation (8 qubits)',
        benchmark_circuits.w_state_circuit,
        {'n_qubits': 8},
    ),
    (
        'Grover search (5 qubits, 2 iterations)',
        benchmark_circuits.grover_circuit,
        {'n_qubits': 5, 'n_iterations': 2},
    ),
    (
        'QFT on GHZ (6 qubits)',
        benchmark_circuits.qft_on_ghz_circuit,
        {'n_qubits': 6},
    ),
]

def rebuild_without_simplification(circuit: Circuit) -> Circuit:
    """Return a fresh circuit instance with classical simplification disabled."""

    gate_specs = [gate.to_dict() for gate in circuit.gates]
    return Circuit(gate_specs, use_classical_simplification=False)

def instantiate_recent_benchmarks():
    """Yield (name, Circuit) pairs for the selected benchmarks."""

    for name, factory, kwargs in RECENT_CIRCUIT_SPECS:
        base = factory(**kwargs)
        yield name, rebuild_without_simplification(base)


In [None]:
def partition_table(graph: nx.MultiDiGraph) -> pd.DataFrame:
    """Return a summary of partition metadata recorded in the graph."""

    rows = []
    for node, data in graph.nodes(data=True):
        if data.get('kind') != 'partition':
            continue
        rows.append({
            'Partition': data.get('index'),
            'Backend': data.get('backend'),
            'Multiplicity': data.get('multiplicity'),
            'Qubits': ', '.join(str(q) for q in data.get('qubits', ())),
            'Boundary': ', '.join(str(q) for q in data.get('boundary_qubits', ())),
            'Rank': data.get('rank'),
            'Frontier': data.get('frontier'),
            'Time cost': data.get('cost_time'),
            'Memory cost': data.get('cost_memory'),
        })
    if not rows:
        return pd.DataFrame()
    return pd.DataFrame(rows).sort_values('Partition').reset_index(drop=True)

def conversion_table(graph: nx.MultiDiGraph) -> pd.DataFrame:
    """Return conversion layer metadata from the SSD graph."""

    rows = []
    for node, data in graph.nodes(data=True):
        if data.get('kind') != 'conversion':
            continue
        rows.append({
            'Conversion': data.get('index'),
            'Boundary': ', '.join(str(q) for q in data.get('boundary', ())),
            'Source': data.get('source'),
            'Target': data.get('target'),
            'Rank': data.get('rank'),
            'Frontier': data.get('frontier'),
            'Primitive': data.get('primitive'),
            'Time cost': data.get('cost_time'),
            'Memory cost': data.get('cost_memory'),
        })
    columns = [
        'Conversion', 'Boundary', 'Source', 'Target',
        'Rank', 'Frontier', 'Primitive', 'Time cost', 'Memory cost',
    ]
    if not rows:
        return pd.DataFrame(columns=columns)
    return pd.DataFrame(rows, columns=columns).sort_values('Conversion').reset_index(drop=True)

def draw_ssd_graph(
    graph: nx.MultiDiGraph,
    *,
    title: str,
    seed: int = 2024,
    ax: plt.Axes | None = None,
) -> plt.Axes:
    """Visualise partitions, conversions and backends for the SSD graph."""

    if ax is None:
        _, ax = plt.subplots(figsize=(10, 6))
    pos = nx.spring_layout(graph, seed=seed)

    node_styles = {
        'partition': {'color': '#4e79a7', 'shape': 's', 'size': 1400, 'label': 'Partition'},
        'conversion': {'color': '#e15759', 'shape': 'D', 'size': 1100, 'label': 'Conversion'},
        'backend': {'color': '#59a14f', 'shape': 'o', 'size': 1000, 'label': 'Backend'},
    }

    for kind, style in node_styles.items():
        nodes = [node for node, data in graph.nodes(data=True) if data.get('kind') == kind]
        if not nodes:
            continue
        nx.draw_networkx_nodes(
            graph,
            pos,
            nodelist=nodes,
            node_color=style['color'],
            node_shape=style['shape'],
            node_size=style['size'],
            label=style['label'],
            ax=ax,
        )

    edge_styles = {
        'dependency': {'style': 'solid', 'color': '#1f77b4', 'arrows': True},
        'entanglement': {'style': 'dashed', 'color': '#ff7f0e', 'arrows': False},
        'conversion_boundary': {'style': 'dotted', 'color': '#76b7b2', 'arrows': True, 'connectionstyle': 'arc3,rad=0.1'},
        'backend_assignment': {'style': 'solid', 'color': '#b07aa1', 'arrows': True, 'connectionstyle': 'arc3,rad=0.15'},
        'conversion_source': {'style': 'solid', 'color': '#9c755f', 'arrows': True, 'connectionstyle': 'arc3,rad=0.2'},
        'conversion_target': {'style': 'solid', 'color': '#f28e2b', 'arrows': True, 'connectionstyle': 'arc3,rad=-0.2'},
    }

    for kind, style in edge_styles.items():
        edges = [
            (u, v)
            for u, v, key, data in graph.edges(data=True, keys=True)
            if data.get('kind') == kind
        ]
        if not edges:
            continue
        nx.draw_networkx_edges(
            graph,
            pos,
            edgelist=edges,
            ax=ax,
            edge_color=style['color'],
            style=style['style'],
            arrows=style.get('arrows', True),
            connectionstyle=style.get('connectionstyle', 'arc3,rad=0.0'),
        )

    labels = {}
    for node, data in graph.nodes(data=True):
        kind = data.get('kind')
        if kind == 'partition':
            labels[node] = f"P{data.get('index')}\n{data.get('backend')}"
        elif kind == 'backend':
            labels[node] = data.get('label', data.get('backend', ''))
        elif kind == 'conversion':
            labels[node] = f"C{data.get('index')}\n{data.get('source')}→{data.get('target')}"
    nx.draw_networkx_labels(graph, pos, labels=labels, font_size=10, ax=ax)

    ax.set_axis_off()
    ax.set_title(title)

    handles, labels = ax.get_legend_handles_labels()
    if handles:
        unique = dict(zip(labels, handles))
        ax.legend(unique.values(), unique.keys(), loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=len(unique))

    return ax


## Visualising the SSDs

The loop below instantiates each benchmark, extracts its SSD graph and renders the structure alongside tabular summaries for partitions and conversion layers.


In [None]:
for name, circuit in instantiate_recent_benchmarks():
    graph = circuit.to_networkx_ssd()
    total_qubits = graph.graph.get('total_qubits')
    num_partitions = graph.graph.get('num_partitions')
    print(f'{name}: {num_partitions} partitions covering {total_qubits} qubits')

    fig, ax = plt.subplots(figsize=(10, 6))
    draw_ssd_graph(graph, title=name, ax=ax, seed=42)
    plt.show()

    partitions = partition_table(graph)
    if not partitions.empty:
        display(partitions)

    conversions = conversion_table(graph)
    if not conversions.empty:
        display(conversions)

    print('
')
