In [1]:
import pybamm
import casadi
import numpy as np
from scipy.sparse import eye, linalg, csr_matrix

In [2]:
model = pybamm.lithium_ion.DFN()
param = model.default_parameter_values
param["Negative electrode porosity"] = 0.3
param["Separator porosity"] = 0.3
param["Positive electrode porosity"] = 0.3
param["Cation transference number"] = pybamm.InputParameter("t")

# param = pybamm.ParameterValues({})
# model = pybamm.BaseModel()
# v = pybamm.Variable("v")
# t = pybamm.InputParameter("t")
# model.rhs = {v: t}
# model.initial_conditions = {v: 1}
# model.variables = {"Terminal voltage [V]": v}

solver = pybamm.CasadiSolver(mode="fast")  # , sensitivity=True)
sim = pybamm.Simulation(model, parameter_values=param, solver=solver)
t_eval = np.linspace(0,3600,50)
sol = sim.solve(t_eval, inputs={"t": 0.5})

# print(sol["X-averaged electrolyte concentration"].data)
var = sol["Terminal voltage [V]"]

t = casadi.MX.sym("t")
y = casadi.MX.sym("y", sim.built_model.len_rhs_and_alg)
p = casadi.MX.sym("p")

rhs = casadi.vertcat(
    sim.built_model.casadi_rhs(t, y, p),
    sim.built_model.casadi_algebraic(t, y, p),
)

jac_x_func = casadi.Function("jac_x", [t, y, p], [casadi.jacobian(rhs, y)])
jac_p_func = casadi.Function("jac_x", [t, y, p], [casadi.jacobian(rhs, p)])

In [3]:
# sim_with_sens = pybamm.Simulation(model, parameter_values=param, 
# #                         solver=pybamm.CasadiSolver(mode="fast", sensitivity=True)
#                         solver=pybamm.CasadiSolver(mode="fast", sensitivity=True)
# )
# sol_with_sens = sim_with_sens.solve(
#     np.linspace(0,3600,50), 
#     inputs={"t": 0.5}, 
# )

In [4]:
sol.solve_time

0.10660402900000054

In [5]:
# sol_with_sens.solve_time

In [6]:
inp = 0.5
x0 = sim.built_model.init_eval(p)
S_0 = casadi.Function("S_0", [p], [casadi.jacobian(x0,p)])(inp)

In [7]:
%%time
n = sim.built_model.len_rhs_and_alg
for idx in range(len(sol.t)):
    ti = sol.t[idx]
    ui = sol.y[:, idx]
    next_jac_x_eval = jac_x_func(ti, ui, inp)
    next_jac_p_eval = jac_p_func(ti, ui, inp)
    if idx == 0:
        jac_x_eval = next_jac_x_eval
        jac_p_eval = next_jac_p_eval
    else:
        jac_x_eval = casadi.diagcat(jac_x_eval, next_jac_x_eval)
        jac_p_eval = casadi.vertcat(jac_p_eval, next_jac_p_eval)

CPU times: user 170 ms, sys: 21.4 ms, total: 191 ms
Wall time: 190 ms


In [8]:
jac_x_eval.shape

(68050, 68050)

In [9]:
jac_p_eval.shape

(68050, 1)

In [10]:
%%time
i=0
jac_x_eval[n*(i+1):n*(i+2),n*(i+1):n*(i+2)]

CPU times: user 301 µs, sys: 6 µs, total: 307 µs
Wall time: 312 µs


DM(sparse: 1361-by-1361, 4868 nnz
 (1, 1) -> -31728.4
 (2, 1) -> 3525.38
 (1, 2) -> 31728.4
 ...
 (1300, 1360) -> -1.21814
 (1359, 1360) -> -1670.37
 (1360, 1360) -> 1671.59)

In [12]:
%%time
# Solve for sensitivities symbolically
# Forward Euler
# Sx_all = Sx_0.full()
# S_x = Sx_0
# n = sim.built_model.len_rhs
# for i in range(len(sol.t)-1):
#     dt = sol.t[i+1] - sol.t[i]
#     S_x = dt * S_x + jac_x_eval[n*i:n*(i+1),n*i:n*(i+1)] @ S_x + jac_p_eval[n*i:n*(i+1)]
#     Sx_all = np.hstack([Sx_all, S_x.full()])
# Backward Euler
# Sx_all = Sx_0.full()
# S_x = Sx_0
# n = sim.built_model.len_rhs
# for i in range(len(sol.t)-1):
#     dt = sol.t[i+1] - sol.t[i]
#     A = np.eye(n) - dt * jac_x_eval[n*(i+1):n*(i+2),n*(i+1):n*(i+2)]
#     b = dt * jac_p_eval[n*(i+1):n*(i+2)] + S_x
#     S_x = np.linalg.solve(A,b)
#     Sx_all = np.hstack([Sx_all, S_x])
# Crank-Nicolson
Sx_all = S_0.full()
S_x = S_0

timer = pybamm.Timer()
I = casadi.DM.eye(n)
I2 = np.eye(n)
# jxf = jac_x_eval.full()
for i in range(len(sol.t)-1):
#     print(1, timer.time())
#     timer.reset()
    dt = sol.t[i+1] - sol.t[i]
