In [1]:
from typing import List
import warnings
import torch
from torch.library import Library, impl


@torch.jit.script
def _bsr_diag(input, offset: int=0):
    crow_indices = input.crow_indices() # b + 1 dimensional
    col_indices = input.col_indices() # b + 1 dimensional
    bsr_values = input.values() # 1 + 2 dimensional
    m, n = input.shape[-2], input.shape[-1]
    dense_m, dense_n = (bsr_values.shape[-2],
                                 bsr_values.shape[-1])
    sparse_m, sparse_n = m // dense_m, n // dense_n

    #simple case(block is square and offset is 0)
    if dense_m == dense_n and offset == 0:
        dummy_val = torch.zeros(bsr_values.shape[0])
        dummy = torch.sparse_csr_tensor(crow_indices=crow_indices,
                                        col_indices=col_indices,
                                        values=dummy_val)
        dummy_coo = dummy.to_sparse(layout=torch.sparse_coo).coalesce()

        indices = dummy_coo.indices()
        diag_indices = indices[0] == indices[1]
        values = bsr_values[diag_indices]
        n_diag_blocks = sparse_m if sparse_m < sparse_n else sparse_n
        results_shape = (n_diag_blocks, dense_m)
        results = torch.zeros(results_shape, dtype=values.dtype, device=values.device)
        results[indices[0, diag_indices]] = torch.diagonal(values, dim1=-2, dim2=-1)
        results = torch.flatten(results)
        return results

def _sparse_csr_mm(mat1, mat2):
    if isinstance(mat1, torch.Tensor) and mat1.layout == torch.sparse_bsr:
        if isinstance(mat2, torch.Tensor) and mat2.layout == torch.sparse_bsc:
            return bsr_bsc_matmul(mat1, mat2)
    elif isinstance(mat1, torch.Tensor) and mat1.layout == torch.sparse_bsc:
        if isinstance(mat2, torch.Tensor) and mat2.layout == torch.sparse_bsr:
            raise NotImplemented
    #https://github.com/pytorch/pytorch/blob/3fa3ed4923c19a2b8d2da69e994169b4c8ac5fe3/
    #aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp#L789
    if mat1.is_sparse_csr and mat2.is_sparse_csr:
        return torch.addmm(
            torch.zeros([mat1.size(0), mat2.size(1)], dtype=mat2.dtype,
                device=mat2.device, layout=mat2.layout),
            mat1,
            mat2,
            beta=0.0,
            alpha=1.0)
    if (mat1.layout == torch.sparse_csc or mat1.layout == torch.sparse_csr) and\
        (mat2.layout == torch.sparse_csc or mat2.layout == torch.sparse_csr):
        return _sparse_csr_mm(mat1.to_sparse_csr(), mat2.to_sparse_csr())
    if mat1.layout == torch.sparse_csc and mat2.layout == torch.strided:
        return _sparse_csr_mm(mat1.to_sparse_csr(), mat2)
    if mat2.layout == torch.strided:
        return torch.addmm(
            torch.zeros([mat1.size(0), mat2.size(1)], dtype=mat1.dtype,
                device=mat1.device, layout=mat2.layout),
            mat1,
            mat2,
            beta=0.0,
            alpha=1.0)
    return torch.addmm(
        torch.zeros([mat1.size(0), mat2.size(1)], dtype=mat1.dtype, device=mat1.device,
        layout=mat1.layout),
        mat1,
        mat2,
        beta=0.0,
        alpha=1.0)

