In [2]:
import numpy as np
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display
import math

# Constants
resolution = 200
x_dim = 850
y_dim = x_dim
x = np.linspace(0, x_dim, resolution)
y = np.linspace(0, y_dim, resolution)
X, Y = np.meshgrid(x, y)

center_x, center_y = x_dim / 2, y_dim / 2
inner_radius, outer_radius = 250, 400
ref_x, ref_y = center_x, center_y + (inner_radius + outer_radius) / 2
opponent_x, opponent_y = ref_x, ref_y
ego_x, ego_y = center_x, center_y - (inner_radius + outer_radius) / 2

def out_of_bounds_cost(x, y, weight=1, spread=100):
    dx = x - center_x
    dy = y - center_y
    distance_to_center = np.hypot(dx, dy)
    radius = (outer_radius + inner_radius) / 2
    distance_to_perimeter = abs(distance_to_center - radius)
    cost = weight * (1 - np.exp(-(2 / spread * distance_to_perimeter) ** 2))
    return cost

def relative_progress_cost(x1, y1, weight=1):
    def angle(x, y):
        return (np.arctan2(y - center_y, x - center_x) + 2 * np.pi) % (2 * np.pi)
    angle1 = angle(x1, y1)
    angle2 = angle(ref_x, ref_y)
    delta = angle1 - angle2
    delta = np.where(delta < -np.pi, delta + 2 * np.pi, delta)
    delta = np.where(delta > np.pi, delta - 2 * np.pi, delta)
    return weight * delta

def proximity_cost(x1, y1, threshold=50, weight=1.0):
    dx = x1 - opponent_x
    dy = y1 - opponent_y
    dist = np.sqrt(dx ** 2 + dy ** 2)
    cost = np.zeros_like(dist)
    within = dist < threshold
    cost[within] = weight * np.exp(-(2 * 1 / threshold * dist[within]))
    return cost

def update_plot(bounds_weight, progress_weight, proximity_weight, bounds_spread, proximity_threshold):
    Z1 = out_of_bounds_cost(X, Y, bounds_weight, spread=bounds_spread)
    Z2 = relative_progress_cost(X, Y, progress_weight)
    Z3 = proximity_cost(X, Y, threshold=proximity_threshold, weight=proximity_weight)
    Z = Z1 + Z2 + Z3

    # Interpolate z from grid (nearest for now)
    ix = np.abs(x - opponent_x).argmin()
    iy = np.abs(y - opponent_y).argmin()
    pz = Z[iy, ix]

    ix = np.abs(x - ego_x).argmin()
    iy = np.abs(y - ego_y).argmin()
    ez = Z[ix, iy]

    def generate_circle(cx, cy, radius, z_level=0, resolution=200):
        theta = np.linspace(0, 2 * np.pi, resolution)
        x_circle = cx + radius * np.cos(theta)
        y_circle = cy + radius * np.sin(theta)
        z_circle = np.full_like(x_circle, z_level)
        return x_circle, y_circle, z_circle

    inner_x, inner_y, inner_z = generate_circle(center_x, center_y, inner_radius)
    outer_x, outer_y, outer_z = generate_circle(center_x, center_y, outer_radius)

    # fig = go.Figure()
    fig = go.FigureWidget()
    fig.data = []


    fig.add_trace(go.Surface(z=Z, x=X, y=Y, colorscale='Viridis', opacity=0.95, showlegend=False))
    fig.add_trace(go.Scatter3d(x=inner_x, y=inner_y, z=inner_z,
                               mode='lines', line=dict(color='black', width=4), showlegend=False))
    fig.add_trace(go.Scatter3d(x=outer_x, y=outer_y, z=outer_z,
                               mode='lines', line=dict(color='black', width=4), showlegend=False))
    fig.add_trace(go.Scatter3d(
        x=[opponent_x], y=[opponent_y], z=[pz],
        mode='markers+text',
        marker=dict(size=6, color='red'),
        text=["Opponent"],
        textposition='top center', showlegend=False
    ))
    fig.add_trace(go.Scatter3d(
        x=[ego_x], y=[ego_y], z=[ez],
        mode='markers+text',
        marker=dict(size=6, color='green'),
        text=["Ego"],
        textposition='top center', showlegend=False
    ))

    fig.update_layout(
        title='Combined Cost Surface with Track Boundaries',
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Cost',
            aspectmode='manual',
            aspectratio=dict(x=1, y=y_dim / x_dim, z=0.3)
        ),
        width=800,
        height=600
    )

    # fig.show()

