In [1]:
from latent_geometry.mapping import TorchModelMapping
import numpy as np

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()
print(net)

Net(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=16, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [3]:
in_tensor = torch.randn(1, 1, 16, 16)
out_tensor = net(in_tensor)
print(out_tensor)

tensor([[ 0.0323, -0.0169,  0.0084, -0.0244, -0.0995,  0.0110, -0.1254,  0.2002,
         -0.0330, -0.0190]], grad_fn=<AddmmBackward0>)


In [4]:
torch_mapping = TorchModelMapping(net, in_shape=(1, 1, 16, 16), out_shape=(1, 10))

In [5]:
in_numpy = in_tensor.numpy().reshape(-1)

In [6]:
torch_mapping(in_numpy)

array([ 0.03230598, -0.01692359,  0.00841565, -0.02437422, -0.09945782,
        0.01103655, -0.12544382,  0.20019071, -0.03297206, -0.01902856],
      dtype=float32)

In [7]:
J = torch_mapping.jacobian(in_numpy)
print(f"{J.shape=}")
print(f"{in_tensor.shape=}, {out_tensor.shape=}")
print(f"{np.prod(in_tensor.shape)=}, {np.prod(out_tensor.shape)=}")
J

J.shape=(10, 256)
in_tensor.shape=torch.Size([1, 1, 16, 16]), out_tensor.shape=torch.Size([1, 10])
np.prod(in_tensor.shape)=256, np.prod(out_tensor.shape)=10


array([[-2.05029628e-05,  8.89252915e-05, -5.45427087e-04, ...,
        -2.04503769e-03,  5.41817222e-04,  0.00000000e+00],
       [-1.10362154e-04,  4.78661852e-04,  7.44492165e-04, ...,
        -8.21662194e-04,  6.60911552e-04,  0.00000000e+00],
       [-4.98314330e-05,  2.16128494e-04,  6.32408541e-04, ...,
         1.08441069e-04,  1.96131237e-04,  0.00000000e+00],
       ...,
       [ 1.62661454e-05, -7.05494022e-05,  1.15997740e-03, ...,
        -3.24699358e-04,  2.41960850e-04,  0.00000000e+00],
       [-3.81119098e-05,  1.65298683e-04,  3.93068651e-04, ...,
        -3.00431100e-04,  8.78821738e-05,  0.00000000e+00],
       [-2.59192461e-06,  1.12416756e-05, -8.87307280e-04, ...,
         1.92417225e-04,  2.49991077e-04,  0.00000000e+00]], dtype=float32)

In [8]:
H = torch_mapping.second_derivative(in_numpy)
print(f"{H.shape=}")
print(f"{in_tensor.shape=}, {out_tensor.shape=}")
print(f"{np.prod(in_tensor.shape)=}, {np.prod(out_tensor.shape)=}")
H

H.shape=(10, 256, 256)
in_tensor.shape=torch.Size([1, 1, 16, 16]), out_tensor.shape=torch.Size([1, 10])
np.prod(in_tensor.shape)=256, np.prod(out_tensor.shape)=10


array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0.