In [None]:
import cProfile
import io
import logging
import math
import pstats
import random

import numpy as np
import torch
import socialforce

logging.basicConfig(level=logging.INFO)

# Performance

## Fitting one Scenario

In [None]:
scenario = socialforce.scenarios.Circle().generate(1)

In [None]:
true_experience = socialforce.Trainer.scenes_to_experience(scenario)
V = socialforce.PedPedPotentialMLP().double()
initial_parameters = V.state_dict()

def simulator_factory(initial_state):
    return socialforce.Simulator(initial_state, ped_ped=V)

opt = torch.optim.SGD(V.parameters(), lr=3.0)
with cProfile.Profile() as pr:
    socialforce.Trainer(simulator_factory, opt, true_experience, batch_size=1).loop(10)

In [None]:
ps = pstats.Stats(pr).strip_dirs().sort_stats('tottime')
ps.dump_stats('simulator.prof')
ps.print_stats()

In [None]:
!flameprof simulator.prof > simulator_flame.svg

In [None]:
import IPython
IPython.display.SVG(filename='simulator_flame.svg')