In [2]:
import torch
import coremltools as ct

scikit-learn version 1.5.1 is not supported. Minimum required version: 0.17. Maximum required version: 1.1.2. Disabling scikit-learn conversion API.
Torch version 2.3.1 has not been tested with coremltools. You may run into unexpected errors. Torch 2.2.0 is the most recent version that has been tested.


In [3]:
class SuperPointNet(torch.nn.Module):
  """ Pytorch definition of SuperPoint Network. """
  def __init__(self):
    super(SuperPointNet, self).__init__()
    self.relu = torch.nn.ReLU(inplace=True)
    self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
    
    # Shared Encoder.
    self.conv1a = torch.nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
    self.conv1b = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
    self.conv2a = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
    self.conv2b = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
    self.conv3a = torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
    self.conv3b = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
    self.conv4a = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
    self.conv4b = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
    # Detector Head.
    self.convPa = torch.nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
    self.convPb = torch.nn.Conv2d(256, 65, kernel_size=1, stride=1, padding=0)
    # Descriptor Head.
    self.convDa = torch.nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
    self.convDb = torch.nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)

  def forward(self, x):
    """ Forward pass that jointly computes unprocessed point and descriptor
    tensors.
    Input
      x: Image pytorch tensor shaped N x 1 x H x W.
    Output
      semi: Output point pytorch tensor shaped N x 65 x H/8 x W/8.
      desc: Output descriptor pytorch tensor shaped N x 256 x H/8 x W/8.
    """
    # Shared Encoder.
    x = self.relu(self.conv1a(x))
    x = self.relu(self.conv1b(x))
    x = self.pool(x)
    x = self.relu(self.conv2a(x))
    x = self.relu(self.conv2b(x))
    x = self.pool(x)
    x = self.relu(self.conv3a(x))
    x = self.relu(self.conv3b(x))
    x = self.pool(x)
    x = self.relu(self.conv4a(x))
    x = self.relu(self.conv4b(x))
    # Detector Head.
    cPa = self.relu(self.convPa(x))
    semi = self.convPb(cPa)
    # Descriptor Head.
    cDa = self.relu(self.convDa(x))
    desc = self.convDb(cDa)
    dn = torch.norm(desc, p=2, dim=1) # Compute the norm.
    desc = desc.div(torch.unsqueeze(dn, 1)) # Divide by norm to normalize.
    return semi, desc

In [4]:
model = SuperPointNet()
model.load_state_dict(torch.load('models/superpoint_v1.pth'))

<All keys matched successfully>

In [9]:
# script model with input of width 240, height 320
scripted = torch.jit.trace(model, torch.rand(1, 1, 240, 320))


# Convert the PyTorch model to CoreML format
coreml_model = ct.convert(
    scripted,
    convert_to='mlprogram',
    inputs=[ct.ImageType(name="image", shape=(1, 1, 240, 320), color_layout=ct.colorlayout.GRAYSCALE)],
    outputs=[ct.TensorType(name="semi"), ct.TensorType(name="desc")]
)
coreml_model.save("superpoint.mlpackage")

Model is not in eval mode. Consider calling '.eval()' on your model prior to conversion
Tuple detected at graph output. This will be flattened in the converted model.
Converting PyTorch Frontend ==> MIL Ops:  99%|█████████▊| 153/155 [00:00<00:00, 9181.98 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 844.13 passes/s]
Running MIL default pipeline: 100%|██████████| 78/78 [00:00<00:00, 445.13 passes/s]
Running MIL backend_mlprogram pipeline: 100%|██████████| 12/12 [00:00<00:00, 878.48 passes/s]