@torch.jit.script
def bsr_bsc_matmul(bsr:torch.Tensor, bsc:torch.Tensor):
    assert bsr.shape[-1] == bsc.shape[-2]
    assert bsr.layout == torch.sparse_bsr or bsr.layout == torch.sparse_csr
    assert bsc.layout == torch.sparse_bsc or bsc.layout == torch.sparse_csc
    crow_indices = bsr.crow_indices() # b + 1 dimensional
    col_indices = bsr.col_indices() # b + 1 dimensional
    csr_values = bsr.values() # 1 + 2 dimensional

    ccol_indices = bsc.ccol_indices() # b + 1 dimensional
    row_indices = bsc.row_indices() # b + 1 dimensional
    csc_values = bsc.values() # 1 + 2 dimensional

    idx_dtype = crow_indices.dtype

    assert bsr.ndim == 2 and bsc.ndim == 2, "bsr and bsc must be 2 dimensional. \
    batch dimension is yet not supported."
    m, n, p = bsr.shape[-2], bsr.shape[-1], bsc.shape[-1]
    dense_m, dense_n, dense_p = (csr_values.shape[-2],
                                 csr_values.shape[-1],
                                 csc_values.shape[-1])
    sparse_m, sparse_n, sparse_p = m // dense_m, n // dense_n, p // dense_p
    assert dense_m * sparse_m == m
    assert dense_n * sparse_n == n
    assert dense_p * sparse_p == p

    result_step: int = 0
    coo_indices: List[int] = list()
    index: List[int] = list()
    source: List[int] = list()
    for i in range(sparse_m):
        for j in range(sparse_p):
            nz: bool = False
            k2 = int(ccol_indices[j].item())
            for k1 in range(int(crow_indices[i].item()), int(crow_indices[i+1].item())):
                if k2 == ccol_indices[j+1]:
                    break
                while row_indices[k2] < col_indices[k1] and k2 < ccol_indices[j+1] - 1:
                    k2 += 1
                if row_indices[k2] == col_indices[k1]:
                    index.append(result_step)
                    source.append(k1)
                    source.append(k2)
                    nz = True
            if nz:
                result_step += 1
                coo_indices.append(i)
                coo_indices.append(j)
    source = torch.tensor(source, dtype=idx_dtype, device=bsr.device).view(-1, 2)
    index = torch.tensor(index, dtype=idx_dtype, device=bsr.device)
    prod = torch.bmm(csr_values[source[:, 0]], csc_values[source[:, 1]])
    values_shape = (result_step, dense_m, dense_p)
    reduced = torch.zeros(values_shape, dtype=prod.dtype, device=prod.device)
    reduced.scatter_add_(0, index.unsqueeze(-1).unsqueeze(-1).expand_as(prod), prod)
    coo_indices = torch.tensor(coo_indices, dtype=idx_dtype, device=bsr.device)
    coo_indices = coo_indices.view(-1, 2).T
    # use fake coo
    dummy_val = torch.zeros(coo_indices.shape[-1], dtype=prod.dtype, device=prod.device)
    dummy = torch.sparse_coo_tensor(indices=coo_indices,
                                    values=dummy_val,
                                    size=(sparse_m, sparse_p)).coalesce()
    dummy_csr = dummy.to_sparse_csr()
    return torch.sparse_bsr_tensor(dummy_csr.crow_indices(),
                                   dummy_csr.col_indices(),
                                   reduced,
                                   size=(m, p), dtype=reduced.dtype)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    sparse_lib = Library('aten', 'IMPL')
    sparse_lib.impl('mm', _sparse_csr_mm, 'SparseCsrCPU')
    sparse_lib.impl('mm', _sparse_csr_mm, 'SparseCsrCUDA')
    sparse_lib.impl('diagonal', _bsr_diag, 'SparseCsrCPU')
    sparse_lib.impl('diagonal', _bsr_diag, 'SparseCsrCUDA')

# Bundle Adjustment Example using SparsePyBA and the BAL dataset

```
The dataset is from the following paper:  
Sameer Agarwal, Noah Snavely, Steven M. Seitz, and Richard Szeliski.  
Bundle adjustment in the large.  
In European Conference on Computer Vision (ECCV), 2010.  
```

Link to the dataset: https://grail.cs.washington.edu/projects/bal/

# Fetch data

In [2]:
from bal_loader import build_pipeline

TARGET_DATASET = "ladybug"
TARGET_PROBLEM = "problem-49-7776-pre"
MAX_PIXELS = 31843

