In [1]:
import numpy as np
import matplotlib.pyplot as plt
from mrmustard.lab import Circuit, SqueezedVacuum, Number
from mrmustard.lab.transformations import BSgate
from mrmustard.physics.wigner import wigner_discretized

In [2]:
def draw_params(N : int, r_limits=(0,0.8), theta_limits=(0,np.pi/2),
                     phi_limits=(0,2*np.pi)):
    """
    Draw parameters for the quantum optical setup.
    Parameters
    ----------
    N : int, number of samples to draw
    r_limits : tuple, limits for the squeezing parameter r (inclusive)
    theta_limits : tuple, limits for the angle theta (inclusive)
    phi_limits : tuple, limits for the angle phi (exclusive)
    Returns
    -------
    r : np.ndarray, array of shape (N, 9) containing the drawn parameters
    """
    
    rng = np.random.default_rng()
    r = rng.uniform(r_limits[0], r_limits[1], size=(N, 3))
    theta = rng.uniform(theta_limits[0], theta_limits[1], size=(N, 3))
    phi = rng.uniform(phi_limits[0], phi_limits[1], size=(N, 3))
    
    return np.hstack([r, theta, phi])

In [11]:
N = 3

In [22]:
params = draw_params(N, r_limits=(0,0.8), theta_limits=(0,np.pi/2), phi_limits=(0,2*np.pi))

In [None]:
def generate_wigner_sample(params, grid_size=15, x_max=4.0, plot=False):
    """
    Generate a discretized Wigner function for a single parameter set.

    Returns
    -------
    wigner_grid : np.ndarray of shape (grid_size, grid_size)
        Discretized Wigner function evaluated on a square phase-space grid.
    """

    # --- Input squeezed states ---
    input_state = [
        SqueezedVacuum(
            i,
            params[i],
            phi=(0 if i % 2 == 1 else np.pi / 2)
        )
        for i in range(3)
    ]

    # --- Interferometer ---
    BS1 = BSgate([0, 1], params[3], params[6])
    BS2 = BSgate([1, 2], params[4], params[7])
    BS3 = BSgate([0, 1], params[5], params[8])
    interferometer = BS1 >> BS2 >> BS3

    # --- Post-selection measurements ---
    measurement = [
        Number(i, params[i]).dual for i in params[10:11]
    ]

    # --- Build and contract circuit ---
    c = Circuit(input_state) >> interferometer >> Circuit(measurement)
    out = c.contract().normalize()

    # --- Phase-space grid ---
    xvec = np.linspace(-x_max, x_max, grid_size)
    pvec = np.linspace(-x_max, x_max, grid_size)

    # --- Wigner function ---
    wigner, _, _ = wigner_discretized(
        out.dm().ansatz.array,
        xvec,
        pvec
    )

    # --- Optional plot ---
    if plot:
        plt.figure(figsize=(5, 4))
        plt.imshow(
            wigner.T,
            origin="lower",
            extent=[-x_max, x_max, -x_max, x_max],
            cmap="RdBu",
            vmin=-1 / np.pi,
            vmax=1 / np.pi
        )
        plt.colorbar(label="Wigner function")
        plt.xlabel("x")
        plt.ylabel("p")
        plt.title("Conditional Wigner function")
        plt.tight_layout()
        plt.show()

    return wigner

In [4]:
params = [0.5, 0.1, 0.2, np.pi/2, np.pi/6, np.pi/3, 0.0, np.pi/2, np.pi, 1, 0]

wigner = generate_wigner_sample(
    params,
    plot=True
)

  warn(


TypingError: Failed in nopython mode pipeline (step: nopython frontend)
[1m[1m[1m[1mNo implementation of function Function(<function diag at 0x00000138FF297240>) found for signature:
 
 >>> diag(array(complex128, 4d, C), int64)
 
There are 2 candidate implementations:
[1m - Of which 2 did not match due to:
 Overload in function 'impl_np_diag': File: numba\np\arrayobj.py: Line 4673.
   With argument(s): '(array(complex128, 4d, C), int64)':[0m
[1m  Rejected as the implementation raised a specific error:
    NumbaTypeError: [1mInput must be 1- or 2-d.[0m[0m
  raised from c:\Users\mathi\Documents\Cours\M2\DigiQ\QST Hackathon Copenhagen\qst-hack-2026\Case 6\.venv\Lib\site-packages\numba\np\arrayobj.py:4680
[0m
[0m[1mDuring: resolving callee type: Function(<function diag at 0x00000138FF297240>)[0m
[0m[1mDuring: typing of call at c:\Users\mathi\Documents\Cours\M2\DigiQ\QST Hackathon Copenhagen\qst-hack-2026\Case 6\.venv\Lib\site-packages\mrmustard\physics\wigner.py (138)[0m
[1m
File "..\Case 6\.venv\Lib\site-packages\mrmustard\physics\wigner.py", line 138:[0m
[1mdef _wigner_discretized_clenshaw(rho, q_vec, p_vec, hbar):  # pragma: no cover
    <source elided>
    for j in range(1, cutoff):
[1m        c_L = _wig_laguerre_val(L - j, B, np.diag(rho2, L - j))
[0m        [1m^[0m[0m

[0m[1mDuring: Pass nopython_type_inference[0m