In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib notebook

sigma = 1

def get_csm(X, Y):
    X = np.array(X)
    Y = np.array(Y)
    XSqr = np.sum(X**2, axis=1)
    YSqr = np.sum(Y**2, axis=1)
    D = XSqr[:, None] + YSqr[None, :] - 2*X.dot(Y.T)
    D[D < 0] = 0
    return D

def get_score(X, Y):
    """
    Return accumulated observation cost and transition cost
    """
    dx = X - Y[0:X.shape[0], :]
    obs_cost = -np.sum(dx**2, axis=1)/(2*sigma**2)
    obs_cost = np.sum(obs_cost)
    
    dx = X[1::, :] - X[0:-1, :]
    tx_cost = -np.sum(dx**2, axis=1)/(2*sigma**2)
    tx_cost = np.sum(tx_cost)
    return obs_cost, tx_cost

# Create observation function
np.random.seed(0)
N = 20
t = np.linspace(0, 1, 200)
obs = 0.5*N*(1+np.array([np.cos(4*np.pi*t), np.sin(6*np.pi*t)])).T
obs += N*np.random.rand(obs.shape[0], 2)**20
obs /= np.max(obs, axis=0, keepdims=True)
obs *= N
obs = obs[0:100, :]

# Create grid
pix = np.arange(N)
x, y = np.meshgrid(pix, pix)
x = x.flatten()
y = y.flatten()
X = np.array([x, y]).T

# Compute greedy trajectory
D = get_csm(obs, X)
greedy = np.argmin(D, axis=1)
Xg = X[greedy, :]

# Initialize variable for my trajectory
XMy = []



## Have user choose their trajectory
fig = plt.figure(figsize=(8, 8))
obs_loc = plt.scatter(obs[0, 0], obs[0, 1], c='C2', s=200)


def update(x, y):
    x = np.round(x)
    y = np.round(y)
    XMy.append([x, y])
    idx = len(XMy)
    XPlot = np.array(XMy)

    plt.scatter(x, y, c='C1')
    plt.plot(XPlot[-2::, 0], XPlot[-2::, 1], c='C0')
    if idx < Xg.shape[0]-1:
        obs_loc.set_offsets([obs[idx+1, 0], obs[idx+1, 1]])
    
    
    plt.plot(obs[0:idx+2, 0], obs[0:idx+2, 1], c='C2')
    
    obs_cost, tx_cost = get_score(XPlot, obs)
    
    plt.gca().set_title("Observation Cost: {:.3f}, Transition Cost: {:.3f}".format(obs_cost, tx_cost))
    fig.canvas.draw()
    fig.canvas.flush_events()

def onclick(event):
    update(event.xdata, event.ydata)

plt.scatter(X[greedy[0], 0], X[greedy[0], 1], c='C0', s=50, zorder=20)
update(X[greedy[0], 0], X[greedy[0], 1])

# Draw grid
plt.scatter(X[:, 1], X[:, 0], c='C1')
plt.axis("equal")
plt.gca().set_xticks([])
plt.gca().set_yticks([])

plt.show()

fig.canvas.mpl_connect('button_press_event', onclick)

In [None]:
plt.plot(obs[:, 0], obs[:, 1])