# Existing sliders
bounds_slider = widgets.FloatSlider(value=1, min=0, max=1, step=.1, description='Bounds:')
progress_slider = widgets.FloatSlider(value=1, min=0, max=10, step=1, description='Progress:')
proximity_slider = widgets.FloatSlider(value=1, min=0, max=1, step=.1, description='Proximity:')

# New sliders for spread and threshold
bounds_spread_slider = widgets.FloatSlider(value=100, min=1, max=200, step=1, description='Bounds Spread:')
proximity_threshold_slider = widgets.FloatSlider(value=50, min=1, max=100, step=1, description='Proximity Threshold:')

ui = widgets.VBox([
    bounds_slider,
    progress_slider,
    proximity_slider,
    bounds_spread_slider,
    proximity_threshold_slider
])

out = widgets.interactive_output(update_plot, {
    'bounds_weight': bounds_slider,
    'progress_weight': progress_slider,
    'proximity_weight': proximity_slider,
    'bounds_spread': bounds_spread_slider,
    'proximity_threshold': proximity_threshold_slider
})

# display(ui, out)
display(ui, fig)




NameError: name 'fig' is not defined

In [2]:
import numpy as np
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display
import math

# Constants
resolution = 200
x_dim = 850
y_dim = x_dim
x = np.linspace(0, x_dim, resolution)
y = np.linspace(0, y_dim, resolution)
X, Y = np.meshgrid(x, y)

center_x, center_y = x_dim / 2, y_dim / 2
inner_radius, outer_radius = 250, 400
ref_x, ref_y = center_x, center_y + (inner_radius + outer_radius) / 2
opponent_x, opponent_y = ref_x, ref_y
ego_x, ego_y = center_x, center_y - (inner_radius + outer_radius) / 2

# Create FigureWidget once
fig = go.FigureWidget()

def out_of_bounds_cost(x, y, weight=1, spread=100):
    dx = x - center_x
    dy = y - center_y
    distance_to_center = np.hypot(dx, dy)
    radius = (outer_radius + inner_radius) / 2
    distance_to_perimeter = abs(distance_to_center - radius)
    cost = weight * (1 - np.exp(-(2 / spread * distance_to_perimeter) ** 2))
    return cost

def relative_progress_cost(x1, y1, weight=1):
    def angle(x, y):
        return (np.arctan2(y - center_y, x - center_x) + 2 * np.pi) % (2 * np.pi)
    angle1 = angle(x1, y1)
    angle2 = angle(ref_x, ref_y)
    delta = angle1 - angle2
    delta = np.where(delta < -np.pi, delta + 2 * np.pi, delta)
    delta = np.where(delta > np.pi, delta - 2 * np.pi, delta)
    return weight * delta

def proximity_cost(x1, y1, threshold=50, weight=1.0):
    dx = x1 - opponent_x
    dy = y1 - opponent_y
    dist = np.sqrt(dx ** 2 + dy ** 2)
    cost = np.zeros_like(dist)
    within = dist < threshold
    cost[within] = weight * np.exp(-(2 * 1 / threshold * dist[within]))
    return cost

