Skip to content

Commit 1fb0a0b

Browse files
committed
Initial upstreaming of Rust ADBench implementations
1 parent 6159a44 commit 1fb0a0b

File tree

15 files changed

+2004
-0
lines changed

15 files changed

+2004
-0
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat --crate-type=staticlib
2+
//@ build-pass
3+
//@ no-prefer-dynamic
4+
//@ needs-enzyme
5+
#![feature(autodiff)]
6+
#![allow(non_snake_case)]
7+
8+
use std::autodiff::autodiff_reverse;
9+
pub mod safe;
10+
pub mod r#unsafe;
11+
12+
static BA_NCAMPARAMS: usize = 11;
13+
14+
#[no_mangle]
15+
pub extern "C" fn rust_dcompute_zach_weight_error(
16+
w: *const f64,
17+
dw: *mut f64,
18+
err: *mut f64,
19+
derr: *mut f64,
20+
) {
21+
dcompute_zach_weight_error(w, dw, err, derr);
22+
}
23+
24+
#[autodiff_reverse(dcompute_zach_weight_error, Duplicated, Duplicated)]
25+
pub fn compute_zach_weight_error(w: *const f64, err: *mut f64) {
26+
let w = unsafe { *w };
27+
unsafe {
28+
*err = 1. - w * w;
29+
}
30+
}
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
//@ ignore-auxiliary lib.rs
2+
use crate::compute_zach_weight_error;
3+
use crate::BA_NCAMPARAMS;
4+
use std::autodiff::autodiff_reverse;
5+
use std::convert::TryInto;
6+
7+
fn sqsum(x: &[f64]) -> f64 {
8+
x.iter().map(|&v| v * v).sum()
9+
}
10+
11+
#[inline]
12+
fn cross(a: &[f64; 3], b: &[f64; 3]) -> [f64; 3] {
13+
[
14+
a[1] * b[2] - a[2] * b[1],
15+
a[2] * b[0] - a[0] * b[2],
16+
a[0] * b[1] - a[1] * b[0],
17+
]
18+
}
19+
20+
fn radial_distort(rad_params: &[f64], proj: &mut [f64]) {
21+
let rsq = sqsum(proj);
22+
let l = 1. + rad_params[0] * rsq + rad_params[1] * rsq * rsq;
23+
proj[0] = proj[0] * l;
24+
proj[1] = proj[1] * l;
25+
}
26+
27+
fn rodrigues_rotate_point(rot: &[f64; 3], pt: &[f64; 3], rotated_pt: &mut [f64; 3]) {
28+
let sqtheta = sqsum(rot);
29+
if sqtheta != 0. {
30+
let theta = sqtheta.sqrt();
31+
let costheta = theta.cos();
32+
let sintheta = theta.sin();
33+
let theta_inverse = 1. / theta;
34+
let mut w = [0.; 3];
35+
for i in 0..3 {
36+
w[i] = rot[i] * theta_inverse;
37+
}
38+
let w_cross_pt = cross(&w, &pt);
39+
let tmp = (w[0] * pt[0] + w[1] * pt[1] + w[2] * pt[2]) * (1. - costheta);
40+
for i in 0..3 {
41+
rotated_pt[i] = pt[i] * costheta + w_cross_pt[i] * sintheta + w[i] * tmp;
42+
}
43+
} else {
44+
let rot_cross_pt = cross(&rot, &pt);
45+
for i in 0..3 {
46+
rotated_pt[i] = pt[i] + rot_cross_pt[i];
47+
}
48+
}
49+
}
50+
51+
fn project(cam: &[f64; 11], X: &[f64; 3], proj: &mut [f64; 2]) {
52+
let C = &cam[3..6];
53+
let mut Xo = [0.; 3];
54+
let mut Xcam = [0.; 3];
55+
56+
Xo[0] = X[0] - C[0];
57+
Xo[1] = X[1] - C[1];
58+
Xo[2] = X[2] - C[2];
59+
60+
rodrigues_rotate_point(cam.first_chunk::<3>().unwrap(), &Xo, &mut Xcam);
61+
62+
proj[0] = Xcam[0] / Xcam[2];
63+
proj[1] = Xcam[1] / Xcam[2];
64+
65+
radial_distort(&cam[9..], proj);
66+
67+
proj[0] = proj[0] * cam[6] + cam[7];
68+
proj[1] = proj[1] * cam[6] + cam[8];
69+
}
70+
71+
#[no_mangle]
72+
pub extern "C" fn rust_dcompute_reproj_error(
73+
cam: *const [f64; 11],
74+
dcam: *mut [f64; 11],
75+
x: *const [f64; 3],
76+
dx: *mut [f64; 3],
77+
w: *const [f64; 1],
78+
wb: *mut [f64; 1],
79+
feat: *const [f64; 2],
80+
err: *mut [f64; 2],
81+
derr: *mut [f64; 2],
82+
) {
83+
unsafe { dcompute_reproj_error(cam, dcam, x, dx, w, wb, feat, err, derr) };
84+
}
85+
86+
#[autodiff_reverse(
87+
dcompute_reproj_error,
88+
Duplicated,
89+
Duplicated,
90+
Duplicated,
91+
Const,
92+
DuplicatedOnly
93+
)]
94+
pub fn compute_reproj_error(
95+
cam: *const [f64; 11],
96+
x: *const [f64; 3],
97+
w: *const [f64; 1],
98+
feat: *const [f64; 2],
99+
err: *mut [f64; 2],
100+
) {
101+
let cam = unsafe { &*cam };
102+
let w = unsafe { *(*w).get_unchecked(0) };
103+
let x = unsafe { &*x };
104+
let feat = unsafe { &*feat };
105+
let err = unsafe { &mut *err };
106+
let mut proj = [0.; 2];
107+
project(cam, x, &mut proj);
108+
err[0] = w * (proj[0] - feat[0]);
109+
err[1] = w * (proj[1] - feat[1]);
110+
}
111+
112+
// n number of cameras
113+
// m number of points
114+
// p number of observations
115+
// cams: 11*n cameras in format [r1 r2 r3 C1 C2 C3 f u0 v0 k1 k2]
116+
// r1, r2, r3 are angle - axis rotation parameters(Rodrigues)
117+
// [C1 C2 C3]' is the camera center
118+
// f is the focal length in pixels
119+
// [u0 v0]' is the principal point
120+
// k1, k2 are radial distortion parameters
121+
// X: 3*m points
122+
// obs: 2*p observations (pairs cameraIdx, pointIdx)
123+
// feats: 2*p features (x,y coordinates corresponding to observations)
124+
// reproj_err: 2*p errors of observations
125+
// w_err: p weight "error" terms
126+
fn rust_ba_objective(
127+
n: usize,
128+
m: usize,
129+
p: usize,
130+
cams: &[f64],
131+
x: &[f64],
132+
w: &[f64],
133+
obs: &[i32],
134+
feats: &[f64],
135+
reproj_err: &mut [f64],
136+
w_err: &mut [f64],
137+
) {
138+
assert_eq!(cams.len(), n * 11);
139+
assert_eq!(x.len(), m * 3);
140+
assert_eq!(w.len(), p);
141+
assert_eq!(obs.len(), p * 2);
142+
assert_eq!(feats.len(), p * 2);
143+
assert_eq!(reproj_err.len(), p * 2);
144+
assert_eq!(w_err.len(), p);
145+
146+
for i in 0..p {
147+
let cam_idx = obs[i * 2 + 0] as usize;
148+
let pt_idx = obs[i * 2 + 1] as usize;
149+
let start = cam_idx * BA_NCAMPARAMS;
150+
let cam: &[f64; 11] = unsafe {
151+
cams[start..]
152+
.get_unchecked(..11)
153+
.try_into()
154+
.unwrap_unchecked()
155+
};
156+
let x: &[f64; 3] = unsafe {
157+
x[pt_idx * 3..]
158+
.get_unchecked(..3)
159+
.try_into()
160+
.unwrap_unchecked()
161+
};
162+
let w: &[f64; 1] = unsafe { w[i..].get_unchecked(..1).try_into().unwrap_unchecked() };
163+
let feat: &[f64; 2] = unsafe {
164+
feats[i * 2..]
165+
.get_unchecked(..2)
166+
.try_into()
167+
.unwrap_unchecked()
168+
};
169+
let reproj_err: &mut [f64; 2] = unsafe {
170+
reproj_err[i * 2..]
171+
.get_unchecked_mut(..2)
172+
.try_into()
173+
.unwrap_unchecked()
174+
};
175+
compute_reproj_error(cam, x, w, feat, reproj_err);
176+
}
177+
178+
for i in 0..p {
179+
let w_err: &mut f64 = unsafe { w_err.get_unchecked_mut(i) };
180+
compute_zach_weight_error(w[i..].as_ptr(), w_err as *mut f64);
181+
}
182+
}
183+
184+
#[no_mangle]
185+
extern "C" fn rust2_ba_objective(
186+
n: i32,
187+
m: i32,
188+
p: i32,
189+
cams: *const f64,
190+
x: *const f64,
191+
w: *const f64,
192+
obs: *const i32,
193+
feats: *const f64,
194+
reproj_err: *mut f64,
195+
w_err: *mut f64,
196+
) {
197+
let n = n as usize;
198+
let m = m as usize;
199+
let p = p as usize;
200+
let cams = unsafe { std::slice::from_raw_parts(cams, n * 11) };
201+
let x = unsafe { std::slice::from_raw_parts(x, m * 3) };
202+
let w = unsafe { std::slice::from_raw_parts(w, p) };
203+
let obs = unsafe { std::slice::from_raw_parts(obs, p * 2) };
204+
let feats = unsafe { std::slice::from_raw_parts(feats, p * 2) };
205+
let reproj_err = unsafe { std::slice::from_raw_parts_mut(reproj_err, p * 2) };
206+
let w_err = unsafe { std::slice::from_raw_parts_mut(w_err, p) };
207+
rust_ba_objective(n, m, p, cams, x, w, obs, feats, reproj_err, w_err);
208+
}
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
//@ ignore-auxiliary lib.rs
2+
use crate::compute_zach_weight_error;
3+
use crate::BA_NCAMPARAMS;
4+
use std::autodiff::autodiff_reverse;
5+
use std::convert::TryInto;
6+
7+
unsafe fn sqsum(x: *const f64, n: usize) -> f64 {
8+
let mut sum = 0.;
9+
for i in 0..n {
10+
let v = unsafe { *x.add(i) };
11+
sum += v * v;
12+
}
13+
sum
14+
}
15+
16+
#[inline]
17+
unsafe fn cross(a: *const f64, b: *const f64, out: *mut f64) {
18+
*out.add(0) = *a.add(1) * *b.add(2) - *a.add(2) * *b.add(1);
19+
*out.add(1) = *a.add(2) * *b.add(0) - *a.add(0) * *b.add(2);
20+
*out.add(2) = *a.add(0) * *b.add(1) - *a.add(1) * *b.add(0);
21+
}
22+
23+
unsafe fn radial_distort(rad_params: *const f64, proj: *mut f64) {
24+
let rsq = sqsum(proj, 2);
25+
let l = 1. + *rad_params.add(0) * rsq + *rad_params.add(1) * rsq * rsq;
26+
*proj.add(0) = *proj.add(0) * l;
27+
*proj.add(1) = *proj.add(1) * l;
28+
}
29+
30+
unsafe fn rodrigues_rotate_point(rot: *const f64, pt: *const f64, rotated_pt: *mut f64) {
31+
let sqtheta = sqsum(rot, 3);
32+
if sqtheta != 0. {
33+
let theta = sqtheta.sqrt();
34+
let costheta = theta.cos();
35+
let sintheta = theta.sin();
36+
let theta_inverse = 1. / theta;
37+
let mut w = [0.; 3];
38+
for i in 0..3 {
39+
w[i] = *rot.add(i) * theta_inverse;
40+
}
41+
let mut w_cross_pt = [0.; 3];
42+
cross(w.as_ptr(), pt, w_cross_pt.as_mut_ptr());
43+
let tmp = (w[0] * *pt.add(0) + w[1] * *pt.add(1) + w[2] * *pt.add(2)) * (1. - costheta);
44+
for i in 0..3 {
45+
*rotated_pt.add(i) = *pt.add(i) * costheta + w_cross_pt[i] * sintheta + w[i] * tmp;
46+
}
47+
} else {
48+
let mut rot_cross_pt = [0.; 3];
49+
cross(rot, pt, rot_cross_pt.as_mut_ptr());
50+
for i in 0..3 {
51+
*rotated_pt.add(i) = *pt.add(i) + rot_cross_pt[i];
52+
}
53+
}
54+
}
55+
56+
unsafe fn project(cam: *const f64, X: *const f64, proj: *mut f64) {
57+
let C = cam.add(3);
58+
let mut Xo = [0.; 3];
59+
let mut Xcam = [0.; 3];
60+
61+
Xo[0] = *X.add(0) - *C.add(0);
62+
Xo[1] = *X.add(1) - *C.add(1);
63+
Xo[2] = *X.add(2) - *C.add(2);
64+
65+
rodrigues_rotate_point(cam, Xo.as_ptr(), Xcam.as_mut_ptr());
66+
67+
*proj.add(0) = Xcam[0] / Xcam[2];
68+
*proj.add(1) = Xcam[1] / Xcam[2];
69+
70+
radial_distort(cam.add(9), proj);
71+
*proj.add(0) = *proj.add(0) * *cam.add(6) + *cam.add(7);
72+
*proj.add(1) = *proj.add(1) * *cam.add(6) + *cam.add(8);
73+
}
74+
75+
#[no_mangle]
76+
pub unsafe extern "C" fn rust_unsafe_dcompute_reproj_error(
77+
cam: *const f64,
78+
dcam: *mut f64,
79+
x: *const f64,
80+
dx: *mut f64,
81+
w: *const f64,
82+
wb: *mut f64,
83+
feat: *const f64,
84+
err: *mut f64,
85+
derr: *mut f64,
86+
) {
87+
unsafe { dcompute_reproj_error(cam, dcam, x, dx, w, wb, feat, err, derr) };
88+
}
89+
90+
#[autodiff_reverse(
91+
dcompute_reproj_error,
92+
Duplicated,
93+
Duplicated,
94+
Duplicated,
95+
Const,
96+
DuplicatedOnly
97+
)]
98+
pub unsafe fn compute_reproj_error(
99+
cam: *const f64,
100+
x: *const f64,
101+
w: *const f64,
102+
feat: *const f64,
103+
err: *mut f64,
104+
) {
105+
let mut proj = [0.; 2];
106+
project(cam, x, proj.as_mut_ptr());
107+
*err.add(0) = *w * (proj[0] - *feat.add(0));
108+
*err.add(1) = *w * (proj[1] - *feat.add(1));
109+
}
110+
111+
#[no_mangle]
112+
unsafe extern "C" fn rust2_unsafe_ba_objective(
113+
n: i32,
114+
m: i32,
115+
p: i32,
116+
cams: *const f64,
117+
x: *const f64,
118+
w: *const f64,
119+
obs: *const i32,
120+
feats: *const f64,
121+
reproj_err: *mut f64,
122+
w_err: *mut f64,
123+
) {
124+
let n = n as usize;
125+
let m = m as usize;
126+
let p = p as usize;
127+
for i in 0..p {
128+
let cam_idx = *obs.add(i * 2 + 0) as usize;
129+
let pt_idx = *obs.add(i * 2 + 1) as usize;
130+
let start = cam_idx * BA_NCAMPARAMS;
131+
132+
let cam: *const f64 = cams.add(start);
133+
let x: *const f64 = x.add(pt_idx * 3);
134+
let w: *const f64 = w.add(i);
135+
let feat: *const f64 = feats.add(i * 2);
136+
let reproj_err: *mut f64 = reproj_err.add(i * 2);
137+
compute_reproj_error(cam, x, w, feat, reproj_err);
138+
}
139+
140+
for i in 0..p {
141+
compute_zach_weight_error(w.add(i), w_err.add(i));
142+
}
143+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat --crate-type=staticlib
2+
//@ build-pass
3+
//@ no-prefer-dynamic
4+
//@ needs-enzyme
5+
#![feature(slice_swap_unchecked)]
6+
#![feature(autodiff)]
7+
8+
pub mod safe;
9+
pub mod unsf;

0 commit comments

Comments
 (0)