# Generating Conditional Wigner Functions with MrMustard

## Objective

In this notebook, we demonstrate how to:

- Construct a continuous-variable (CV) quantum circuit using squeezed vacuum states and beam splitters
- Apply photon-number-resolving measurements to generate non-Gaussian conditional states
- Compute and visualize the Wigner function of the resulting conditional state

**Your Challenge:** Use this starter code to generate a training dataset of Wigner functions and their corresponding circuit parameters. Then build and train a neural network that can perform **inverse design** — given a target Wigner function as input, predict the circuit parameters needed to generate it.

## Background

**Wigner Functions** are phase-space representations of quantum states that visualize position and momentum simultaneously. Negative regions in Wigner functions indicate non-classical quantum behavior.

**Non-Gaussian States** are essential for universal quantum computing but are difficult to generate. This circuit produces them by:
1. Starting with Gaussian squeezed vacuum states
2. Mixing them with parameterized beam splitters
3. Post-selecting on photon-number measurements to create non-Gaussian conditional states

**Your Task:** Generate diverse Wigner functions by varying circuit parameters, then train a neural network to solve the inverse problem: Wigner function → circuit parameters.

## Circuit Parameters

The circuit is controlled by:
- **Squeezing parameters** `r_i`: How much each input state is squeezed (max 0.8 for tractability)
- **Beam splitter angles** `θ`: Mixing ratios, bounded [0, π/2]
- **Beam splitter phases** `φ`: Relative phases, bounded [0, 2π)
- **Photon numbers** `n_i`: Detected photon counts (max 3 per detector)

**Note:** Higher squeezing and photon numbers create rarer measurement events and noisier states.

In [None]:
import numpy as np
from mrmustard.lab import Circuit, SqueezedVacuum, Number
from mrmustard.lab.transformations import BSgate
from mrmustard.physics.wigner import wigner_discretized
import matplotlib.pyplot as plt

## Generate a state with random circuit parameters

In [None]:
seed = 42
rng = np.random.default_rng(seed)

# Set number of modes, maximal squeezing and maximal number of photons detected
# Note: larger squeezing and higher photon numbers lead to rarer events
# and noisier conditional states
N = 3
max_r = 0.8  # recommended bound for tractable simulation
max_n = 3

# Generate input squeezed states
squeezing_params = rng.uniform(0, max_r, size=N)
input_state = [
    SqueezedVacuum(i, r, phi=(0 if i % 2 == 1 else np.pi/2))
    for i, r in enumerate(squeezing_params)
]

# Generate interferometer randomly
def random_BS(mode1, mode2):
    # Mixing angle
    theta = rng.uniform(0, np.pi/2)
    # Relative phase
    phi = rng.uniform(0, 2*np.pi)
    BS = BSgate([mode1, mode2], theta, phi)
    return BS, [theta, phi]

# Three beam splitters are sufficient to generate an arbitrary 3-mode interferometer.
# Students are encouraged to experiment with fewer or more.
BS1, BS1_params = random_BS(0, 1)
BS2, BS2_params = random_BS(1, 2)
BS3, BS3_params = random_BS(0, 1)
BS_params = [BS1_params, BS2_params, BS3_params]

interferometer = BS1 >> BS2 >> BS3

# Generate random measurement numbers
photon_numbers = rng.integers(0, max_n + 1, size=N-1)

params = {
    "squeezing": squeezing_params,
    "beam_splitters": BS_params,
    "photon_numbers": photon_numbers,
}

measurement = [Number(i, n).dual for i, n in enumerate(photon_numbers)]

# Realize circuit and get output state (out)
c = Circuit(input_state) >> interferometer >> Circuit(measurement)
out = c.contract().normalize()

# Print information
print("Circuit parameters:")
for k, v in params.items():
    print(f"  {k}: {v}")

In [None]:
display(out)

