In [1]:
from latent_geometry.mapping.torch import TorchModelMapping
import numpy as np

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()
print(net)

Net(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=16, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [3]:
in_tensor = torch.randn(1, 1, 16, 16)
out_tensor = net(in_tensor)
print(out_tensor)

tensor([[ 0.0584,  0.0179,  0.1008,  0.0430, -0.0197, -0.0136,  0.0521,  0.0640,
          0.0145,  0.0269]], grad_fn=<AddmmBackward0>)


In [4]:
torch_mapping = TorchModelMapping(net, in_shape=(1, 1, 16, 16))

In [5]:
in_numpy = in_tensor.numpy().reshape(-1)

In [6]:
torch_mapping(in_numpy)

array([ 0.05837061,  0.01793842,  0.10083365,  0.04300383, -0.01971399,
       -0.01355287,  0.05213169,  0.06402797,  0.01449347,  0.02694278],
      dtype=float32)

In [7]:
J = torch_mapping.jacobian(in_numpy)
print(f"{J.shape=}")
print(f"{in_tensor.shape=}, {out_tensor.shape=}")
print(f"{np.prod(in_tensor.shape)=}, {np.prod(out_tensor.shape)=}")
J

J.shape=(10, 256)
in_tensor.shape=torch.Size([1, 1, 16, 16]), out_tensor.shape=torch.Size([1, 10])
np.prod(in_tensor.shape)=256, np.prod(out_tensor.shape)=10


array([[ 2.43085655e-04,  1.55304078e-04, -2.90749536e-04, ...,
         1.40571457e-04, -1.71413594e-05, -1.53054116e-05],
       [ 3.86251440e-06, -1.28810672e-04, -4.23813144e-05, ...,
        -4.27645748e-04, -3.66685010e-04, -3.27410729e-04],
       [-9.04537137e-06, -2.55093502e-04, -6.08948612e-05, ...,
        -2.22120652e-04, -1.87315727e-05, -1.67253038e-05],
       ...,
       [ 9.68107634e-06,  4.29770676e-04,  1.10262525e-04, ...,
         1.52338189e-05, -4.84039083e-05, -4.32195448e-05],
       [ 1.41447483e-04, -1.75945970e-04, -2.45786185e-04, ...,
        -1.48724197e-04, -1.12677364e-04, -1.00608908e-04],
       [ 7.91204438e-05, -1.01443329e-04, -1.38353949e-04, ...,
        -4.21404780e-04, -1.49548403e-04, -1.33530833e-04]], dtype=float32)

In [8]:
DM = torch_mapping.metric_matrix_derivative(in_numpy)
print(f"{DM.shape=}")
print(f"{in_tensor.shape=}, {out_tensor.shape=}")
print(f"{np.prod(in_tensor.shape)=}, {np.prod(out_tensor.shape)=}")
DM

DM.shape=(256, 256, 256)
in_tensor.shape=torch.Size([1, 1, 16, 16]), out_tensor.shape=torch.Size([1, 10])
np.prod(in_tensor.shape)=256, np.prod(out_tensor.shape)=10


array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0.