In [14]:
import numpy as np
import plotly.graph_objects as go
import ipywidgets as widgets
from ipywidgets import HBox, VBox, Tab, HTML, Layout, Label, FloatProgress
from qiskit import QuantumCircuit
from qiskit.quantum_info import Statevector, Operator
from collections import deque
from IPython.display import display
# This line is for Jupyter environments. If not in one, it might not be necessary.
%matplotlib inline

# ------------------- Styling -------------------
BACKGROUND_COLOR = "#000000"
TEXT_COLOR = "#ffffff"
ACCENT_COLOR = "#007aff"
GRID_COLOR = "#888888"
FONT_FAMILY = "-apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif"
initial_camera = dict(eye=dict(x=1.5, y=1.5, z=1))

# ------------------- State -------------------
vector_history = deque(maxlen=3)

def get_bloch_vector_coordinates(theta_rad, phi_rad):
    x = np.sin(theta_rad) * np.cos(phi_rad)
    y = np.sin(theta_rad) * np.sin(phi_rad)
    z = np.cos(theta_rad)
    return x, y, z

def state_to_bloch(state_vector):
    rho = np.outer(state_vector, np.conj(state_vector))
    pauli_x = np.array([[0, 1], [1, 0]])
    pauli_y = np.array([[0, -1j], [1j, 0]])
    pauli_z = np.array([[1, 0], [0, -1]])
    x = np.trace(rho @ pauli_x).real
    y = np.trace(rho @ pauli_y).real
    z = np.trace(rho @ pauli_z).real
    return x, y, z

# ------------------- Figure -------------------
fig = go.FigureWidget()
u, v = np.mgrid[0:2*np.pi:100j, 0:np.pi:100j]
x_sphere, y_sphere, z_sphere = np.cos(u)*np.sin(v), np.sin(u)*np.sin(v), np.cos(v)
fig.add_trace(go.Surface(x=x_sphere, y=y_sphere, z=z_sphere, colorscale=[[0, '#aaaaaa'], [1, '#dddddd']], opacity=0.15, showscale=False))

# Grid lines
for i in np.arange(-np.pi/2, np.pi/2, np.pi/6):
    t = np.linspace(0, 2*np.pi, 100)
    x_line, y_line = np.cos(t) * np.cos(i), np.sin(t) * np.cos(i)
    fig.add_trace(go.Scatter3d(x=x_line, y=y_line, z=np.sin(i) * np.ones(100), mode='lines', line=dict(color=GRID_COLOR, width=1)))
for i in np.arange(0, 2*np.pi, np.pi/6):
    s = np.linspace(0, np.pi, 100)
    x_line, y_line = np.cos(i) * np.sin(s), np.sin(i) * np.sin(s)
    fig.add_trace(go.Scatter3d(x=x_line, y=y_line, z=np.cos(s), mode='lines', line=dict(color=GRID_COLOR, width=1)))

# Trace & projections
TRACE_INDEX = len(fig.data)
fig.add_trace(go.Scatter3d(x=[], y=[], z=[], mode='lines', line=dict(color=ACCENT_COLOR, width=4), name='trace'))

PROJECTION_INDICES = {}
PROJECTION_INDICES['xy'] = len(fig.data); fig.add_trace(go.Scatter3d(x=[0,0], y=[0,0], z=[0,0], mode='lines', line=dict(color=GRID_COLOR, width=1)))
PROJECTION_INDICES['xz'] = len(fig.data); fig.add_trace(go.Scatter3d(x=[0,0], y=[0,0], z=[0,0], mode='lines', line=dict(color=GRID_COLOR, width=1)))
PROJECTION_INDICES['yz'] = len(fig.data); fig.add_trace(go.Scatter3d(x=[0,0], y=[0,0], z=[0,0], mode='lines', line=dict(color=GRID_COLOR, width=1)))

# Arrow
ARROW_INDEX = len(fig.data); fig.add_trace(go.Scatter3d(x=[0,0], y=[0,0], z=[0,1], mode='lines', line=dict(color="#ff3b30", width=8), name='arrow'))
ARROWHEAD_INDEX = len(fig.data); fig.add_trace(go.Cone(x=[0], y=[0], z=[1], u=[0], v=[0], w=[0.1], sizemode="absolute", sizeref=0.15, anchor="tip", colorscale=[[0, "#ff3b30"], [1, "#ff3b30"]], showscale=False))