def update_plot(bounds_weight, progress_weight, proximity_weight, bounds_spread, proximity_threshold):
    fig.data = []  # Clear previous traces

    Z1 = out_of_bounds_cost(X, Y, bounds_weight, spread=bounds_spread)
    Z2 = relative_progress_cost(X, Y, progress_weight)
    Z3 = proximity_cost(X, Y, threshold=proximity_threshold, weight=proximity_weight)
    Z = Z1 + Z2 + Z3

    # Interpolate z from grid (nearest for now)
    ix = np.abs(x - opponent_x).argmin()
    iy = np.abs(y - opponent_y).argmin()
    pz = Z[iy, ix]

    ix = np.abs(x - ego_x).argmin()
    iy = np.abs(y - ego_y).argmin()
    ez = Z[iy, ix]

    def generate_circle(cx, cy, radius, z_level=0, resolution=200):
        theta = np.linspace(0, 2 * np.pi, resolution)
        x_circle = cx + radius * np.cos(theta)
        y_circle = cy + radius * np.sin(theta)
        z_circle = np.full_like(x_circle, z_level)
        return x_circle, y_circle, z_circle

    inner_x, inner_y, inner_z = generate_circle(center_x, center_y, inner_radius)
    outer_x, outer_y, outer_z = generate_circle(center_x, center_y, outer_radius)

    fig.add_trace(go.Surface(z=Z, x=X, y=Y, colorscale='Viridis', opacity=0.95, showlegend=False))
    fig.add_trace(go.Scatter3d(x=inner_x, y=inner_y, z=inner_z,
                               mode='lines', line=dict(color='black', width=4), showlegend=False))
    fig.add_trace(go.Scatter3d(x=outer_x, y=outer_y, z=outer_z,
                               mode='lines', line=dict(color='black', width=4), showlegend=False))
    fig.add_trace(go.Scatter3d(
        x=[opponent_x], y=[opponent_y], z=[pz],
        mode='markers+text',
        marker=dict(size=6, color='red'),
        text=["Opponent"],
        textposition='top center', showlegend=False
    ))
    fig.add_trace(go.Scatter3d(
        x=[ego_x], y=[ego_y], z=[ez],
        mode='markers+text',
        marker=dict(size=6, color='green'),
        text=["Ego"],
        textposition='top center', showlegend=False
    ))

    fig.update_layout(
        title='Combined Cost Surface with Track Boundaries',
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Cost',
            aspectmode='manual',
            aspectratio=dict(x=1, y=y_dim / x_dim, z=0.3)
        ),
        width=800,
        height=600
    )

# Sliders
bounds_slider = widgets.FloatSlider(value=1, min=0, max=1, step=.1, description='Bounds:')
progress_slider = widgets.FloatSlider(value=1, min=0, max=1, step=.1, description='Progress:')
proximity_slider = widgets.FloatSlider(value=1, min=0, max=1, step=.1, description='Proximity:')
bounds_spread_slider = widgets.FloatSlider(value=100, min=1, max=500, step=1, description='Bounds Spread:')
proximity_threshold_slider = widgets.FloatSlider(value=50, min=1, max=100, step=1, description='Proximity Threshold:')

ui = widgets.VBox([
    bounds_slider,
    progress_slider,
    proximity_slider,
    bounds_spread_slider,
    proximity_threshold_slider
])

# Trigger updates
widgets.interactive_output(update_plot, {
    'bounds_weight': bounds_slider,
    'progress_weight': progress_slider,
    'proximity_weight': proximity_slider,
    'bounds_spread': bounds_spread_slider,
    'proximity_threshold': proximity_threshold_slider
})

# Display UI and plot
display(ui, fig)


VBox(children=(FloatSlider(value=1.0, description='Bounds:', max=1.0), FloatSlider(value=1.0, description='Pro…

FigureWidget({
    'data': [{'colorscale': [[0.0, '#440154'], [0.1111111111111111, '#482878'],
                             [0.2222222222222222, '#3e4989'], [0.3333333333333333,
                             '#31688e'], [0.4444444444444444, '#26828e'],
                             [0.5555555555555556, '#1f9e89'], [0.6666666666666666,
                             '#35b779'], [0.7777777777777778, '#6ece58'],
                             [0.8888888888888888, '#b5de2b'], [1.0, '#fde725']],
              'opacity': 0.95,
              'showlegend': False,
              'type': 'surface',
              'uid': 'a5642693-9373-490f-8de0-45ace839c9c7',
              'x': {'bdata': ('AAAAAAAAAAD97YGN3hURQP3tgY3eFS' ... 'mFqEuKQCP85ELUbYpAAAAAAACQikA='),
                    'dtype': 'f8',
                    'shape': '200, 200'},
              'y': {'bdata': ('AAAAAAAAAAAAAAAAAAAAAAAAAAAAAA' ... 'AAAJCKQAAAAAAAkIpAAAAAAACQikA='),
                    'dtype': 'f8',
                    'shape': '200,