# TARGET_DATASET = "trafalgar"
# TARGET_PROBLEM = "problem-257-65132-pre"
# MAX_PIXELS = 225911

DEVICE = 'cpu' # change device to CPU if needed

def filter_problem(x):
  return x['problem_name'] == TARGET_PROBLEM

dataset_pipeline = build_pipeline(dataset=TARGET_DATASET, cache_dir='bal_data').filter(filter_problem)
dataset_iterator = iter(dataset_pipeline)
dataset = next(dataset_iterator)

print(f'Fetched {TARGET_PROBLEM} from {TARGET_DATASET}')

import torch
import torch.nn as nn
import pypose as pp

def trim_dataset(dataset, max_pixels):
  trimmed_dataset = dict()
  trimmed_dataset['points_2d'] = dataset['points_2d'][:max_pixels]
  trimmed_dataset['point_index_of_observations'] = dataset['point_index_of_observations'][:max_pixels]
  trimmed_dataset['camera_index_of_observations'] = dataset['camera_index_of_observations'][:max_pixels]
  # other fields are not changed
  trimmed_dataset['camera_extrinsics'] = dataset['camera_extrinsics']
  trimmed_dataset['camera_intrinsics'] = dataset['camera_intrinsics']
  trimmed_dataset['camera_distortions'] = dataset['camera_distortions']
  trimmed_dataset['points_3d'] = dataset['points_3d']

  for k in trimmed_dataset.keys():
    if not isinstance(trimmed_dataset[k], torch.Tensor):
      trimmed_dataset[k] = torch.from_numpy(trimmed_dataset[k])
    trimmed_dataset[k] = trimmed_dataset[k].to(DEVICE)
  return trimmed_dataset

trimmed_dataset = trim_dataset(dataset, max_pixels=MAX_PIXELS)

Streaming data for ladybug...
Fetched problem-49-7776-pre from ladybug


# Declare helper functions

In [3]:
def reprojerr(pose, points, intrinsics, distortions, pixels, camera_index, point_index):
  points = points[point_index, None] # [1000, 1, 3]
  pose = pose[camera_index] # [1000, 7]
  points = pose.unsqueeze(-2) @ points
  points = points.squeeze(-2)

  # perspective division
  points_proj = -points[:, :2] / points[:, -1:]

  # convert to pixel coordinates
  intrinsics = intrinsics[camera_index]
  distortions = distortions[camera_index]
  f = intrinsics[:, 0, 0]
  k1 = distortions[:, 0]
  k2 = distortions[:, 1]
  n = torch.sum(points_proj**2, dim=-1)
  r = 1.0 + k1 * n + k2 * n**2
  img_repj = f[:, None] * r[:, None] * points_proj

  # calculate the reprojection error
  loss = (img_repj - pixels)

  return loss
    
def reprojerr_vmap(pose, point, intrinsic, distortion, pixel):
  # reprojerr_vmap is not batched, it operates on a single 3D point and camera
  pose = pp.LieTensor(pose, ltype=pp.SE3_type) # pose will lose its ltype through vmap, temporary fix
  point = pose.unsqueeze(-2) @ point
  point = point.squeeze(-2)

  # perspective division
  point_proj = -point[:2] / point[-1:]

  # convert to pixel coordinates
  f = intrinsic[0, 0]
  k1 = distortion[0]
  k2 = distortion[1]
  n = torch.sum(point_proj**2, dim=-1)
  r = 1.0 + k1 * n + k2 * n**2
  img_repj = f * r * point_proj

  # calculate the reprojection error
  loss = (img_repj - pixel)

  return loss

def reprojerr_gen(*args):
    if args[4].shape[0] == MAX_PIXELS:
        return reprojerr(*args)
    else:
        args = list(args)
        return reprojerr_vmap(*args[:-2])

