In [1]:
import torch
from torch import nn
from functorch import vmap, grad, hessian
device = 'cpu'

In [2]:
# Non-batched, non-vectorized
x = torch.tensor([1., 2., 3.], requires_grad=True).to(device)
u = torch.prod(x**2+x+1, -1, True)
d1 = torch.autograd.grad(u, x, create_graph=True)[0]
d2 = torch.stack([torch.autograd.grad(d1[i], x, create_graph=True)[0][i] for i in range(len(d1))])
laplacian = torch.sum(d2, -1, True)

print(f'u = {u}\nd1 = {d1}\nd2 = {d2}\nlaplacian = {laplacian}')

u = tensor([273.], device='cuda:0', grad_fn=<ProdBackward1>)
d1 = tensor([273., 195., 147.], device='cuda:0', grad_fn=<AddBackward0>)
d2 = tensor([182.0000,  78.0000,  42.0000], device='cuda:0',
       grad_fn=<StackBackward0>)
laplacian = tensor([302.], device='cuda:0', grad_fn=<SumBackward1>)


In [3]:
# Non-batched, vectorized
x = torch.tensor([1., 2., 3.], requires_grad=True).to(device)
u = torch.prod(x**2+x+1, -1, True)
d1 = torch.autograd.grad(u, x, create_graph=True)[0]
i = torch.eye(len(d1)).to(device)
def get_vjp(i):
    return torch.autograd.grad(d1, x, i, create_graph=True)
d2 = torch.diag(vmap(get_vjp)(i)[0])
laplacian = torch.sum(d2, -1, True)

print(f'u = {u}\nd2 = {d2}\nlaplacian = {laplacian}')

u = tensor([273.], device='cuda:0', grad_fn=<ProdBackward1>)
d2 = tensor([182.0000,  78.0000,  42.0000], device='cuda:0',
       grad_fn=<DiagBackward0>)
laplacian = tensor([302.], device='cuda:0', grad_fn=<SumBackward1>)


In [4]:
# Batched, non-vectorized
x = torch.tensor([[1., 2., 3.], [4., 5., 6.]], requires_grad=True).to(device)
u = torch.prod(x**2+x+1, -1, True)
d1 = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True)[0]
d2 = torch.stack([torch.autograd.grad(d1[:,i], x, grad_outputs=torch.ones_like(d1[:,i]), create_graph=True)[0][:,i] for i in range(d1.size()[1])], 1)
laplacian = torch.sum(d2, -1, True)

print(f'u = {u}\nd1 = {d1}\nd2 = {d2}\nlaplacian = {laplacian}')

u = tensor([[  273.],
        [27993.]], device='cuda:0', grad_fn=<ProdBackward1>)
d1 = tensor([[  273.,   195.,   147.],
        [11997.,  9933.,  8463.]], device='cuda:0', grad_fn=<AddBackward0>)
d2 = tensor([[ 182.0000,   78.0000,   42.0000],
        [2666.0000, 1806.0000, 1302.0000]], device='cuda:0',
       grad_fn=<StackBackward0>)
laplacian = tensor([[ 302.],
        [5774.]], device='cuda:0', grad_fn=<SumBackward1>)


In [5]:
# Batched, vectorized (hessian)
x = torch.tensor([[1., 2., 3.], [4., 5., 6.]]).to(device)
u = lambda x: torch.prod(x**2+x+1, -1, True)
d2 = torch.diagonal(vmap(hessian(u))(x), dim1=-2, dim2=-1)
laplacian = torch.sum(d2, -1)

print(f'u = {u(x)}\nd2 = {d2}\nlaplacian = {laplacian}')

u = tensor([[  273.],
        [27993.]], device='cuda:0')
d2 = tensor([[[ 182.,   78.,   42.]],

        [[2666., 1806., 1302.]]], device='cuda:0')
laplacian = tensor([[ 302.],
        [5774.]], device='cuda:0')


