## Environment

Prerequisite:

- PyTorch (tested on 1.8.1/1.10.1)
- Pyro (tested on 1.6.0)

We recommend PyTorch 1.8.1, on which the current implementation of the PnP solver runs significantly faster than PyTorch 1.10.1.

Install the python packages:

In [None]:
# CUDA 11.1
%pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
%pip install pyro-ppl==1.6.0

Clone and enter this project:

In [None]:
!git clone https://github.com/tjiiv-cprg/EPro-PnP
%cd EPro-PnP

## Fit the Identity Function

Here we demonstrate the usage of EPro-PnP by fitting a simple model `out_pose = EProPnP(MLP(in_pose))` to data points generated from the indentity function $I: SE(3) \to SE(3)$. The model takes `in_pose = [x, y, z, w, i, j, k]` as input, which is converted into a 2D-3D correspondence set by a plain MLP, and outputs the probabilistic pose through the EPro-PnP layer. 65536 data points are generated with additional noise. 

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data

from epropnp.epropnp import EProPnP6DoF
from epropnp.levenberg_marquardt import LMSolver, RSLMSolver
from epropnp.camera import PerspectiveCamera
from epropnp.cost_fun import AdaptiveHuberPnPCost
from epropnp.common import quaternion_to_rot_mat

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
n_data = 65536
batch_size = 256
n_epoch = 10
noise = 0.01

In [None]:
class Model(nn.Module):

    def __init__(
            self,
            num_points=64,  # number of 2D-3D pairs
            mlp_layers=[1024],  # a single hidden layer
            epropnp=EProPnP6DoF(
                mc_samples=512,
                num_iter=4,
                solver=LMSolver(
                    dof=6,
                    num_iter=10,
                    initial_trust_region_radius=1e3,
                    init_solver=RSLMSolver(
                        dof=6,
                        num_points=8,
                        num_proposals=128,
                        num_iter=5,
                        initial_trust_region_radius=1e3))),
            camera=PerspectiveCamera(),
            cost_fun=AdaptiveHuberPnPCost(
                relative_delta=0.5)):
        super().__init__()
        self.num_points = num_points
        mlp_layers = [7] + mlp_layers
        mlp = []
        for i in range(len(mlp_layers) - 1):
            mlp.append(nn.Linear(mlp_layers[i], mlp_layers[i + 1]))
            mlp.append(nn.LeakyReLU())
        mlp.append(nn.Linear(mlp_layers[-1], num_points * (3 + 2 + 2)))
        self.mlp = nn.Sequential(*mlp)
        # Here we use static weight_scale because the data noise is homoscedastic
        self.log_weight_scale = nn.Parameter(torch.zeros(2))
        self.epropnp = epropnp
        self.camera = camera
        self.cost_fun = cost_fun

    def forward_correspondence(self, in_pose):
        x3d, x2d, w2d = self.mlp(in_pose).reshape(-1, self.num_points, 7).split([3, 2, 2], dim=-1)
        w2d = (w2d.log_softmax(dim=-2) + self.log_weight_scale).exp()
        # equivalant to:
        # w2d = w2d.softmax(dim=-2) * self.log_weight_scale.exp()
        # alternatively we can use mean substract instead of log_softmax, both serves the purpose of 
        # normalizing scale of the weights, e.g.:
        # w2d = (w2d - w2d.mean(dim=-2, keepdim=True) - math.log(w2d.size(-2))
        #        + self.log_weight_scale).exp()
        return x3d, x2d, w2d

    def forward_train(self, in_pose, cam_mats, out_pose):
        x3d, x2d, w2d = self.forward_correspondence(in_pose)
        self.camera.set_param(cam_mats)
        self.cost_fun.set_param(x2d.detach(), w2d)  # compute dynamic delta
        pose_opt, cost, pose_opt_plus, pose_samples, pose_sample_logweights, cost_tgt = self.epropnp.monte_carlo_forward(
            x3d,
            x2d,
            w2d,
            self.camera,
            self.cost_fun,
            pose_init=out_pose,
            force_init_solve=True,
            with_pose_opt_plus=True)  # True for derivative regularization loss
        norm_factor = model.log_weight_scale.detach().exp().mean()
        return pose_opt, cost, pose_opt_plus, pose_samples, pose_sample_logweights, cost_tgt, norm_factor

In [None]:
class MonteCarloPoseLoss(nn.Module):

    def __init__(self, init_norm_factor=1.0, momentum=0.1):
        super(MonteCarloPoseLoss, self).__init__()
        self.register_buffer('norm_factor', torch.tensor(init_norm_factor, dtype=torch.float))
        self.momentum = momentum

    def forward(self, pose_sample_logweights, cost_target, norm_factor):
        """
        Args:
            pose_sample_logweights: Shape (mc_samples, num_obj)
            cost_target: Shape (num_obj, )
            norm_factor: Shape ()
        """
        if self.training:
            with torch.no_grad():
                self.norm_factor.mul_(
                    1 - self.momentum).add_(self.momentum * norm_factor)

        loss_tgt = cost_target
        loss_pred = torch.logsumexp(pose_sample_logweights, dim=0)  # (num_obj, )

        loss_pose = loss_tgt + loss_pred  # (num_obj, )
        loss_pose[torch.isnan(loss_pose)] = 0
        loss_pose = loss_pose.mean() / self.norm_factor

        return loss_pose.mean()

