In [1]:
import os
import sys

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.func import jacrev, jacfwd, hessian, vmap


from nnbma.networks import NeuralNetwork

%load_ext autoreload
%autoreload 2

In [2]:
net = NeuralNetwork.load("meudon_pdr_model_dense")
net.double().eval()

list_lines = [
    "h2_v0_j2__v0_j0",
    "h2_v0_j3__v0_j1",
    "h2_v0_j4__v0_j2",
    "h2_v0_j5__v0_j3",
    "h2_v0_j6__v0_j4",
    "h2_v0_j7__v0_j5",
    #
    "co_v0_j4__v0_j3",
    "co_v0_j5__v0_j4",
    "co_v0_j6__v0_j5",
    "co_v0_j7__v0_j6",
    "co_v0_j8__v0_j7",
    "co_v0_j9__v0_j8",
    "co_v0_j10__v0_j9",
    "co_v0_j11__v0_j10",
    "co_v0_j12__v0_j11",
    "co_v0_j13__v0_j12",
]
net.restrict_to_output_subset(list_lines)
print(net)

PolynomialNetwork:
	input_features: 4
	order: 3
	subnetwork: DenselyConnected:
	input_features: 34
	output_features: 5375
	n_layers: 10
	growing_factor: 0.5
	activation: GELU(approximate='none')
	batch_norm: False
	inputs_names: None
	outputs_names: ['h2_v0_j2__v0_j0', 'h2_v0_j3__v0_j1', 'h2_v0_j4__v0_j2', 'h2_v0_j5__v0_j3', 'h2_v0_j6__v0_j4', 'h2_v0_j7__v0_j5', '...']
	inputs_transformer: None
	outputs_transformer: None
	device: cpu
	last_restrictable: True

	inputs_names: ['P', 'radm', 'Avmax', 'angle']
	outputs_names: ['h2_v0_j2__v0_j0', 'h2_v0_j3__v0_j1', 'h2_v0_j4__v0_j2', 'h2_v0_j5__v0_j3', 'h2_v0_j6__v0_j4', 'h2_v0_j7__v0_j5', '...']
	inputs_transformer: SequentialOperator: ["ColumnwiseOperator: ['log10', 'log10', 'log10', 'id']", 'Normalizer: NormTypes.MEAN0STD1']
	outputs_transformer: Operator: id
	device: cpu



In [3]:
x = np.array([[1e5, 1e0, 1.0, 0.0]])

In [4]:
jacobian_f = vmap(jacrev(net))
hessian_f = vmap(hessian(net))

In [5]:
n_points = 10
n_inputs = 4

x = np.ones((n_points, n_inputs))

print("x.shape:", x.shape)

y = net(torch.from_numpy(x)).detach().numpy()
print("y.shape:", y.shape)

x.shape: (10, 4)
y.shape: (10, 16)


In [6]:
dy = jacobian_f(torch.from_numpy(x)).detach().numpy()
print("dy.shape:", dy.shape)

dy.shape: (10, 16, 4)


In [7]:
ddy = hessian_f(torch.from_numpy(x)).detach().numpy()
print("ddy.shape:", ddy.shape)

ddy.shape: (10, 16, 4, 4)


# Measuring durations

In [21]:
# Jacobian matrix # results for 10 lines and 10 points
jacr = vmap(jacrev(net))  # 11 ms ± 460 µs
jacf = vmap(jacfwd(net))  # 15.9 ms ± 876 µs

# Hessian matrix
hess = vmap(hessian(net))  # 42.5 ms ± 3.55 ms
jacrr = vmap(jacrev(jacrev(net)))  # 112 ms ± 2.64 ms
jacrf = vmap(jacrev(jacfwd(net)))  # 69.3 ms ± 2.33 ms
jacfr = vmap(jacfwd(jacrev(net)))  # 40.3 ms ± 362 µs
jacff = vmap(jacfwd(jacfwd(net)))  # 49.2 ms ± 1.16 ms

In [33]:
n_batchs = 10_000
x = torch.normal(0, torch.ones(n_batchs, 4)).double()
x_numpy = x.numpy()

In [34]:
%%timeit
net.evaluate(x.numpy())

204 ms ± 71.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [24]:
%%timeit
torch.from_numpy(x_numpy)

1.41 µs ± 66.7 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [25]:
%%timeit
x.numpy()

2.35 µs ± 27.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


Jacobian computation

In [26]:
%%timeit
jacr(x)

308 ms ± 6.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [35]:
%%timeit
jacf(x)

1.1 s ± 28.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Hessian computation

In [28]:
%%timeit
hess(x)

2.21 s ± 86 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [29]:
%%timeit
jacrr(x)

KeyboardInterrupt: 

In [30]:
%%timeit
jacrf(x)

7.39 s ± 599 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [31]:
%%timeit
jacfr(x)

2.08 s ± 132 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [36]:
%%timeit
jacff(x)

6.93 s ± 251 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
