# Bundle Adjustment Example using PyPose and the BAL dataset

```
The dataset is from the following paper:  
Sameer Agarwal, Noah Snavely, Steven M. Seitz, and Richard Szeliski.  
Bundle adjustment in the large.  
In European Conference on Computer Vision (ECCV), 2010.  
```

Link to the dataset: https://grail.cs.washington.edu/projects/bal/

# Sparse-Jac

In [1]:
from bal_loader import build_pipeline

# TARGET_DATASET = "trafalgar"
# TARGET_PROBLEM = "problem-257-65132-pre"
# MAX_PIXELS = 225911

TARGET_DATASET = "ladybug"
TARGET_PROBLEM = "problem-49-7776-pre"
MAX_PIXELS = 31843

DEVICE = 'cuda'

def filter_problem(x):
  return x['problem_name'] == TARGET_PROBLEM

dataset_pipeline = build_pipeline(dataset=TARGET_DATASET, cache_dir='bal_data').filter(filter_problem)
dataset_iterator = iter(dataset_pipeline)
dataset = next(dataset_iterator)

print(f'Fetched {TARGET_PROBLEM} from {TARGET_DATASET}')

import torch
import torch.nn as nn
import pypose as pp

def trim_dataset(dataset, max_pixels):
  trimmed_dataset = dict()
  trimmed_dataset['points_2d'] = dataset['points_2d'][:max_pixels]
  trimmed_dataset['point_index_of_observations'] = dataset['point_index_of_observations'][:max_pixels]
  trimmed_dataset['camera_index_of_observations'] = dataset['camera_index_of_observations'][:max_pixels]
  # other fields are not changed
  trimmed_dataset['camera_extrinsics'] = dataset['camera_extrinsics']
  trimmed_dataset['camera_intrinsics'] = dataset['camera_intrinsics']
  trimmed_dataset['camera_distortions'] = dataset['camera_distortions']
  trimmed_dataset['points_3d'] = dataset['points_3d']

  for k in trimmed_dataset.keys():
    if not isinstance(trimmed_dataset[k], torch.Tensor):
      trimmed_dataset[k] = torch.from_numpy(trimmed_dataset[k])
    trimmed_dataset[k] = trimmed_dataset[k].to(DEVICE)
  return trimmed_dataset

def reprojerr(pose, points, pixels, intrinsics, distortions, camera_index, point_index):
  points = points[point_index, None] # [1000, 1, 3]
  pose = pose[camera_index] # [1000, 7]
  points = pose.unsqueeze(-2) @ points
  points = points.squeeze(-2)

  # perspective division
  points_proj = -points[:, :2] / points[:, -1:]

  # convert to pixel coordinates
  # intrinsics = intrinsics[camera_index]
  # distortions = distortions[camera_index]
  f = intrinsics[:, 0, 0]
  k1 = distortions[:, 0]
  k2 = distortions[:, 1]
  n = torch.sum(points_proj**2, dim=-1)
  r = 1.0 + k1 * n + k2 * n**2
  img_repj = f[:, None] * r[:, None] * points_proj

  # calculate the reprojection error
  loss = (img_repj - pixels).norm(dim=-1)

  return loss

trimmed_dataset = trim_dataset(dataset, max_pixels=MAX_PIXELS)

# errors = reprojerr(trimmed_dataset['camera_extrinsics'], # [49, 7], a LieTensor
#                    trimmed_dataset['points_3d'], # [7776, 3]
#                    trimmed_dataset['points_2d'], # [1000, 2] before trimmed [31843, 2]
#                    trimmed_dataset['camera_intrinsics'], # [49, 3, 3]
#                    trimmed_dataset['camera_distortions'],) # [49, 2
                  #  trimmed_dataset['point_index_of_observations'], # [1000] this 2-D point is produced by which 3-D point
                  #  trimmed_dataset['camera_index_of_observations']) # [1000] this 2-D point is produced by which camera


