In [None]:
from stochman.curves import CubicSpline
import torch 
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# constants 
N = 100 # number of nodes for cubic spline approx

In [None]:
# generate points that roughly follow sin curve between 0 and 2pi
t = torch.linspace(0, 1, N)
x = (t*torch.pi*2).view(1, N, 1).repeat(2, 1, 1)
y = torch.sin(x)+ torch.normal(0, 0.05, [2, N, 1]) # add noise

In [None]:
# plot samples
fig, axis = plt.subplots(1, 1)
axis.scatter(x[0, :, 0],y[0, :, 0])
axis.scatter(x[1, :, 0],y[1, :, 0])

pts = torch.cat((x, y), dim=2)
begin = pts[:, 0, :]
end = pts[:, -1, :]

# fit cubic spline 
cubic_spline = CubicSpline(begin=begin, end=end, num_nodes=N)
# assert boundary point constraints are satisfied
assert(torch.allclose(begin, cubic_spline(torch.tensor([0.0])).squeeze(1), atol=1e-5))
assert(torch.allclose(end, cubic_spline(torch.tensor([1.0])).squeeze(1), atol=1e-5))
cubic_spline.fit(t, pts)
cubic_spline.plot(ax=axis)
plt.show()