# Axes
AXES_INDICES = {}
AXES_INDICES['x'] = len(fig.data); fig.add_trace(go.Scatter3d(x=[0,1.2], y=[0,0], z=[0,0], mode='lines', line=dict(color='red', width=5)))
AXES_INDICES['y'] = len(fig.data); fig.add_trace(go.Scatter3d(x=[0,0], y=[0,1.2], z=[0,0], mode='lines', line=dict(color='green', width=5)))
AXES_INDICES['z'] = len(fig.data); fig.add_trace(go.Scatter3d(x=[0,0], y=[0,0], z=[0,1.2], mode='lines', line=dict(color='blue', width=5)))

fig.update_layout(
    width=550, height=550, showlegend=False, transition_duration=100,
    scene=dict(
        xaxis=dict(title='X', showticklabels=False, backgroundcolor=BACKGROUND_COLOR, gridcolor=GRID_COLOR, zerolinecolor=GRID_COLOR, range=[-1.2, 1.2]),
        yaxis=dict(title='Y', showticklabels=False, backgroundcolor=BACKGROUND_COLOR, gridcolor=GRID_COLOR, zerolinecolor=GRID_COLOR, range=[-1.2, 1.2]),
        zaxis=dict(title='', tickvals=[-1,1], ticktext=['|1⟩','|0⟩'], backgroundcolor=BACKGROUND_COLOR, gridcolor=GRID_COLOR, zerolinecolor=GRID_COLOR, range=[-1.2,1.2]),
        aspectratio=dict(x=1, y=1, z=1), camera=initial_camera
    ),
    margin=dict(l=0,r=0,b=0,t=0), paper_bgcolor=BACKGROUND_COLOR, font=dict(color=TEXT_COLOR, family=FONT_FAMILY)
)

# ------------------- Widgets -------------------
header_html = HTML(f"<h1 style='text-align:center;color:{TEXT_COLOR}; font-family:{FONT_FAMILY};'>Live Bloch Sphere Visualization</h1>")

theta_slider = widgets.FloatSlider(value=0, min=0, max=180, step=1, description=r'$\theta$ (deg):', continuous_update=True, style={'description_width': 'initial','handle_color':ACCENT_COLOR})
phi_slider = widgets.FloatSlider(value=0, min=0, max=360, step=1, description=r'$\phi$ (deg):', continuous_update=True, style={'description_width':'initial','handle_color':ACCENT_COLOR})

reset_button = widgets.Button(description="Reset to |0⟩", button_style='info')
random_button = widgets.Button(description="Random State", button_style='success')
plus_ket_button = widgets.Button(description="Set to |+⟩", button_style='primary')
minus_ket_button = widgets.Button(description="Set to |-⟩", button_style='primary')
gate_buttons = {name: widgets.Button(description=name, layout=Layout(width='50px')) for name in ['X','Y','Z','H','S','T']}

coord_readout = HTML(value="<b>Bloch Vector:</b> (0.00, 0.00, 1.00)")
state_readout = HTML(value="<b>State |ψ⟩:</b> 1.00 |0⟩ + (0.00+0.00j) |1⟩")
p0_bar = FloatProgress(value=1.0, min=0.0, max=1.0, description='P(|0⟩):', bar_style='info')
p1_bar = FloatProgress(value=0.0, min=0.0, max=1.0, description='P(|1⟩):', bar_style='danger')

reset_camera_button = widgets.Button(description="Reset Rotation & Zoom", button_style='warning')
show_axes_checkbox = widgets.Checkbox(value=True, description="Show axes helper", indent=False)
arrow_color_picker = widgets.ColorPicker(concise=False, description='Arrow Color:', value='#ff3b30')
trace_length_input = widgets.IntText(value=3, description="Trace last N gates:")

