In [1]:
import numpy as np

import matplotlib.pyplot as plt
import numpy as np
import torch
from ipywidgets import interact

params = {
    "figure.figsize": (6, 4),
    "figure.dpi": 72,
    "axes.titlesize": 14,
    "axes.labelsize": 14,
    "font.size": 14,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
    "legend.fontsize": 14,
    "savefig.bbox": "tight",
    "figure.constrained_layout.use": True,
}
plt.rcParams.update(params)

In [44]:
n_x = 101
x = torch.linspace(0, 1, n_x)
dx = torch.diff(x)
y = torch.zeros(n_x - 2)
y = torch.nn.Parameter(y)
y0 = torch.tensor(0.0)
y1 = torch.tensor(1.0)
n_epoch = 20000

def refraction_index(y):
    # return torch.where(y < 0.5, 1, 0.5)
    return torch.maximum(1 + y**2, torch.zeros_like(y))

def compute_time(y):
    y = torch.hstack([y0, y, y1])
    dy = torch.diff(y)
    ds = torch.sqrt(dx ** 2 + dy ** 2)
    x_c = 0.5 * (x[:-1] + x[1:])
    y_c = 0.5 * (y[:-1] + y[1:])
    n = refraction_index(y_c)
    dt = n * ds
    return dt.sum()

def train():
    y_history = []
    optimizer = torch.optim.Rprop([y])
    for epoch in range(n_epoch):
        optimizer.zero_grad()
        loss = compute_time(y)
        loss.backward()
        optimizer.step()

        if epoch % 100 == 0:
            print(f"Epoch {epoch:d}, loss {loss.item():g}")

        if epoch % 1 == 0:
            y_history.append(y.clone().detach().numpy())

    return np.asarray(y_history)

In [45]:
y_history = train()

Epoch 0, loss 2.24006
Epoch 100, loss 2.61512
Epoch 200, loss 2.67604
Epoch 300, loss 2.52688
Epoch 400, loss 2.39677
Epoch 500, loss 2.29788
Epoch 600, loss 2.22019
Epoch 700, loss 2.16447
Epoch 800, loss 2.11411
Epoch 900, loss 2.06749
Epoch 1000, loss 2.02628
Epoch 1100, loss 1.99122
Epoch 1200, loss 1.96475
Epoch 1300, loss 1.94373
Epoch 1400, loss 1.92775
Epoch 1500, loss 1.9151
Epoch 1600, loss 1.90363
Epoch 1700, loss 1.89431
Epoch 1800, loss 1.8863
Epoch 1900, loss 1.87984
Epoch 2000, loss 1.87492
Epoch 2100, loss 1.87052
Epoch 2200, loss 1.86686
Epoch 2300, loss 1.86365
Epoch 2400, loss 1.86102
Epoch 2500, loss 1.85888
Epoch 2600, loss 1.85726
Epoch 2700, loss 1.85557
Epoch 2800, loss 1.85421
Epoch 2900, loss 1.85284
Epoch 3000, loss 1.85185
Epoch 3100, loss 1.85117
Epoch 3200, loss 1.85055
Epoch 3300, loss 1.85003
Epoch 3400, loss 1.84952
Epoch 3500, loss 1.84917
Epoch 3600, loss 1.84883
Epoch 3700, loss 1.84854
Epoch 3800, loss 1.84828
Epoch 3900, loss 1.84801
Epoch 4000, lo

In [46]:
def plot(epoch):
    y = np.hstack([y0, y_history[epoch], y1])
    plt.plot(x.numpy(), y)
    plt.grid()
    plt.title(f"Epoch {epoch:d}")
    plt.show()

interact(plot, epoch=(0, n_epoch - 1));

interactive(children=(IntSlider(value=9999, description='epoch', max=19999), Output()), _dom_classes=('widget-…