In [None]:
# generate data points
in_pose = torch.randn([n_data, 7], device=device)
in_pose[:, 2] += 5  # positive z, avoid points falling behind the camera plane
in_pose[:, 3:] = F.normalize(in_pose[:, 3:], dim=-1)  # normalize to unit quaternion

out_pose = in_pose + torch.randn([n_data, 7], device=device) * noise
out_pose[:, 3:] = F.normalize(out_pose[:, 3:], dim=-1)  # normalize to unit quaternion

cam_mats = torch.eye(3, device=device)

dataset = Data.TensorDataset(in_pose, out_pose)
loader = Data.DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True)

# setup model
model = Model().to(device)
mc_loss_fun = MonteCarloPoseLoss().to(device)

optimizer = torch.optim.Adam([
                {'params': model.mlp.parameters()},
                {'params': model.log_weight_scale, 'lr': 1e-2}
            ], lr=1e-4)

In [8]:
# start training
for epoch_id in range(n_epoch):
    for iter_id, (batch_in_pose, batch_out_pose) in enumerate(loader):  # for each training step
        batch_cam_mats = cam_mats.expand(batch_in_pose.size(0), -1, -1)
        _, _, pose_opt_plus, _, pose_sample_logweights, cost_tgt, norm_factor = model.forward_train(
            batch_in_pose,
            batch_cam_mats,
            batch_out_pose)

        # monte carlo pose loss
        loss_mc = mc_loss_fun(
            pose_sample_logweights,
            cost_tgt, 
            norm_factor)

        # derivative regularization
        dist_t = (pose_opt_plus[:, :3] - batch_out_pose[:, :3]).norm(dim=-1)
        beta = 1.0
        loss_t = torch.where(dist_t < beta, 0.5 * dist_t.square() / beta,
                             dist_t - 0.5 * beta)
        loss_t = loss_t.mean()

        dot_quat = (pose_opt_plus[:, None, 3:] @ batch_out_pose[:, 3:, None]).squeeze(-1).squeeze(-1)
        loss_r = (1 - dot_quat.square()) * 2
        loss_r = loss_r.mean()

        loss = loss_mc + 0.1 * loss_t + 0.1 * loss_r

        optimizer.zero_grad()
        loss.backward()

        grad_norm = []
        for p in model.parameters():
            if (p.grad is None) or (not p.requires_grad):
                continue
            else:
                grad_norm.append(torch.norm(p.grad.detach()))
        grad_norm = torch.norm(torch.stack(grad_norm))
        
        optimizer.step()

        print('Epoch {}: {}/{} - loss_mc={:.4f}, loss_t={:.4f}, loss_r={:.4f}, loss={:.4f}, norm_factor={:.4f}, grad_norm={:.4f}'.format(
            epoch_id + 1, iter_id + 1, len(loader), loss_mc, loss_t, loss_r, loss, norm_factor, grad_norm))



Epoch 1: 1/256 - loss_mc=33.2071, loss_t=1.6915, loss_r=1.4542, loss=33.5217, norm_factor=1.0000, grad_norm=3.9409
Epoch 1: 2/256 - loss_mc=34.8328, loss_t=2.1324, loss_r=1.5369, loss=35.1998, norm_factor=1.0101, grad_norm=1.4019
Epoch 1: 3/256 - loss_mc=32.6467, loss_t=1.7092, loss_r=1.5730, loss=32.9749, norm_factor=1.0188, grad_norm=2.4692
Epoch 1: 4/256 - loss_mc=33.1540, loss_t=1.3633, loss_r=1.5264, loss=33.4429, norm_factor=1.0277, grad_norm=2.2266
Epoch 1: 5/256 - loss_mc=33.8417, loss_t=1.1837, loss_r=1.4941, loss=34.1095, norm_factor=1.0368, grad_norm=0.8866
Epoch 1: 6/256 - loss_mc=33.7920, loss_t=1.4051, loss_r=1.4958, loss=34.0820, norm_factor=1.0457, grad_norm=1.0419
Epoch 1: 7/256 - loss_mc=33.4611, loss_t=1.2668, loss_r=1.5197, loss=33.7398, norm_factor=1.0550, grad_norm=0.9438
Epoch 1: 8/256 - loss_mc=32.9867, loss_t=1.0734, loss_r=1.4590, loss=33.2400, norm_factor=1.0640, grad_norm=0.8413
Epoch 1: 9/256 - loss_mc=32.8414, loss_t=0.9253, loss_r=1.5326, loss=33.0872, no