In [None]:
%matplotlib inline
%config InlineBackend.figure_format ='retina'

import torch
import socialforce

# Fit Diamond

This is an extension of the {ref}`fit-2d` example to study the robustness
of the inference process to potentials with gradients that change orientation.
We use a modified $V(b)$ potential that could be described as a "diamond"
of height $V_0$ and with a half-width of $\sigma$:
\begin{align}
    V(b) &= V_0 \max\left(0, 1 - \frac{L1}{2\sigma} \right)
\end{align}
with its two parameters $V_0$ and $\sigma$.

In [None]:
V = socialforce.potentials.PedPedPotentialDiamond(sigma=0.5).double()
with socialforce.show.canvas(figsize=(12, 6), ncols=2) as (ax1, ax2):
    socialforce.show.potential2D(V, ax1)
    socialforce.show.potential2D_grad(V, ax2)

## Scenarios

We generate {ref}`Circle and ParallelOvertake scenarios <scenarios>`.

In [None]:
circle = socialforce.scenarios.Circle(ped_ped=V)
parallel = socialforce.scenarios.ParallelOvertake(ped_ped=V)
scenarios = circle.generate(20, seed=42) + parallel.generate(20, seed=42)
true_experience = socialforce.Trainer.scenes_to_experience(scenarios, radius=3.0)

In [None]:
# HIDE CODE
with socialforce.show.track_canvas() as ax:
    socialforce.show.states(ax, scenarios[-1])
    for scene in scenarios[:-1]:
        socialforce.show.states(ax, scene, alpha=0.1)

## MLP

We construct an coordinate-based MLP with Fourier 
Features {cite}`tancik2020fourier,rahimi2007random`.

In [None]:
V = socialforce.potentials.PedPedPotentialMLP2D(hidden_units=64, n_fourier_features=256, fourier_scale=3.0).double()
with socialforce.show.canvas(figsize=(12, 6), ncols=2) as (ax1, ax2):
    socialforce.show.potential2D(V, ax1)
    socialforce.show.potential2D_grad(V, ax2)

In [None]:
def simulator_factory(initial_state):
    return socialforce.Simulator(initial_state, ped_ped=V)

opt = torch.optim.SGD(V.parameters(), lr=0.3)
socialforce.Trainer(simulator_factory, opt, true_experience, loss=torch.nn.L1Loss()).loop(10)

In [None]:
with socialforce.show.canvas(figsize=(12, 6), ncols=2) as (ax1, ax2):
    socialforce.show.potential2D(V, ax1)
    socialforce.show.potential2D_grad(V, ax2)