<a href="https://colab.research.google.com/github/profteachkids/CHE2064/blob/master/Chebyshev_Roots_Polynomial_Lagrange_Interpolation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [75]:
import jax
import jax.numpy as jnp
from jax.config import config
from jax.experimental.host_callback import id_print
from functools import partial
config.update("jax_enable_x64", True) #JAX default is 32 bit - enable 64 bit - double precision

from plotly.subplots import make_subplots
import plotly.io as pio
pio.templates.default='plotly_dark'

In [76]:
N=20
x=jnp.linspace(-1,1,N)
vander = jnp.vander(x,N)
lu, piv = jax.scipy.linalg.lu_factor(vander)
y=1/(1+25*x**2)
c = jax.scipy.linalg.lu_solve((lu,piv),y)
fig = make_subplots()
fig.add_scatter(x=x,y=y, mode='markers')
x01 = jnp.linspace(-1,1,200)
fig.add_scatter(x=x01,y=jnp.vander(x01,N)@c,mode='lines')
fig.update_layout(width=800, height=500)

In [77]:
N=20
a=-1.
b=1.
chebx = jnp.cos(jnp.pi*(jnp.arange(N-1,-1,-1)+ .5)/N)
vander = jnp.vander(chebx,N)
lu, piv = jax.scipy.linalg.lu_factor(vander)
x = chebx*(b-a)/2 + (b+a)/2
y = 1/(1+25*x**2)
c = jax.scipy.linalg.lu_solve((lu,piv),y)
fig = make_subplots()
fig.add_scatter(x=x,y=y, mode='markers')

x01 = jnp.linspace(-1,1,200)
x = x01*(b-a)/2+(b+a)/2
fig.add_scatter(x=x,y=jnp.vander(x01,N) @ c, mode='lines')
fig.update_layout(width=800, height=500)

In [96]:
class Lagrange():
  def __init__(self, x, y):
    self.N = x.size
    idxs = jnp.arange(self.N)
    self.xs = jnp.array([x[jnp.where(idxs!=i)] for i in jnp.arange(x.size)])
    self.den = jnp.array([jnp.product(x[i] - x[jnp.where(idxs!=i)]) for i in jnp.arange(x.size)])
    self.y = y
    self.Lvec = jnp.vectorize(self.L)

  @partial(jax.jit, static_argnums=(0,))
  def L(self,x):
    return jnp.sum(self.y*jnp.product(jnp.array([x-self.xs[i] for i in jnp.arange(self.N)]),1)/self.den)
  



In [101]:
xdata =jnp.array([1, 3, 8, 10])
ydata = jnp.array([2, 0.5, 3, 1.5])
L = Lagrange(xdata, ydata)
xplot = jnp.linspace(xdata[0],xdata[-1], 100)
fig = make_subplots()
fig.add_scatter(x=xdata, y=ydata, mode='markers')
fig.add_scatter(x=xplot, y=L.Lvec(xplot), mode='lines')
fig.update_layout(width=600, height=400)

In [102]:
N=20
a=-1.
b=1.
chebx = jnp.cos(jnp.pi*(jnp.arange(N-1,-1,-1)+ .5)/N)
x = chebx*(b-a)/2 + (b+a)/2
y = 1/(1+25*x**2)
L=Lagrange(x,y)
fig = make_subplots()
fig.add_scatter(x=x,y=y, mode='markers')
xplot = jnp.linspace(x[0],x[-1], 100)
fig.add_scatter(x=xplot,y=L.Lvec(xplot), mode='lines')
fig.update_layout(width=800, height=500)

In [103]:
N=10
a=0.
b=100.

chebx = jnp.cos(jnp.pi*(jnp.arange(N-1,-1,-1)+ .5)/N)
vander = jnp.vander(chebx,N)
x = chebx*(b-a)/2 + (b+a)/2
y = jnp.sin(x/10)
fig = make_subplots()
L=Lagrange(x,y)
fig.add_scatter(x=x,y=y, mode='markers')
xplot = jnp.linspace(x[0],x[-1], 100)
fig.add_scatter(x=xplot,y=L.Lvec(xplot), mode='lines')
fig.update_layout(width=800, height=500)