#     print(2, timer.time())
#     timer.reset()
    A = (
#         I2 - dt / 2 * jac_x_eval[n*(i+1):n*(i+2),n*(i+1):n*(i+2)].full()
        I - dt / 2 * jac_x_eval[n*(i+1):n*(i+2),n*(i+1):n*(i+2)]
    ).full()
#     print(3, timer.time())
#     timer.reset()
    b = (
        dt / 2 * (jac_p_eval[n*i:n*(i+1)] + jac_p_eval[n*(i+1):n*(i+2)])
        + (I + dt / 2 * jac_x_eval[n*i:n*(i+1),n*i:n*(i+1)]) @ S_x
    ).full()
#     print(4, timer.time())
#     timer.reset()
    S_x = np.linalg.solve(A,b)
#     print(5, timer.time())
#     timer.reset()
    Sx_all = np.hstack([Sx_all, S_x])

CPU times: user 42.1 s, sys: 6.11 s, total: 48.2 s
Wall time: 6.39 s


Solve with casadi integrator

In [18]:
S_x = casadi.SX.sym("S_x", n)
ode = jac_

In [14]:
next_jac_x_eval

DM(sparse: 1361-by-1361, 4868 nnz
 (1, 1) -> -31728.4
 (2, 1) -> 3525.38
 (1, 2) -> 31728.4
 ...
 (1300, 1360) -> -1.30835
 (1359, 1360) -> -1643.65
 (1360, 1360) -> 1644.96)

In [None]:
%%time
np.linalg.solve(A,b)
b.shape

In [None]:
jac_x_eval[n*(i+1):n*(i+2),n*(i+1):n*(i+2)]

In [None]:
Sx_all[:,1][61:] / (sol_with_sens.sensitivity["t"][121:242])[61:].T

In [None]:
Sx_all[:,1][61:] - (sol_with_sens.sensitivity["t"][121:242])[61:].T

In [None]:
%%time
# Convert variable to casadi format for differentiating
var_casadi = var.base_variable.to_casadi(t, y, inputs={"t": p})
dvar_dy = casadi.jacobian(var_casadi, y)
dvar_dp = casadi.jacobian(var_casadi, p)

# Convert to functions and evaluate index-by-index
dvar_dy_func = casadi.Function(
    "dvar_dy", [t, y, p], [dvar_dy]
)
dvar_dp_func = casadi.Function(
    "dvar_dp", [t, y, p], [dvar_dp]
)
for idx in range(len(var.t_sol)):
    ti = var.t_sol[idx]
    ui = var.u_sol[:, idx]
    next_dvar_dy_eval = dvar_dy_func(ti, ui, inp)
    next_dvar_dp_eval = dvar_dp_func(ti, ui, inp)
    if idx == 0:
        dvar_dy_eval = next_dvar_dy_eval
        dvar_dp_eval = next_dvar_dp_eval
    else:
        dvar_dy_eval = casadi.vertcat(dvar_dy_eval, next_dvar_dy_eval)
        dvar_dp_eval = casadi.vertcat(dvar_dp_eval, next_dvar_dp_eval)

In [None]:
S_var = dvar_dy_eval @ Sx_all + dvar_dp_eval

In [None]:
dvar_dy_eval.shape

In [None]:
sol_with_sens["Terminal voltage [V]"].sensitivity["all"].shape

In [None]:
np.diag(S_var - sol_with_sens["Terminal voltage [V]"].sensitivity["all"])

## Finite difference for comparison

In [17]:
%%time
h = 1e-8
sol_fd = (
    sim.solve(t_eval, inputs={"t": 0.5+h})["Terminal voltage [V]"].data
    - sim.solve(t_eval, inputs={"t": 0.5})["Terminal voltage [V]"].data
) / h

CPU times: user 227 ms, sys: 6.25 ms, total: 233 ms
Wall time: 231 ms


In [None]:
sol_fd- sol_with_sens["Terminal voltage [V]"].sensitivity["all"]

In [16]:
%%time
sim.solve(t_eval, inputs={"t": 0.5})

CPU times: user 122 ms, sys: 6.74 ms, total: 129 ms
Wall time: 127 ms


<pybamm.solvers.solution.Solution at 0x142ad69d0>

In [17]:
from scipy.interpolate import interp1d

In [19]:
%%time
n = sim.built_model.len_rhs_and_alg
for idx in range(len(sol.t)):
    ti = sol.t[idx]
    ui = sol.y[:, idx]
    next_jac_x_eval = jac_x_func(ti, ui, inp)
    next_jac_p_eval = jac_p_func(ti, ui, inp)
    if idx == 0:
        jac_x_eval = next_jac_x_eval
        jac_p_eval = next_jac_p_eval
    else:
        jac_x_eval = casadi.diagcat(jac_x_eval, next_jac_x_eval)
        jac_p_eval = casadi.vertcat(jac_p_eval, next_jac_p_eval)

CPU times: user 165 ms, sys: 99.4 ms, total: 264 ms
Wall time: 263 ms


In [22]:
A = casadi.DM(np.random.rand(5,5))
b = casadi.DM([1,2,3,4,5])
casadi.solve(A,b)

DM([3.00095, -3.09973, 0.730214, 0.427567, 2.73288])