# nanotorch: Step-Through Training (Interactive)

This notebook lets you scrub through training steps and see:
- The model line vs data
- Parameter and gradient values
- Loss curve over time

It also caches traces so moving the slider is instant.


In [None]:
from nanotorch import manual_gradient, train_iter
from nanotorch.scenarios import get_scenario, list_scenarios
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

trace_cache = {}

def build_trace(name: str):
    scenario = get_scenario(name)
    rule = manual_gradient(scenario.grad)
    states = list(train_iter(
        scenario.data,
        scenario.params,
        scenario.predict,
        scenario.loss,
        rule,
        steps=scenario.steps,
        lr=scenario.lr,
    ))
    return scenario, states

def get_trace(name: str):
    if name not in trace_cache:
        trace_cache[name] = build_trace(name)
    return trace_cache[name]

def dict_to_table(title: str, d: dict) -> str:
    rows = ''.join([f"<tr><td>{k}</td><td>{v:.6f}</td></tr>" for k, v in d.items()])
    return (
        f"<h4>{title}</h4>"
                + "<table>"
        + "<thead><tr><th>name</th><th>value</th></tr></thead>"
        + f"<tbody>{rows}</tbody>"
        + "</table>"
    )


In [None]:
scenario_selector = widgets.Dropdown(
    options=list_scenarios(),
    value='single_point',
    description='Scenario:'
)
step_slider = widgets.IntSlider(min=0, max=1, step=1, value=0, description='Step:')
out = widgets.Output()
params_table = widgets.HTML()
grads_table = widgets.HTML()

def render(scenario, states, step_idx):
    state = states[step_idx]
    xs = [x for x, _ in scenario.data]
    ys = [y for _, y in scenario.data]

    x_min = min(xs) - 1
    x_max = max(xs) + 1
    line_x = [x_min + i * (x_max - x_min) / 50 for i in range(51)]
    line_y = [scenario.predict(x, state.params) for x in line_x]

    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    # Left panel: data + model line
    ax = axes[0]
    ax.scatter(xs, ys, color='black', label='data')
    ax.plot(line_x, line_y, color='blue', label='model')
    ax.set_title(f'Step {state.step} | loss={state.loss:.4f}')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.legend()

    # Right panel: loss curve
    ax = axes[1]
    losses = [s.loss for s in states]
    ax.plot(range(len(losses)), losses, color='purple')
    ax.scatter([step_idx], [state.loss], color='red')
    ax.set_title('Loss over steps')
    ax.set_xlabel('step')
    ax.set_ylabel('loss')

    plt.tight_layout()
    plt.show()

    params_table.value = dict_to_table('params', state.params)
    grads_table.value = dict_to_table('grads', state.grads)

def on_change(_):
    scenario, states = get_trace(scenario_selector.value)
    step_slider.max = len(states) - 1
    step_slider.value = 0
    with out:
        clear_output(wait=True)
        render(scenario, states, step_slider.value)

def on_step_change(change):
    scenario, states = get_trace(scenario_selector.value)
    with out:
        clear_output(wait=True)
        render(scenario, states, change.new)

scenario_selector.observe(on_change, names='value')
step_slider.observe(on_step_change, names='value')

# Initial render
on_change(None)
display(scenario_selector, step_slider, out, params_table, grads_table)
