# Bundle Adjustment Example using SparsePyBA and the BAL dataset

```
The dataset is from the following paper:  
Sameer Agarwal, Noah Snavely, Steven M. Seitz, and Richard Szeliski.  
Bundle adjustment in the large.  
In European Conference on Computer Vision (ECCV), 2010.  
```

Link to the dataset: https://grail.cs.washington.edu/projects/bal/

# Fetch data

In [1]:
from bal_loader import build_pipeline

# TARGET_DATASET = "ladybug"
# TARGET_PROBLEM = "problem-49-7776-pre"
# MAX_PIXELS = 31843

TARGET_DATASET = "trafalgar"
TARGET_PROBLEM = "problem-257-65132-pre"
MAX_PIXELS = 225911

DEVICE = 'cuda' # change device to CPU if needed

def filter_problem(x):
  return x['problem_name'] == TARGET_PROBLEM

dataset_pipeline = build_pipeline(dataset=TARGET_DATASET, cache_dir='bal_data').filter(filter_problem)
dataset_iterator = iter(dataset_pipeline)
dataset = next(dataset_iterator)

print(f'Fetched {TARGET_PROBLEM} from {TARGET_DATASET}')

import torch
import torch.nn as nn
import pypose as pp

def trim_dataset(dataset, max_pixels):
  trimmed_dataset = dict()
  trimmed_dataset['points_2d'] = dataset['points_2d'][:max_pixels]
  trimmed_dataset['point_index_of_observations'] = dataset['point_index_of_observations'][:max_pixels]
  trimmed_dataset['camera_index_of_observations'] = dataset['camera_index_of_observations'][:max_pixels]
  # other fields are not changed
  trimmed_dataset['camera_extrinsics'] = dataset['camera_extrinsics']
  trimmed_dataset['camera_intrinsics'] = dataset['camera_intrinsics']
  trimmed_dataset['camera_distortions'] = dataset['camera_distortions']
  trimmed_dataset['points_3d'] = dataset['points_3d']

  for k in trimmed_dataset.keys():
    if not isinstance(trimmed_dataset[k], torch.Tensor):
      trimmed_dataset[k] = torch.from_numpy(trimmed_dataset[k])
    trimmed_dataset[k] = trimmed_dataset[k].to(DEVICE)
  return trimmed_dataset

trimmed_dataset = trim_dataset(dataset, max_pixels=MAX_PIXELS)

Streaming data for trafalgar...
Fetched problem-257-65132-pre from trafalgar


# Declare helper functions

In [2]:
def reprojerr(pose, points, intrinsics, distortions, pixels, camera_index, point_index):
  points = points[point_index, None] # [1000, 1, 3]
  pose = pose[camera_index] # [1000, 7]
  points = pose.unsqueeze(-2) @ points
  points = points.squeeze(-2)

  # perspective division
  points_proj = -points[:, :2] / points[:, -1:]

  # convert to pixel coordinates
  intrinsics = intrinsics[camera_index]
  distortions = distortions[camera_index]
  f = intrinsics[:, 0, 0]
  k1 = distortions[:, 0]
  k2 = distortions[:, 1]
  n = torch.sum(points_proj**2, dim=-1)
  r = 1.0 + k1 * n + k2 * n**2
  img_repj = f[:, None] * r[:, None] * points_proj

  # calculate the reprojection error
  loss = (img_repj - pixels).norm(dim=-1)

  return loss
    
def reprojerr_vmap(pose, point, intrinsic, distortion, pixel):
  # reprojerr_vmap is not batched, it operates on a single 3D point and camera
  pose = pp.LieTensor(pose, ltype=pp.SE3_type) # pose will lose its ltype through vmap, temporary fix
  point = pose.unsqueeze(-2) @ point
  point = point.squeeze(-2)

  # perspective division
  point_proj = -point[:2] / point[-1:]

  # convert to pixel coordinates
  f = intrinsic[0, 0]
  k1 = distortion[0]
  k2 = distortion[1]
  n = torch.sum(point_proj**2, dim=-1)
  r = 1.0 + k1 * n + k2 * n**2
  img_repj = f * r * point_proj

  # calculate the reprojection error
  loss = (img_repj - pixel).norm(dim=-1)

  return loss