Streaming data for ladybug...
Fetched problem-49-7776-pre from ladybug


In [2]:
def reprojerr_vmap(pose, point, pixel, intrinsic, distortion):
  # reprojerr_vmap is not batched, it operates on a single 3D point and camera
  pose = pp.LieTensor(pose, ltype=pp.SE3_type) # pose will lose its ltype through vmap, temporary fix
  point = pose.unsqueeze(-2) @ point
  point = point.squeeze(-2)

  # perspective division
  point_proj = -point[:2] / point[-1:]

  # convert to pixel coordinates
  f = intrinsic[0, 0]
  k1 = distortion[0]
  k2 = distortion[1]
  n = torch.sum(point_proj**2, dim=-1)
  r = 1.0 + k1 * n + k2 * n**2
  img_repj = f * r * point_proj

  # calculate the reprojection error
  loss = (img_repj - pixel).norm(dim=-1)

  return loss

# jac_function_vmap = torch.vmap(pp.func.jacrev(reprojerr_vmap))
# jac_from_vmap = jac_function_vmap(trimmed_dataset['camera_extrinsics'][trimmed_dataset['camera_index_of_observations']],
#                       trimmed_dataset['points_3d'][trimmed_dataset['point_index_of_observations'], None],
#                       trimmed_dataset['points_2d'],
#                       trimmed_dataset['camera_intrinsics'][trimmed_dataset['camera_index_of_observations']],
#                       trimmed_dataset['camera_distortions'][trimmed_dataset['camera_index_of_observations']])
# print(jac_from_vmap.shape)

def reprojerr_gen(*args):
    if args[4].shape[0] == MAX_PIXELS:
        return reprojerr(*args)
    else:
        args = list(args)
        return reprojerr_vmap(*args[:-2])

# sparse version
class ReprojNonBatched(nn.Module):
    def __init__(self, camera_extrinsics, points_3d):
        super().__init__()
        self.pose = pp.Parameter(camera_extrinsics)
        self.points_3d = nn.Parameter(points_3d)

    def forward(self, *args):
        return reprojerr_gen(self.pose, self.points_3d, *args)


In [3]:
# densed version
class ReprojBatched(nn.Module):
    def __init__(self, camera_extrinsics, points_3d):
        super().__init__()
        self.pose = nn.Parameter(camera_extrinsics)
        self.points_3d = nn.Parameter(points_3d)

    def forward(self, *args):
        return reprojerr(self.pose, self.points_3d, *args)


In [4]:
from pypose.optim.solver import Krylov
input = [trimmed_dataset['points_2d'],
         trimmed_dataset['camera_intrinsics'][trimmed_dataset['camera_index_of_observations']],
         trimmed_dataset['camera_distortions'][trimmed_dataset['camera_index_of_observations']],
         trimmed_dataset['camera_index_of_observations'],
         trimmed_dataset['point_index_of_observations']]


In [5]:
def least_square_error(pose, points, pixels, intrinsics, distortions, camera_index, point_index):
  points = points[point_index, None] # [1000, 1, 3]
  pose = pose[camera_index] # [1000, 7]
  points = pose.unsqueeze(-2) @ points
  points = points.squeeze(-2)

  # perspective division
  points_proj = -points[:, :2] / points[:, -1:]

  # convert to pixel coordinates
  # intrinsics = intrinsics[camera_index]
  # distortions = distortions[camera_index]
  f = intrinsics[:, 0, 0]
  k1 = distortions[:, 0]
  k2 = distortions[:, 1]
  n = torch.sum(points_proj**2, dim=-1)
  r = 1.0 + k1 * n + k2 * n**2
  img_repj = f[:, None] * r[:, None] * points_proj

  # calculate the least square error
  return (torch.flatten(img_repj - pixels) ** 2).sum() / 2