# sparse version
class ReprojNonBatched(nn.Module):
    def __init__(self, camera_extrinsics, points_3d):
        super().__init__()
        self.pose = nn.Parameter(camera_extrinsics)
        self.points_3d = nn.Parameter(points_3d)

    def forward(self, *args):
        return reprojerr_gen(self.pose, self.points_3d, *args)

def least_square_error(pose, points, intrinsics, distortions, pixels, camera_index, point_index):
  points = points[point_index, None] # [1000, 1, 3]
  pose = pose[camera_index] # [1000, 7]
  points = pose.unsqueeze(-2) @ points
  points = points.squeeze(-2)

  # perspective division
  points_proj = -points[:, :2] / points[:, -1:]

  # convert to pixel coordinates
  intrinsics = intrinsics[camera_index]
  distortions = distortions[camera_index]
  f = intrinsics[:, 0, 0]
  k1 = distortions[:, 0]
  k2 = distortions[:, 1]
  n = torch.sum(points_proj**2, dim=-1)
  r = 1.0 + k1 * n + k2 * n**2
  img_repj = f[:, None] * r[:, None] * points_proj

  # calculate the least square error
  return (torch.flatten(img_repj - pixels) ** 2).sum() / 2

def mean_square_error(pose, points, intrinsics, distortions, pixels, camera_index, point_index):
  points = points[point_index, None] # [1000, 1, 3]
  pose = pose[camera_index] # [1000, 7]
  points = pose.unsqueeze(-2) @ points
  points = points.squeeze(-2)

  # perspective division
  points_proj = -points[:, :2] / points[:, -1:]

  # convert to pixel coordinates
  intrinsics = intrinsics[camera_index]
  distortions = distortions[camera_index]
  f = intrinsics[:, 0, 0]
  k1 = distortions[:, 0]
  k2 = distortions[:, 1]
  n = torch.sum(points_proj**2, dim=-1)
  r = 1.0 + k1 * n + k2 * n**2
  img_repj = f[:, None] * r[:, None] * points_proj

  # calculate the mean suqared error
  return (img_repj - pixels).norm(dim=-1).mean()

# Run optimization

In [4]:
from pypose.optim.solver import CG
input = [trimmed_dataset['camera_intrinsics'][trimmed_dataset['camera_index_of_observations']],
         trimmed_dataset['camera_distortions'][trimmed_dataset['camera_index_of_observations']],
         trimmed_dataset['points_2d'],
         trimmed_dataset['camera_index_of_observations'],
         trimmed_dataset['point_index_of_observations']]


model_non_batched = ReprojNonBatched(trimmed_dataset['camera_extrinsics'].clone(),
                                     trimmed_dataset['points_3d'].clone())

model_non_batched = model_non_batched.to(DEVICE)

strategy_sparse = pp.optim.strategy.Adaptive(damping=1e-6)
optimizer_sparse = pp.optim.LM(model_non_batched, strategy=strategy_sparse, solver=CG(tol=1e-2), reject=1, dense_A=False, scipy=False)

print('Starting loss:', mean_square_error(model_non_batched.pose, model_non_batched.points_3d, *input).item())
for idx in range(30):
    loss = optimizer_sparse.step(input)
    print('Least square loss %.3f @ %d it'%(mean_square_error(model_non_batched.pose, model_non_batched.points_3d, *input).item(), idx))
    if loss < 1e-5:
        print('Early Stopping with loss:', loss.item())
        break
print('Ending loss:', mean_square_error(model_non_batched.pose, model_non_batched.points_3d, *input).item())

Starting loss: 5.668307304382324


  dummy_csc = dummy_coo.coalesce().to_sparse_csc()


addmm: computation on CPU is not implemented for Strided + SparseBsr @ Strided without MKL. PyTorch built with MKL has better support for addmm with sparse CPU tensors. 
Linear solver failed. Breaking optimization step...


RuntimeError: addmm: computation on CPU is not implemented for Strided + SparseBsr @ Strided without MKL. PyTorch built with MKL has better support for addmm with sparse CPU tensors.