In [6]:
# Batched, vectorized (grad)
x = torch.tensor([[1., 2., 3.], [4., 5., 6.]]).to(device)
u = lambda x: torch.prod(x**2+x+1, -1, True)
d2 = vmap(lambda x: torch.stack([grad(grad(lambda *x: u(torch.stack(x)).squeeze(), argnums=i), argnums=i)(*x) for i in range(len(x))]))(x)
laplacian = torch.sum(d2, -1, True)

print(f'u = {u(x)}\nd2 = {d2}\nlaplacian = {laplacian}')

u = tensor([[  273.],
        [27993.]], device='cuda:0')
d2 = tensor([[ 182.,   78.,   42.],
        [2666., 1806., 1302.]], device='cuda:0')
laplacian = tensor([[ 302.],
        [5774.]], device='cuda:0')


In [131]:
# Manual laplacian
class HelmholtzSolver(nn.Module):
    def __init__(self, ndims, N, L, activation, bounds, g=lambda x: 0):
        super(HelmholtzSolver, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(ndims, N), activation,
            *[nn.Linear(N, N), activation]*(L-1),
            nn.Linear(N, 1),
        )
        self.bounds = bounds
        self.g = g


    def forward(self, x):
        # enforce boundary condition
        return self.g(x) + torch.prod((x-self.bounds[0])*(self.bounds[1]-x), -1, True)*self.layers(x)


class Sin(nn.Module):
    def __init__(self):
        super().__init__()


    def forward(self, x):
        return torch.sin(x)


def laplacian(x):
    z = [model.layers[0].weight @ x + model.layers[0].bias]
    y = [torch.sin(z[0])]
    dy1 = [torch.cos(z[0]).unsqueeze(1) * model.layers[0].weight]
    dy2 = [-torch.sin(z[0]).unsqueeze(1) * model.layers[0].weight**2]
    for i in range(1, L):
        z.append(model.layers[2*i].weight @ y[i-1] + model.layers[2*i].bias)
        y.append(torch.sin(z[i]))
        dy1.append(torch.cos(z[i]).unsqueeze(1) * model.layers[2*i].weight @ dy1[i-1])
        dy2.append(-torch.sin(z[i]).unsqueeze(1) * (model.layers[2*i].weight @ dy1[i-1])**2 + torch.cos(z[i]).unsqueeze(1) * model.layers[2*i].weight @ dy2[i-1])
    y.append(model.layers[-1].weight @ y[-1] + model.layers[-1].bias)
    dy1.append((model.layers[-1].weight @ dy1[-1]).squeeze(0))
    dy2.append((model.layers[-1].weight @ dy2[-1]).squeeze(0))

    a = model.bounds[0]
    b = model.bounds[1]
    dg2 = 0
    du2 = dg2 + torch.prod(((x-a)*(b-x)).repeat(len(x), 1).fill_diagonal_(1), 1)*(-2*y[-1] + 2*(-2*x+a+b)*dy1[-1] + (-x**2+(a+b)*x-a*b)*dy2[-1])
    return torch.sum(du2, -1, True)


def laplacian_alt(x):
    z = [model.layers[0].weight @ x + model.layers[0].bias]
    y = [torch.sin(z[0])]
    dy1 = [torch.cos(z[0]).unsqueeze(1) * model.layers[0].weight]
    dy2 = [-torch.sin(z[0]).unsqueeze(1) * model.layers[0].weight**2]
    for i in range(1, L):
        z.append(model.layers[2*i].weight @ y[i-1] + model.layers[2*i].bias)
        y.append(torch.sin(z[i]))
        dy1.append(torch.cos(z[i]).unsqueeze(1) * model.layers[2*i].weight @ dy1[i-1])
        dy2.append(-torch.sin(z[i]).unsqueeze(1) * (model.layers[2*i].weight @ dy1[i-1])**2 + torch.cos(z[i]).unsqueeze(1) * model.layers[2*i].weight @ dy2[i-1])
    y.append(model.layers[-1].weight @ y[-1] + model.layers[-1].bias)
    dy1.append((model.layers[-1].weight @ dy1[-1]).squeeze(0))
    dy2.append((model.layers[-1].weight @ dy2[-1]).squeeze(0))

    a = model.bounds[0]
    b = model.bounds[1]
    prod_matrix = ((x-a)*(b-x)).repeat(len(x), 1)
    for i in range(len(x)):
        prod_matrix[i,i] = 1
    dg2 = 0
    du2 = dg2 + torch.prod(prod_matrix, 1)*(-2*y[-1] + 2*(-2*x+a+b)*dy1[-1] + (-x**2+(a+b)*x-a*b)*dy2[-1])
    return torch.sum(du2, -1, True)