In [None]:
%%time
model_non_batched = ReprojNonBatched(trimmed_dataset['camera_extrinsics'].clone(),
                                     trimmed_dataset['points_3d'].clone())

model_non_batched = model_non_batched.to(DEVICE)

strategy_sparse = pp.optim.strategy.Adaptive(damping=1e-6)
#optimizer_sparse = pp.optim.LM(model_non_batched, strategy=strategy_sparse, solver=Krylov(rtol=3e-4, iterations=5000), reject=1)
optimizer_sparse = pp.optim.LM(model_non_batched, strategy=strategy_sparse, dense_A=True)

print('starting loss:', least_square_error(model_non_batched.pose, model_non_batched.points_3d, *input).item())
for idx in range(100):
    loss = optimizer_sparse.step(input)
    # if (idx + 1) % 10 == 0:
    print('Least square loss %.3f @ %d it'%(least_square_error(model_non_batched.pose, model_non_batched.points_3d, *input).item(), idx))
    if loss < 1e-5:
        print('Early Stopping with loss:', loss.item())
        break
print('ending loss:', least_square_error(model_non_batched.pose, model_non_batched.points_3d, *input).item())


starting loss: 850912.75
Least square loss 1277107.000 @ 0 it
Least square loss 1276469.250 @ 1 it
Least square loss 1275283.250 @ 2 it
Least square loss 1269541.500 @ 3 it
Least square loss 1260798.625 @ 4 it
Least square loss 1254494.375 @ 5 it
Least square loss 1247155.125 @ 6 it
Least square loss 1239018.500 @ 7 it
Least square loss 1234315.750 @ 8 it
Least square loss 1231778.250 @ 9 it
Least square loss 1229343.875 @ 10 it
Least square loss 1227006.000 @ 11 it
Least square loss 1224723.000 @ 12 it
Least square loss 1222579.500 @ 13 it
Least square loss 1220205.375 @ 14 it
Least square loss 1217971.500 @ 15 it
Least square loss 1215594.500 @ 16 it
Least square loss 1213365.125 @ 17 it
Least square loss 1211166.000 @ 18 it
Least square loss 1208939.250 @ 19 it
Least square loss 1206652.250 @ 20 it
Least square loss 1204486.000 @ 21 it
Least square loss 1202202.875 @ 22 it
Least square loss 1200186.750 @ 23 it
Least square loss 1197985.250 @ 24 it
Least square loss 1195808.250 @ 25 

In [25]:
params = dict(model_non_batched.named_parameters())
points_3d_opt = params['points_3d'].to(device='cpu').detach().numpy()

In [28]:
Initial error: 2.45619e+07
points_3d_opt

array([[ 1.7799599e+00,  5.3873158e-01, -3.1275871e+00],
       [ 1.8246309e+00,  5.4241651e-01, -3.1123295e+00],
       [ 5.5786234e-01,  2.3620661e-01, -2.7201819e+00],
       ...,
       [ 1.3664684e+01,  2.2823713e+00,  2.1599278e+00],
       [-8.1716841e-01,  4.9323239e-03, -2.6544741e-01],
       [-5.9021598e-01,  2.1373612e-01, -1.3738620e+00]], dtype=float32)

In [27]:
import numpy as np
with open('points_3d_opt_problem-257-65132-pre.npy', 'wb') as f:
    np.save(f, points_3d_opt)

In [None]:

model_batched = ReprojBatched(trimmed_dataset['camera_extrinsics'].clone(), trimmed_dataset['points_3d'].clone())

strategy_dense = pp.optim.strategy.Adaptive(damping=1e-6)
optimizer_dense = pp.optim.LM(model_batched, strategy=strategy_dense, sparse=False)
for idx in range(3):
    loss = optimizer_dense.step(input)
    print('Pose Inversion loss %.7f @ %d it'%(loss, idx))
    if loss < 1e-5:
        print('Early Stopping with loss:', loss.item())
        break