# ------------------- Core Functions -------------------
def update_readouts_and_figure(theta_rad, phi_rad):
    """A dedicated function to update UI elements without touching slider observation."""
    x, y, z = get_bloch_vector_coordinates(theta_rad, phi_rad)
    alpha = np.cos(theta_rad/2)
    beta = np.exp(1j*phi_rad)*np.sin(theta_rad/2)
    p0, p1 = abs(alpha)**2, abs(beta)**2
    
    # Update text and progress bars
    coord_readout.value = f"<b>Bloch Vector:</b> ({x:.2f},{y:.2f},{z:.2f})"
    state_readout.value = f"<b>State |ψ⟩:</b> {alpha:.2f}|0⟩ + ({beta.real:.2f}{beta.imag:+.2f}j)|1⟩"
    p0_bar.value, p1_bar.value = p0, p1
    p0_bar.description, p1_bar.description = f'P(|0⟩) {p0:.1%}:', f'P(|1⟩) {p1:.1%}:'
    
    # Update the 3D plot
    with fig.batch_update():
        fig.data[PROJECTION_INDICES['xy']].x, fig.data[PROJECTION_INDICES['xy']].y, fig.data[PROJECTION_INDICES['xy']].z = [x,x], [y,y], [z,0]
        fig.data[PROJECTION_INDICES['xz']].x, fig.data[PROJECTION_INDICES['xz']].y, fig.data[PROJECTION_INDICES['xz']].z = [x,x], [y,0], [z,z]
        fig.data[PROJECTION_INDICES['yz']].x, fig.data[PROJECTION_INDICES['yz']].y, fig.data[PROJECTION_INDICES['yz']].z = [x,0], [y,y], [z,z]
        fig.data[ARROW_INDEX].x, fig.data[ARROW_INDEX].y, fig.data[ARROW_INDEX].z = [0,x],[0,y],[0,z]
        fig.data[ARROWHEAD_INDEX].x, fig.data[ARROWHEAD_INDEX].y, fig.data[ARROWHEAD_INDEX].z = [x],[y],[z]
        fig.data[ARROWHEAD_INDEX].u, fig.data[ARROWHEAD_INDEX].v, fig.data[ARROWHEAD_INDEX].w = [x],[y],[z]
        trace_x = [v[0] for v in vector_history]+[x]
        trace_y = [v[1] for v in vector_history]+[y]
        trace_z = [v[2] for v in vector_history]+[z]
        fig.data[TRACE_INDEX].x, fig.data[TRACE_INDEX].y, fig.data[TRACE_INDEX].z = trace_x, trace_y, trace_z

def update_all_from_sliders(change=None):
    """This function is now ONLY called by the slider observers."""
    vector_history.clear()
    theta_rad = np.deg2rad(theta_slider.value)
    phi_rad = np.deg2rad(phi_slider.value)
    update_readouts_and_figure(theta_rad, phi_rad)

def apply_gate_and_update(gate_name):
    try:
        # Get current angles and store current vector in history
        theta_rad, phi_rad = np.deg2rad(theta_slider.value), np.deg2rad(phi_slider.value)
        current_vector = get_bloch_vector_coordinates(theta_rad, phi_rad)
        vector_history.append(current_vector)
        
        # Create a numpy array for the current state vector
        current_state_vector = np.array([
            np.cos(theta_rad/2), 
            np.exp(1j*phi_rad)*np.sin(theta_rad/2)
        ])
        
        # Create a temporary circuit containing only the desired gate
        gate_circuit = QuantumCircuit(1)
        gate_map = {'X': gate_circuit.x, 'Y': gate_circuit.y, 'Z': gate_circuit.z, 'H': gate_circuit.h, 'S': gate_circuit.s, 'T': gate_circuit.t}
        
        if gate_name in gate_map:
            gate_map[gate_name](0)
        
        # Get the unitary matrix for the gate circuit
        gate_operator = Operator(gate_circuit)
        
        # Apply the gate by matrix-multiplying the operator and the state vector
        new_state_vector = gate_operator.data @ current_state_vector
        
        # Convert the new state's data back to Bloch coordinates
        x, y, z = state_to_bloch(new_state_vector)
        new_theta_rad = np.arccos(np.clip(z, -1, 1))
        new_phi_rad = np.arctan2(y, x)
        
        # Directly set the slider values. This does NOT trigger their observers.
        theta_slider.value = np.rad2deg(new_theta_rad)
        phi_slider.value = np.rad2deg(new_phi_rad % (2*np.pi))
        
        # Call the dedicated drawing function with the new state
        update_readouts_and_figure(new_theta_rad, new_phi_rad)

    except Exception as e:
        print(f"Error in apply_gate_and_update('{gate_name}'): {e}")