def reprojerr_gen(*args):
    if args[4].shape[0] == MAX_PIXELS:
        return reprojerr(*args)
    else:
        args = list(args)
        return reprojerr_vmap(*args[:-2])

# sparse version
class ReprojNonBatched(nn.Module):
    def __init__(self, camera_extrinsics, points_3d):
        super().__init__()
        self.pose = nn.Parameter(camera_extrinsics)
        self.points_3d = nn.Parameter(points_3d)

    def forward(self, *args):
        return reprojerr_gen(self.pose, self.points_3d, *args)

def least_square_error(pose, points, intrinsics, distortions, pixels, camera_index, point_index):
  points = points[point_index, None] # [1000, 1, 3]
  pose = pose[camera_index] # [1000, 7]
  points = pose.unsqueeze(-2) @ points
  points = points.squeeze(-2)

  # perspective division
  points_proj = -points[:, :2] / points[:, -1:]

  # convert to pixel coordinates
  intrinsics = intrinsics[camera_index]
  distortions = distortions[camera_index]
  f = intrinsics[:, 0, 0]
  k1 = distortions[:, 0]
  k2 = distortions[:, 1]
  n = torch.sum(points_proj**2, dim=-1)
  r = 1.0 + k1 * n + k2 * n**2
  img_repj = f[:, None] * r[:, None] * points_proj

  # calculate the least square error
  return (torch.flatten(img_repj - pixels) ** 2).sum() / 2

def mean_square_error(pose, points, intrinsics, distortions, pixels, camera_index, point_index):
  points = points[point_index, None] # [1000, 1, 3]
  pose = pose[camera_index] # [1000, 7]
  points = pose.unsqueeze(-2) @ points
  points = points.squeeze(-2)

  # perspective division
  points_proj = -points[:, :2] / points[:, -1:]

  # convert to pixel coordinates
  intrinsics = intrinsics[camera_index]
  distortions = distortions[camera_index]
  f = intrinsics[:, 0, 0]
  k1 = distortions[:, 0]
  k2 = distortions[:, 1]
  n = torch.sum(points_proj**2, dim=-1)
  r = 1.0 + k1 * n + k2 * n**2
  img_repj = f[:, None] * r[:, None] * points_proj

  # calculate the mean suqared error
  return (img_repj - pixels).norm(dim=-1).mean()

# Run optimization

In [8]:
from pypose.optim.solver import CG
input = [trimmed_dataset['camera_intrinsics'][trimmed_dataset['camera_index_of_observations']],
         trimmed_dataset['camera_distortions'][trimmed_dataset['camera_index_of_observations']],
         trimmed_dataset['points_2d'],
         trimmed_dataset['camera_index_of_observations'],
         trimmed_dataset['point_index_of_observations']]


model_non_batched = ReprojNonBatched(trimmed_dataset['camera_extrinsics'].clone(),
                                     trimmed_dataset['points_3d'].clone())

model_non_batched = model_non_batched.to(DEVICE)

strategy_sparse = pp.optim.strategy.Adaptive(damping=1e-6)
optimizer_sparse = pp.optim.LM(model_non_batched, strategy=strategy_sparse, solver=CG(tol=1e-5), reject=1, dense_A=False)

print('Starting loss:', least_square_error(model_non_batched.pose, model_non_batched.points_3d, *input).item())
for idx in range(30):
    loss = optimizer_sparse.step(input)
    print('Least square loss %.3f @ %d it'%(least_square_error(model_non_batched.pose, model_non_batched.points_3d, *input).item(), idx))
    if loss < 1e-5:
        print('Early Stopping with loss:', loss.item())
        break
print('Ending loss:', least_square_error(model_non_batched.pose, model_non_batched.points_3d, *input).item())

Starting loss: 21205098496.0
Least square loss 36421129732096.000 @ 0 it
Least square loss 62222017495040.000 @ 1 it
Least square loss 53383566196736.000 @ 2 it
Least square loss 12110813724672.000 @ 3 it
Least square loss 1642698375168.000 @ 4 it
Least square loss 3151167270000722999949393920.000 @ 5 it
CG failed to converge 
Linear solver failed. Breaking optimization step...


RuntimeError: CG failed to converge