# Construct SBT from vmap

In [17]:
import pypose as pp
i = [[0, 0, 1, 1], [0, 2, 1, 2]]
v = torch.arange(12).view((-1, 1, 3)).to(dtype=torch.float32)
x = pp.sbktensor(i, v, size=(3, 3), dtype=torch.float32)
storage_t = x._s

print(f'shape: {storage_t.shape}, dense dim: {storage_t.dense_dim()}, \
sparse: {storage_t.sparse_dim()}')
flattened_coo = pp.hybrid2coo(storage_t) # the flattened tensor
print(flattened_coo.shape)
flattened_coo_T = flattened_coo.T
print(flattened_coo_T.shape)
A = torch.sparse.mm(flattened_coo_T, flattened_coo)
print(A.shape)
print(A.type())
# pdt = torch.matmul(flattened_coo_T, torch.arange(1000).view(-1, 1).to(dtype=torch.float32))
# print(pdt.shape)
# print(pdt.type())


shape: torch.Size([3, 3, 1, 3]), dense dim: 2, sparse: 2
torch.Size([3, 9])
torch.Size([9, 3])
torch.Size([9, 9])
torch.sparse.FloatTensor


In [19]:
A.to_sparse_csr()


tensor(crow_indices=tensor([ 0,  6, 12, 18, 24, 30, 36, 45, 54, 63]),
       col_indices=tensor([0, 1, 2, 6, 7, 8, 0, 1, 2, 6, 7, 8, 0, 1, 2, 6, 7, 8,
                           3, 4, 5, 6, 7, 8, 3, 4, 5, 6, 7, 8, 3, 4, 5, 6, 7, 8,
                           0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7, 8,
                           0, 1, 2, 3, 4, 5, 6, 7, 8]),
       values=tensor([  0.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,   2.,
                        3.,   4.,   5.,   0.,   2.,   4.,   6.,   8.,  10.,
                       36.,  42.,  48.,  54.,  60.,  66.,  42.,  49.,  56.,
                       63.,  70.,  77.,  48.,  56.,  64.,  72.,  80.,  88.,
                        0.,   3.,   6.,  54.,  63.,  72.,  90., 102., 114.,
                        0.,   4.,   8.,  60.,  70.,  80., 102., 116., 130.,
                        0.,   5.,  10.,  66.,  77.,  88., 114., 130., 146.]),
       size=(9, 9), nnz=63, layout=torch.sparse_csr)

In [None]:
i = [[0, 0, 1, 1], [0, 2, 1, 2]]
# v = torch.randn(4 * 2 * 2).view((4, 2, 2)).to(dtype=torch.float32)
v = torch.randn(4)
sct = torch.sparse_coo_tensor(i, v)
v2 = torch.randn(4)
sct2 = torch.sparse_coo_tensor(i, v2)
# torch.cat([sct, sct2], dim=1)
# pdt = torch.sparse.mm(v, v2)
pdt = sct @ sct2.T
print(pdt.type())
print(pdt.is_sparse)


In [None]:
# confirm that the jacobian values from jacrev and vmap match
batch_dim = jac_from_jacrev.shape[0]
index = trimmed_dataset['camera_index_of_observations']
jac_from_jacrev_condensed = jac_from_jacrev[torch.arange(batch_dim), index]
print(jac_from_jacrev_condensed.shape)
if torch.allclose(jac_from_vmap, jac_from_jacrev_condensed):
  print("Jacobian structure sanity check: ok!")
else:
  print("Jacobian structure sanity check: failed!")


In [None]:
# manually construct SBT from vmap result
import torch
import pypose as pp
from pypose.sparse import sbktensor