In [None]:
# Adjust format of pixelized wigner function
x = 3.0      # phase-space cutoff
Np = 15      # number of pixels per axis
xvec = np.linspace(-x, x, Np)
pvec = np.linspace(-x, x, Np)
wig, X, P = wigner_discretized(out.dm().ansatz.array, xvec, pvec)

# Visualize Wigner function
plt.figure(figsize=(6, 5))
plt.imshow(wig.T, vmin=-1/np.pi, vmax=1/np.pi, cmap='RdBu', extent=[-x, x, -x, x], origin='lower')
plt.colorbar(label='Wigner function')
plt.xlabel('Position (x)')
plt.ylabel('Momentum (p)')
plt.title('Wigner Function of Conditional State')
plt.show()

## Next Steps / Your Challenge

### 1. Explore the Parameter Space
- Try changing the squeezing parameters, beam splitter angles, or photon numbers and observe how the Wigner function changes
- Use the interactive widget below to build intuition

### 2. Generate Training Data
- Think about how you could generate a dataset of many such Wigner functions and store the corresponding circuit parameters for NN training
- Consider: What happens with rare measurement outcomes? Should you filter them?
- How will you structure and save your data?

### 3. Build Your Neural Network
- **Input**: Wigner function (consider: raw pixels? flattened? normalized?)
- **Output**: Circuit parameters (11 values total)
- Consider using PyTorch, TensorFlow, or JAX
- What network architecture makes sense for this problem?

### 4. Important Considerations
- Remember to constrain NN outputs to valid parameter ranges
- Think about normalization strategies
- What loss function should you use?
- How will you evaluate your model's performance?

## Extra: Interactive Circuit Explorer

Play with the circuit parameters interactively to see how they affect the output state.

In [None]:
import ipywidgets as widgets
from IPython.display import display

# Cutoff and max values
N = 3
max_r = 0.8
max_n = 3

In [None]:
# Function to build the circuit interactively
def circuit(r0, r1, r2, theta1, phi1, theta2, phi2, theta3, phi3, n0, n1):
    
    # Input squeezed states
    squeezing_params = [r0, r1, r2]
    input_state = [
        SqueezedVacuum(i, r, phi=(0 if i % 2 == 1 else np.pi/2))
        for i, r in enumerate(squeezing_params)
    ]
    
    # Interferometer
    BS1 = BSgate([0,1], theta1, phi1)
    BS2 = BSgate([1,2], theta2, phi2)
    BS3 = BSgate([0,1], theta3, phi3)
    interferometer = BS1 >> BS2 >> BS3
    
    # Measurement (post-selection)
    photon_numbers = [n0, n1]
    measurement = [Number(i, n).dual for i, n in enumerate(photon_numbers)]
    
    # Build and run circuit
    c = Circuit(input_state) >> interferometer >> Circuit(measurement)
    out = c.contract().normalize()
    
    display(out)

# Sliders for squeezing
r_sliders = [widgets.FloatSlider(min=0, max=max_r, step=0.01, value=0.5, description=f'r{i}') for i in range(3)]

# Sliders for beam splitters
theta_sliders = [widgets.FloatSlider(min=0, max=np.pi/2, step=0.01*np.pi, value=np.pi/4, description=f'theta{i+1}') for i in range(3)]
phi_sliders   = [widgets.FloatSlider(min=0, max=2*np.pi, step=0.01*np.pi, value=0, description=f'phi{i+1}') for i in range(3)]

# Sliders for photon numbers
n_sliders = [widgets.IntSlider(min=0, max=max_n, step=1, value=0, description=f'n{i}') for i in range(2)]

# Interactive widget
interactive_plot = widgets.interactive(
    circuit,
    r0=r_sliders[0], r1=r_sliders[1], r2=r_sliders[2],
    theta1=theta_sliders[0], phi1=phi_sliders[0],
    theta2=theta_sliders[1], phi2=phi_sliders[1],
    theta3=theta_sliders[2], phi3=phi_sliders[2],
    n0=n_sliders[0], n1=n_sliders[1]
)

display(interactive_plot)