Skip to content

Commit 5968770

Browse files
committed
Initial upstreaming of Rust ADBench implementation for
ba,fft,gmm,lstm,ode-real
1 parent 6159a44 commit 5968770

File tree

15 files changed

+1716
-0
lines changed

15 files changed

+1716
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat --crate-type=staticlib
2+
//@ build-pass
3+
//@ no-prefer-dynamic
4+
//@ needs-enzyme
5+
#![feature(autodiff)]
6+
#![allow(non_snake_case)]
7+
8+
use std::autodiff::autodiff_reverse;
9+
pub mod safe;
10+
pub mod r#unsafe;
11+
12+
static BA_NCAMPARAMS: usize = 11;
13+
14+
#[no_mangle]
15+
pub extern "C" fn rust_dcompute_zach_weight_error(
16+
w: *const f64,
17+
dw: *mut f64,
18+
err: *mut f64,
19+
derr: *mut f64,
20+
) {
21+
dcompute_zach_weight_error(w, dw, err, derr);
22+
}
23+
24+
#[autodiff_reverse(dcompute_zach_weight_error, Duplicated, Duplicated)]
25+
pub fn compute_zach_weight_error(w: *const f64, err: *mut f64) {
26+
let w = unsafe { *w };
27+
unsafe { *err = 1. - w * w; }
28+
}
29+
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
use crate::BA_NCAMPARAMS;
2+
use crate::compute_zach_weight_error;
3+
use std::autodiff::autodiff_reverse;
4+
use std::convert::TryInto;
5+
6+
fn sqsum(x: &[f64]) -> f64 {
7+
x.iter().map(|&v| v * v).sum()
8+
}
9+
10+
#[inline]
11+
fn cross(a: &[f64; 3], b: &[f64; 3]) -> [f64; 3] {
12+
[
13+
a[1] * b[2] - a[2] * b[1],
14+
a[2] * b[0] - a[0] * b[2],
15+
a[0] * b[1] - a[1] * b[0],
16+
]
17+
}
18+
19+
fn radial_distort(rad_params: &[f64], proj: &mut [f64]) {
20+
let rsq = sqsum(proj);
21+
let l = 1. + rad_params[0] * rsq + rad_params[1] * rsq * rsq;
22+
proj[0] = proj[0] * l;
23+
proj[1] = proj[1] * l;
24+
}
25+
26+
fn rodrigues_rotate_point(rot: &[f64; 3], pt: &[f64; 3], rotated_pt: &mut [f64; 3]) {
27+
let sqtheta = sqsum(rot);
28+
if sqtheta != 0. {
29+
let theta = sqtheta.sqrt();
30+
let costheta = theta.cos();
31+
let sintheta = theta.sin();
32+
let theta_inverse = 1. / theta;
33+
let mut w = [0.; 3];
34+
for i in 0..3 {
35+
w[i] = rot[i] * theta_inverse;
36+
}
37+
let w_cross_pt = cross(&w, &pt);
38+
let tmp = (w[0] * pt[0] + w[1] * pt[1] + w[2] * pt[2]) * (1. - costheta);
39+
for i in 0..3 {
40+
rotated_pt[i] = pt[i] * costheta + w_cross_pt[i] * sintheta + w[i] * tmp;
41+
}
42+
} else {
43+
let rot_cross_pt = cross(&rot, &pt);
44+
for i in 0..3 {
45+
rotated_pt[i] = pt[i] + rot_cross_pt[i];
46+
}
47+
}
48+
}
49+
50+
fn project(cam: &[f64; 11], X: &[f64; 3], proj: &mut [f64; 2]) {
51+
let C = &cam[3..6];
52+
let mut Xo = [0.; 3];
53+
let mut Xcam = [0.; 3];
54+
55+
Xo[0] = X[0] - C[0];
56+
Xo[1] = X[1] - C[1];
57+
Xo[2] = X[2] - C[2];
58+
59+
rodrigues_rotate_point(cam.first_chunk::<3>().unwrap(), &Xo, &mut Xcam);
60+
61+
proj[0] = Xcam[0] / Xcam[2];
62+
proj[1] = Xcam[1] / Xcam[2];
63+
64+
radial_distort(&cam[9..], proj);
65+
66+
proj[0] = proj[0] * cam[6] + cam[7];
67+
proj[1] = proj[1] * cam[6] + cam[8];
68+
}
69+
70+
#[no_mangle]
71+
pub extern "C" fn rust_dcompute_reproj_error(
72+
cam: *const [f64; 11],
73+
dcam: *mut [f64; 11],
74+
x: *const [f64; 3],
75+
dx: *mut [f64; 3],
76+
w: *const [f64; 1],
77+
wb: *mut [f64; 1],
78+
feat: *const [f64; 2],
79+
err: *mut [f64; 2],
80+
derr: *mut [f64; 2],
81+
) {
82+
unsafe {dcompute_reproj_error(cam, dcam, x, dx, w, wb, feat, err, derr)};
83+
}
84+
85+
#[autodiff_reverse(
86+
dcompute_reproj_error,
87+
Duplicated,
88+
Duplicated,
89+
Duplicated,
90+
Const,
91+
DuplicatedOnly
92+
)]
93+
pub fn compute_reproj_error(
94+
cam: *const [f64; 11],
95+
x: *const [f64; 3],
96+
w: *const [f64; 1],
97+
feat: *const [f64; 2],
98+
err: *mut [f64; 2],
99+
) {
100+
let cam = unsafe { &*cam };
101+
let w = unsafe { *(*w).get_unchecked(0) };
102+
let x = unsafe { &*x };
103+
let feat = unsafe { &*feat };
104+
let err = unsafe { &mut *err };
105+
let mut proj = [0.; 2];
106+
project(cam, x, &mut proj);
107+
err[0] = w * (proj[0] - feat[0]);
108+
err[1] = w * (proj[1] - feat[1]);
109+
}
110+
111+
// n number of cameras
112+
// m number of points
113+
// p number of observations
114+
// cams: 11*n cameras in format [r1 r2 r3 C1 C2 C3 f u0 v0 k1 k2]
115+
// r1, r2, r3 are angle - axis rotation parameters(Rodrigues)
116+
// [C1 C2 C3]' is the camera center
117+
// f is the focal length in pixels
118+
// [u0 v0]' is the principal point
119+
// k1, k2 are radial distortion parameters
120+
// X: 3*m points
121+
// obs: 2*p observations (pairs cameraIdx, pointIdx)
122+
// feats: 2*p features (x,y coordinates corresponding to observations)
123+
// reproj_err: 2*p errors of observations
124+
// w_err: p weight "error" terms
125+
fn rust_ba_objective(
126+
n: usize,
127+
m: usize,
128+
p: usize,
129+
cams: &[f64],
130+
x: &[f64],
131+
w: &[f64],
132+
obs: &[i32],
133+
feats: &[f64],
134+
reproj_err: &mut [f64],
135+
w_err: &mut [f64],
136+
) {
137+
assert_eq!(cams.len(), n * 11);
138+
assert_eq!(x.len(), m * 3);
139+
assert_eq!(w.len(), p);
140+
assert_eq!(obs.len(), p * 2);
141+
assert_eq!(feats.len(), p * 2);
142+
assert_eq!(reproj_err.len(), p * 2);
143+
assert_eq!(w_err.len(), p);
144+
145+
for i in 0..p {
146+
let cam_idx = obs[i * 2 + 0] as usize;
147+
let pt_idx = obs[i * 2 + 1] as usize;
148+
let start = cam_idx * BA_NCAMPARAMS;
149+
let cam: &[f64; 11] = unsafe {
150+
cams[start..]
151+
.get_unchecked(..11)
152+
.try_into()
153+
.unwrap_unchecked()
154+
};
155+
let x: &[f64; 3] = unsafe {
156+
x[pt_idx * 3..]
157+
.get_unchecked(..3)
158+
.try_into()
159+
.unwrap_unchecked()
160+
};
161+
let w: &[f64; 1] = unsafe { w[i..].get_unchecked(..1).try_into().unwrap_unchecked() };
162+
let feat: &[f64; 2] = unsafe {
163+
feats[i * 2..]
164+
.get_unchecked(..2)
165+
.try_into()
166+
.unwrap_unchecked()
167+
};
168+
let reproj_err: &mut [f64; 2] = unsafe {
169+
reproj_err[i * 2..]
170+
.get_unchecked_mut(..2)
171+
.try_into()
172+
.unwrap_unchecked()
173+
};
174+
compute_reproj_error(cam, x, w, feat, reproj_err);
175+
}
176+
177+
for i in 0..p {
178+
let w_err: &mut f64 = unsafe { w_err.get_unchecked_mut(i) };
179+
compute_zach_weight_error(w[i..].as_ptr(), w_err as *mut f64);
180+
}
181+
}
182+
183+
#[no_mangle]
184+
extern "C" fn rust2_ba_objective(
185+
n: i32,
186+
m: i32,
187+
p: i32,
188+
cams: *const f64,
189+
x: *const f64,
190+
w: *const f64,
191+
obs: *const i32,
192+
feats: *const f64,
193+
reproj_err: *mut f64,
194+
w_err: *mut f64,
195+
) {
196+
let n = n as usize;
197+
let m = m as usize;
198+
let p = p as usize;
199+
let cams = unsafe { std::slice::from_raw_parts(cams, n * 11) };
200+
let x = unsafe { std::slice::from_raw_parts(x, m * 3) };
201+
let w = unsafe { std::slice::from_raw_parts(w, p) };
202+
let obs = unsafe { std::slice::from_raw_parts(obs, p * 2) };
203+
let feats = unsafe { std::slice::from_raw_parts(feats, p * 2) };
204+
let reproj_err = unsafe { std::slice::from_raw_parts_mut(reproj_err, p * 2) };
205+
let w_err = unsafe { std::slice::from_raw_parts_mut(w_err, p) };
206+
rust_ba_objective(n, m, p, cams, x, w, obs, feats, reproj_err, w_err);
207+
}
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
use crate::BA_NCAMPARAMS;
2+
use crate::compute_zach_weight_error;
3+
use std::autodiff::autodiff_reverse;
4+
use std::convert::TryInto;
5+
6+
unsafe fn sqsum(x: *const f64, n: usize) -> f64 {
7+
let mut sum = 0.;
8+
for i in 0..n {
9+
let v = unsafe { *x.add(i) };
10+
sum += v * v;
11+
}
12+
sum
13+
}
14+
15+
#[inline]
16+
unsafe fn cross(a: *const f64, b: *const f64, out: *mut f64) {
17+
*out.add(0) = *a.add(1) * *b.add(2) - *a.add(2) * *b.add(1);
18+
*out.add(1) = *a.add(2) * *b.add(0) - *a.add(0) * *b.add(2);
19+
*out.add(2) = *a.add(0) * *b.add(1) - *a.add(1) * *b.add(0);
20+
}
21+
22+
unsafe fn radial_distort(rad_params: *const f64, proj: *mut f64) {
23+
let rsq = sqsum(proj, 2);
24+
let l = 1. + *rad_params.add(0) * rsq + *rad_params.add(1) * rsq * rsq;
25+
*proj.add(0) = *proj.add(0) * l;
26+
*proj.add(1) = *proj.add(1) * l;
27+
}
28+
29+
unsafe fn rodrigues_rotate_point(rot: *const f64, pt: *const f64, rotated_pt: *mut f64) {
30+
let sqtheta = sqsum(rot, 3);
31+
if sqtheta != 0. {
32+
let theta = sqtheta.sqrt();
33+
let costheta = theta.cos();
34+
let sintheta = theta.sin();
35+
let theta_inverse = 1. / theta;
36+
let mut w = [0.; 3];
37+
for i in 0..3 {
38+
w[i] = *rot.add(i) * theta_inverse;
39+
}
40+
let mut w_cross_pt = [0.; 3];
41+
cross(w.as_ptr(), pt, w_cross_pt.as_mut_ptr());
42+
let tmp = (w[0] * *pt.add(0) + w[1] * *pt.add(1) + w[2] * *pt.add(2)) * (1. - costheta);
43+
for i in 0..3 {
44+
*rotated_pt.add(i) = *pt.add(i) * costheta + w_cross_pt[i] * sintheta + w[i] * tmp;
45+
}
46+
} else {
47+
let mut rot_cross_pt = [0.; 3];
48+
cross(rot, pt, rot_cross_pt.as_mut_ptr());
49+
for i in 0..3 {
50+
*rotated_pt.add(i) = *pt.add(i) + rot_cross_pt[i];
51+
}
52+
}
53+
}
54+
55+
unsafe fn project(cam: *const f64, X: *const f64, proj: *mut f64) {
56+
let C = cam.add(3);
57+
let mut Xo = [0.; 3];
58+
let mut Xcam = [0.; 3];
59+
60+
Xo[0] = *X.add(0) - *C.add(0);
61+
Xo[1] = *X.add(1) - *C.add(1);
62+
Xo[2] = *X.add(2) - *C.add(2);
63+
64+
rodrigues_rotate_point(cam, Xo.as_ptr(), Xcam.as_mut_ptr());
65+
66+
*proj.add(0) = Xcam[0] / Xcam[2];
67+
*proj.add(1) = Xcam[1] / Xcam[2];
68+
69+
radial_distort(cam.add(9), proj);
70+
*proj.add(0) = *proj.add(0) * *cam.add(6) + *cam.add(7);
71+
*proj.add(1) = *proj.add(1) * *cam.add(6) + *cam.add(8);
72+
}
73+
74+
#[no_mangle]
75+
pub unsafe extern "C" fn rust_unsafe_dcompute_reproj_error(
76+
cam: *const f64,
77+
dcam: *mut f64,
78+
x: *const f64,
79+
dx: *mut f64,
80+
w: *const f64,
81+
wb: *mut f64,
82+
feat: *const f64,
83+
err: *mut f64,
84+
derr: *mut f64,
85+
) {
86+
unsafe {dcompute_reproj_error(cam, dcam, x, dx, w, wb, feat, err, derr)};
87+
}
88+
89+
90+
#[autodiff_reverse(
91+
dcompute_reproj_error,
92+
Duplicated,
93+
Duplicated,
94+
Duplicated,
95+
Const,
96+
DuplicatedOnly
97+
)]
98+
pub unsafe fn compute_reproj_error(
99+
cam: *const f64,
100+
x: *const f64,
101+
w: *const f64,
102+
feat: *const f64,
103+
err: *mut f64,
104+
) {
105+
let mut proj = [0.; 2];
106+
project(cam, x, proj.as_mut_ptr());
107+
*err.add(0) = *w * (proj[0] - *feat.add(0));
108+
*err.add(1) = *w * (proj[1] - *feat.add(1));
109+
}
110+
111+
#[no_mangle]
112+
unsafe extern "C" fn rust2_unsafe_ba_objective(
113+
n: i32,
114+
m: i32,
115+
p: i32,
116+
cams: *const f64,
117+
x: *const f64,
118+
w: *const f64,
119+
obs: *const i32,
120+
feats: *const f64,
121+
reproj_err: *mut f64,
122+
w_err: *mut f64,
123+
) {
124+
let n = n as usize;
125+
let m = m as usize;
126+
let p = p as usize;
127+
for i in 0..p {
128+
let cam_idx = *obs.add(i * 2 + 0) as usize;
129+
let pt_idx = *obs.add(i * 2 + 1) as usize;
130+
let start = cam_idx * BA_NCAMPARAMS;
131+
132+
let cam: *const f64 = cams.add(start);
133+
let x: *const f64 = x.add(pt_idx * 3);
134+
let w: *const f64 = w.add(i);
135+
let feat: *const f64 = feats.add(i * 2);
136+
let reproj_err: *mut f64 = reproj_err.add(i * 2);
137+
compute_reproj_error(cam, x, w, feat, reproj_err);
138+
}
139+
140+
for i in 0..p {
141+
compute_zach_weight_error(w.add(i), w_err.add(i));
142+
}
143+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat --crate-type=staticlib
2+
//@ build-pass
3+
//@ no-prefer-dynamic
4+
//@ needs-enzyme
5+
#![feature(slice_swap_unchecked)]
6+
#![feature(autodiff)]
7+
8+
pub mod safe;
9+
pub mod unsf;

0 commit comments

Comments
 (0)