def laplacian_multiactivation(x):
    if isinstance(activation, Sin):
        a = torch.sin
        da1 = lambda x: torch.cos(x)
        da2 = lambda x: -torch.sin(x)
    if isinstance(activation, nn.Tanh):
        a = torch.tanh
        da1 = lambda x: 1 - torch.tanh(x)**2
        da2 = lambda x: 2*torch.tanh(x)*(torch.tanh(x)**2 - 1)
    if isinstance(activation, nn.Sigmoid):
        a = torch.sigmoid
        da1 = lambda x: 1/(torch.exp(-x) + 2 + torch.exp(x))
        da2 = lambda x: (torch.exp(-x) - torch.exp(x))/(torch.exp(-x) + 2 + torch.exp(x))**2

    z = [model.layers[0].weight @ x + model.layers[0].bias]
    y = [a(z[0])]
    dy1 = [da1(z[0]).unsqueeze(1) * model.layers[0].weight]
    dy2 = [da2(z[0]).unsqueeze(1) * model.layers[0].weight**2]
    for i in range(1, L):
        z.append(model.layers[2*i].weight @ y[i-1] + model.layers[2*i].bias)
        y.append(a(z[i]))
        dy1.append(da1(z[i]).unsqueeze(1) * model.layers[2*i].weight @ dy1[i-1])
        dy2.append(da2(z[i]).unsqueeze(1) * (model.layers[2*i].weight @ dy1[i-1])**2 + da1(z[i]).unsqueeze(1) * model.layers[2*i].weight @ dy2[i-1])
    y.append(model.layers[-1].weight @ y[-1] + model.layers[-1].bias)
    dy1.append((model.layers[-1].weight @ dy1[-1]).squeeze(0))
    dy2.append((model.layers[-1].weight @ dy2[-1]).squeeze(0))

    b1 = model.bounds[0]
    b2 = model.bounds[1]
    prod_matrix = ((x-b1)*(b2-x)).repeat(len(x), 1)
    for i in range(len(x)):
        prod_matrix[i,i] = 1
    dg2 = 0
    du2 = dg2 + torch.prod(prod_matrix, 1)*(-2*y[-1] + 2*(-2*x+b1+b2)*dy1[-1] + (-x**2+(b1+b2)*x-b1*b2)*dy2[-1])
    return torch.sum(du2, -1, True)


ndims = 3
N = 10
L = 5
activation = nn.Sin()
bounds = [0, 1]
model = HelmholtzSolver(ndims, N, L, activation, bounds).to(device)
# x = torch.tensor([1., 2., 3.]).requires_grad_()
x = torch.tensor([[1., 2., 3.], [4., 5., 6.]]).requires_grad_()

print(f'Manual laplacian = {vmap(laplacian_multiactivation)(x)}')

u = model(x)
d1 = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True)[0]
d2 = torch.stack([torch.autograd.grad(d1[:,i], x, grad_outputs=torch.ones_like(d1[:,i]), create_graph=True)[0][:,i] for i in range(d1.size()[1])], 1)
laplacian = torch.sum(d2, -1, True)
print(f'Automatic laplacian = {laplacian}')

Manual laplacian = tensor([[  1.0093],
        [100.8718]], grad_fn=<SumBackward1>)
Automatic laplacian = tensor([[  1.0093],
        [100.8718]], grad_fn=<SumBackward1>)