def construct_sbt(jac_from_vmap, num_cameras, camera_index):
  # camera_index = torch.from_numpy(camera_index) # for torch.stack
  n = camera_index.shape[0] # num 2D points
  # i = torch.stack([torch.arange(n), camera_index, torch.zeros(n)])
  i = torch.stack([torch.arange(n), camera_index])
  print(i.shape)
  # v = jac_from_vmap[:, None, None, :] # adjust dimension to accomodate for sbt constructor
  v = jac_from_vmap[:, None, :] # adjust dimension to accomodate for sbt constructor
  # return pp.sbktensor(i, v, size=(n, num_cameras, 1), dtype=torch.float32)
  return pp.sbktensor.sbktensor(i, v, size=(n, num_cameras), dtype=torch.float32)

sparse_jac = construct_sbt(jac_from_vmap, len(trimmed_dataset['camera_extrinsics']), trimmed_dataset['camera_index_of_observations'])
dense_jac = sparse_jac.to_dense()
# if torch.allclose(jac_from_jacrev, dense_jac):
#   print("Dense Jacobian <-> Sparse Jacobian: ok!")
# else:
#   print("Dense Jacobian <-> Sparse Jacobian: failed!")


In [14]:
storage_t = sparse_jac._s
print(f'shape: {storage_t.shape}, dense dim: {storage_t.dense_dim()}, \
sparse: {storage_t.sparse_dim()}')
flattened_coo = pp.hybrid2coo(storage_t) # the flattened tensor
print(flattened_coo.shape)
flattened_coo_T = flattened_coo.T
print(flattened_coo_T.shape)
A = torch.sparse.mm(flattened_coo_T, flattened_coo)
print(A.shape)
print(A.type())
# pdt = torch.matmul(flattened_coo_T, torch.arange(1000).view(-1, 1).to(dtype=torch.float32))
# print(pdt.shape)
# print(pdt.type())


NameError: name 'sparse_jac' is not defined

In [None]:
def sparse_coo_diagonal(t: torch.Tensor):
    indices = t.indices()
    diag_indices = indices[0] == indices[1]
    return t.values()[diag_indices]

def sparse_coo_diagonal_clamp_(t: torch.Tensor, min_value, max_value):
    indices = t.indices()
    diag_indices = indices[0] == indices[1]
    t.values()[diag_indices] = t.values()[diag_indices].clamp_(min_value, max_value)


def sparse_coo_diagonal_add_(t: torch.Tensor, other: torch.Tensor):
    indices = t.indices()
    diag_indices = indices[0] == indices[1]
    t.values()[diag_indices] = t.values()[diag_indices].add_(other)


In [None]:
diag_indices = (A.indices()[0] == A.indices()[1]).nonzero(as_tuple=True)
# print(diag_indices[0][])
sparse_coo_diagonal_clamp_(A, -100, 100)
print(A.values()[diag_indices[0][:10]])
sparse_coo_diagonal_add_(A, sparse_coo_diagonal(A) * torch.ones(301) * 2)
print(A.values()[diag_indices[0][:10]])


In [None]:
def jacrev_custom(func, argnums):
  def wrapper(*args, **kwargs):
    jac_vmap = torch.vmap(pp.func.jacrev(func)) # vmap
    gradients = jac_vmap(*args)
    sbt = construct_sbt(gradients, kwargs['num_cols'], kwargs['index'])
    return sbt
  return wrapper


In [None]:
jac_function_custom = jacrev_custom(reprojerr_vmap, argnums=0)
jac_from_custom = jac_function_custom(trimmed_dataset['camera_extrinsics'][trimmed_dataset['camera_index_of_observations']],
                      trimmed_dataset['points_3d'][trimmed_dataset['point_index_of_observations'], None],
                      trimmed_dataset['points_2d'],
                      trimmed_dataset['camera_intrinsics'][trimmed_dataset['camera_index_of_observations']],
                      trimmed_dataset['camera_distortions'][trimmed_dataset['camera_index_of_observations']],
                      num_cols=len(trimmed_dataset['camera_extrinsics']),
                      index=trimmed_dataset['camera_index_of_observations'])