# ------------------- Other Widget Callbacks -------------------
def reset_state(b):
    vector_history.clear()
    theta_slider.value, phi_slider.value = 0, 0
    update_readouts_and_figure(0, 0)
def random_state(b):
    vector_history.clear()
    phi_rad, cos_theta = np.random.uniform(0,2*np.pi), np.random.uniform(-1,1)
    theta_rad = np.arccos(cos_theta)
    theta_slider.value, phi_slider.value = np.rad2deg(theta_rad), np.rad2deg(phi_rad)
    update_readouts_and_figure(theta_rad, phi_rad)
def set_plus_state(b):
    vector_history.clear()
    theta_slider.value, phi_slider.value = 90, 0
    update_readouts_and_figure(np.deg2rad(90), np.deg2rad(0))
def set_minus_state(b):
    vector_history.clear()
    theta_slider.value, phi_slider.value = 90, 180
    update_readouts_and_figure(np.deg2rad(90), np.deg2rad(180))
    
def reset_camera(b): fig.layout.scene.camera = initial_camera
def toggle_axes(change):
    is_visible = change.new
    with fig.batch_update():
        for idx in AXES_INDICES.values(): fig.data[idx].visible = is_visible
def update_trace_length(change):
    global vector_history
    new_len = int(change.new)
    if new_len >= 0:
        vector_history = deque(list(vector_history), maxlen=new_len)
        update_all_from_sliders() # This is ok to call here
def update_arrow_color(change):
    new_color = change.new
    with fig.batch_update():
        fig.data[ARROW_INDEX].line.color = new_color
        fig.data[ARROWHEAD_INDEX].colorscale = [[0,new_color],[1,new_color]]

# ------------------- Attach Observers -------------------
theta_slider.observe(update_all_from_sliders, names='value')
phi_slider.observe(update_all_from_sliders, names='value')
reset_button.on_click(reset_state)
random_button.on_click(random_state)
plus_ket_button.on_click(set_plus_state)
minus_ket_button.on_click(set_minus_state)
for name, button in gate_buttons.items():
    button.on_click(lambda b, g=name: apply_gate_and_update(g))
reset_camera_button.on_click(reset_camera)
show_axes_checkbox.observe(toggle_axes, names='value')
trace_length_input.observe(update_trace_length, names='value')
arrow_color_picker.observe(update_arrow_color, names='value')

# Initial UI Draw
update_all_from_sliders()

# ------------------- UI Layout -------------------
controls_tab_content = VBox([
    Label("Qubit State Controls", style={'font_weight': 'bold'}), theta_slider, phi_slider,
    HBox([reset_button, random_button, plus_ket_button, minus_ket_button]),
    VBox([Label("Quantum Gates", style={'font_weight':'bold'}), HBox(list(gate_buttons.values()))]),
    VBox([Label("Live Readouts", style={'font_weight':'bold'}), coord_readout, state_readout]),
    VBox([Label("Measurement Probabilities", style={'font_weight':'bold'}), p0_bar, p1_bar])
])

settings_tab_content = VBox([
    Label("View Controls", style={'font_weight':'bold'}),
    reset_camera_button, show_axes_checkbox, arrow_color_picker,
    VBox([Label("History", style={'font_weight':'bold'}), trace_length_input])
])

tabs = Tab(children=[controls_tab_content, settings_tab_content])
tabs.set_title(0,'Controls & Simulation')
tabs.set_title(1,'Settings')

ui = VBox([
    header_html,
    HBox([
        fig,
        VBox(
            [tabs],
            layout=Layout(
                width='450px',
                height='auto',
                border=f'1px solid {GRID_COLOR}',
                border_radius='10px',
                padding='10px',
                background=BACKGROUND_COLOR
            )
        )
    ])
],
layout=Layout(padding='20px', align_items='center'))

# Display the final UI
display(ui)



VBox(children=(HTML(value="<h1 style='text-align:center;color:#ffffff; font-family:-apple-system, BlinkMacSyst…