diff --git a/tests/ui/autodiff/adbench/ba/src/lib.rs b/tests/ui/autodiff/adbench/ba/src/lib.rs new file mode 100644 index 0000000000000..e6d031767f343 --- /dev/null +++ b/tests/ui/autodiff/adbench/ba/src/lib.rs @@ -0,0 +1,29 @@ +//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat --crate-type=staticlib +//@ build-pass +//@ no-prefer-dynamic +//@ needs-enzyme +#![feature(autodiff)] +#![allow(non_snake_case)] + +use std::autodiff::autodiff_reverse; +pub mod safe; +pub mod r#unsafe; + +static BA_NCAMPARAMS: usize = 11; + +#[no_mangle] +pub extern "C" fn rust_dcompute_zach_weight_error( + w: *const f64, + dw: *mut f64, + err: *mut f64, + derr: *mut f64, +) { + dcompute_zach_weight_error(w, dw, err, derr); +} + +#[autodiff_reverse(dcompute_zach_weight_error, Duplicated, Duplicated)] +pub fn compute_zach_weight_error(w: *const f64, err: *mut f64) { + let w = unsafe { *w }; + unsafe { *err = 1. - w * w; } +} + diff --git a/tests/ui/autodiff/adbench/ba/src/safe.rs b/tests/ui/autodiff/adbench/ba/src/safe.rs new file mode 100644 index 0000000000000..5bcecd202b2f9 --- /dev/null +++ b/tests/ui/autodiff/adbench/ba/src/safe.rs @@ -0,0 +1,207 @@ +use crate::BA_NCAMPARAMS; +use crate::compute_zach_weight_error; +use std::autodiff::autodiff_reverse; +use std::convert::TryInto; + +fn sqsum(x: &[f64]) -> f64 { + x.iter().map(|&v| v * v).sum() +} + +#[inline] +fn cross(a: &[f64; 3], b: &[f64; 3]) -> [f64; 3] { + [ + a[1] * b[2] - a[2] * b[1], + a[2] * b[0] - a[0] * b[2], + a[0] * b[1] - a[1] * b[0], + ] +} + +fn radial_distort(rad_params: &[f64], proj: &mut [f64]) { + let rsq = sqsum(proj); + let l = 1. + rad_params[0] * rsq + rad_params[1] * rsq * rsq; + proj[0] = proj[0] * l; + proj[1] = proj[1] * l; +} + +fn rodrigues_rotate_point(rot: &[f64; 3], pt: &[f64; 3], rotated_pt: &mut [f64; 3]) { + let sqtheta = sqsum(rot); + if sqtheta != 0. { + let theta = sqtheta.sqrt(); + let costheta = theta.cos(); + let sintheta = theta.sin(); + let theta_inverse = 1. / theta; + let mut w = [0.; 3]; + for i in 0..3 { + w[i] = rot[i] * theta_inverse; + } + let w_cross_pt = cross(&w, &pt); + let tmp = (w[0] * pt[0] + w[1] * pt[1] + w[2] * pt[2]) * (1. - costheta); + for i in 0..3 { + rotated_pt[i] = pt[i] * costheta + w_cross_pt[i] * sintheta + w[i] * tmp; + } + } else { + let rot_cross_pt = cross(&rot, &pt); + for i in 0..3 { + rotated_pt[i] = pt[i] + rot_cross_pt[i]; + } + } +} + +fn project(cam: &[f64; 11], X: &[f64; 3], proj: &mut [f64; 2]) { + let C = &cam[3..6]; + let mut Xo = [0.; 3]; + let mut Xcam = [0.; 3]; + + Xo[0] = X[0] - C[0]; + Xo[1] = X[1] - C[1]; + Xo[2] = X[2] - C[2]; + + rodrigues_rotate_point(cam.first_chunk::<3>().unwrap(), &Xo, &mut Xcam); + + proj[0] = Xcam[0] / Xcam[2]; + proj[1] = Xcam[1] / Xcam[2]; + + radial_distort(&cam[9..], proj); + + proj[0] = proj[0] * cam[6] + cam[7]; + proj[1] = proj[1] * cam[6] + cam[8]; +} + +#[no_mangle] +pub extern "C" fn rust_dcompute_reproj_error( + cam: *const [f64; 11], + dcam: *mut [f64; 11], + x: *const [f64; 3], + dx: *mut [f64; 3], + w: *const [f64; 1], + wb: *mut [f64; 1], + feat: *const [f64; 2], + err: *mut [f64; 2], + derr: *mut [f64; 2], +) { + unsafe {dcompute_reproj_error(cam, dcam, x, dx, w, wb, feat, err, derr)}; +} + +#[autodiff_reverse( + dcompute_reproj_error, + Duplicated, + Duplicated, + Duplicated, + Const, + DuplicatedOnly +)] +pub fn compute_reproj_error( + cam: *const [f64; 11], + x: *const [f64; 3], + w: *const [f64; 1], + feat: *const [f64; 2], + err: *mut [f64; 2], +) { + let cam = unsafe { &*cam }; + let w = unsafe { *(*w).get_unchecked(0) }; + let x = unsafe { &*x }; + let feat = unsafe { &*feat }; + let err = unsafe { &mut *err }; + let mut proj = [0.; 2]; + project(cam, x, &mut proj); + err[0] = w * (proj[0] - feat[0]); + err[1] = w * (proj[1] - feat[1]); +} + +// n number of cameras +// m number of points +// p number of observations +// cams: 11*n cameras in format [r1 r2 r3 C1 C2 C3 f u0 v0 k1 k2] +// r1, r2, r3 are angle - axis rotation parameters(Rodrigues) +// [C1 C2 C3]' is the camera center +// f is the focal length in pixels +// [u0 v0]' is the principal point +// k1, k2 are radial distortion parameters +// X: 3*m points +// obs: 2*p observations (pairs cameraIdx, pointIdx) +// feats: 2*p features (x,y coordinates corresponding to observations) +// reproj_err: 2*p errors of observations +// w_err: p weight "error" terms +fn rust_ba_objective( + n: usize, + m: usize, + p: usize, + cams: &[f64], + x: &[f64], + w: &[f64], + obs: &[i32], + feats: &[f64], + reproj_err: &mut [f64], + w_err: &mut [f64], +) { + assert_eq!(cams.len(), n * 11); + assert_eq!(x.len(), m * 3); + assert_eq!(w.len(), p); + assert_eq!(obs.len(), p * 2); + assert_eq!(feats.len(), p * 2); + assert_eq!(reproj_err.len(), p * 2); + assert_eq!(w_err.len(), p); + + for i in 0..p { + let cam_idx = obs[i * 2 + 0] as usize; + let pt_idx = obs[i * 2 + 1] as usize; + let start = cam_idx * BA_NCAMPARAMS; + let cam: &[f64; 11] = unsafe { + cams[start..] + .get_unchecked(..11) + .try_into() + .unwrap_unchecked() + }; + let x: &[f64; 3] = unsafe { + x[pt_idx * 3..] + .get_unchecked(..3) + .try_into() + .unwrap_unchecked() + }; + let w: &[f64; 1] = unsafe { w[i..].get_unchecked(..1).try_into().unwrap_unchecked() }; + let feat: &[f64; 2] = unsafe { + feats[i * 2..] + .get_unchecked(..2) + .try_into() + .unwrap_unchecked() + }; + let reproj_err: &mut [f64; 2] = unsafe { + reproj_err[i * 2..] + .get_unchecked_mut(..2) + .try_into() + .unwrap_unchecked() + }; + compute_reproj_error(cam, x, w, feat, reproj_err); + } + + for i in 0..p { + let w_err: &mut f64 = unsafe { w_err.get_unchecked_mut(i) }; + compute_zach_weight_error(w[i..].as_ptr(), w_err as *mut f64); + } +} + +#[no_mangle] +extern "C" fn rust2_ba_objective( + n: i32, + m: i32, + p: i32, + cams: *const f64, + x: *const f64, + w: *const f64, + obs: *const i32, + feats: *const f64, + reproj_err: *mut f64, + w_err: *mut f64, +) { + let n = n as usize; + let m = m as usize; + let p = p as usize; + let cams = unsafe { std::slice::from_raw_parts(cams, n * 11) }; + let x = unsafe { std::slice::from_raw_parts(x, m * 3) }; + let w = unsafe { std::slice::from_raw_parts(w, p) }; + let obs = unsafe { std::slice::from_raw_parts(obs, p * 2) }; + let feats = unsafe { std::slice::from_raw_parts(feats, p * 2) }; + let reproj_err = unsafe { std::slice::from_raw_parts_mut(reproj_err, p * 2) }; + let w_err = unsafe { std::slice::from_raw_parts_mut(w_err, p) }; + rust_ba_objective(n, m, p, cams, x, w, obs, feats, reproj_err, w_err); +} diff --git a/tests/ui/autodiff/adbench/ba/src/unsafe.rs b/tests/ui/autodiff/adbench/ba/src/unsafe.rs new file mode 100644 index 0000000000000..c1060d67a7457 --- /dev/null +++ b/tests/ui/autodiff/adbench/ba/src/unsafe.rs @@ -0,0 +1,143 @@ +use crate::BA_NCAMPARAMS; +use crate::compute_zach_weight_error; +use std::autodiff::autodiff_reverse; +use std::convert::TryInto; + +unsafe fn sqsum(x: *const f64, n: usize) -> f64 { + let mut sum = 0.; + for i in 0..n { + let v = unsafe { *x.add(i) }; + sum += v * v; + } + sum +} + +#[inline] +unsafe fn cross(a: *const f64, b: *const f64, out: *mut f64) { + *out.add(0) = *a.add(1) * *b.add(2) - *a.add(2) * *b.add(1); + *out.add(1) = *a.add(2) * *b.add(0) - *a.add(0) * *b.add(2); + *out.add(2) = *a.add(0) * *b.add(1) - *a.add(1) * *b.add(0); +} + +unsafe fn radial_distort(rad_params: *const f64, proj: *mut f64) { + let rsq = sqsum(proj, 2); + let l = 1. + *rad_params.add(0) * rsq + *rad_params.add(1) * rsq * rsq; + *proj.add(0) = *proj.add(0) * l; + *proj.add(1) = *proj.add(1) * l; +} + +unsafe fn rodrigues_rotate_point(rot: *const f64, pt: *const f64, rotated_pt: *mut f64) { + let sqtheta = sqsum(rot, 3); + if sqtheta != 0. { + let theta = sqtheta.sqrt(); + let costheta = theta.cos(); + let sintheta = theta.sin(); + let theta_inverse = 1. / theta; + let mut w = [0.; 3]; + for i in 0..3 { + w[i] = *rot.add(i) * theta_inverse; + } + let mut w_cross_pt = [0.; 3]; + cross(w.as_ptr(), pt, w_cross_pt.as_mut_ptr()); + let tmp = (w[0] * *pt.add(0) + w[1] * *pt.add(1) + w[2] * *pt.add(2)) * (1. - costheta); + for i in 0..3 { + *rotated_pt.add(i) = *pt.add(i) * costheta + w_cross_pt[i] * sintheta + w[i] * tmp; + } + } else { + let mut rot_cross_pt = [0.; 3]; + cross(rot, pt, rot_cross_pt.as_mut_ptr()); + for i in 0..3 { + *rotated_pt.add(i) = *pt.add(i) + rot_cross_pt[i]; + } + } +} + +unsafe fn project(cam: *const f64, X: *const f64, proj: *mut f64) { + let C = cam.add(3); + let mut Xo = [0.; 3]; + let mut Xcam = [0.; 3]; + + Xo[0] = *X.add(0) - *C.add(0); + Xo[1] = *X.add(1) - *C.add(1); + Xo[2] = *X.add(2) - *C.add(2); + + rodrigues_rotate_point(cam, Xo.as_ptr(), Xcam.as_mut_ptr()); + + *proj.add(0) = Xcam[0] / Xcam[2]; + *proj.add(1) = Xcam[1] / Xcam[2]; + + radial_distort(cam.add(9), proj); + *proj.add(0) = *proj.add(0) * *cam.add(6) + *cam.add(7); + *proj.add(1) = *proj.add(1) * *cam.add(6) + *cam.add(8); +} + +#[no_mangle] +pub unsafe extern "C" fn rust_unsafe_dcompute_reproj_error( + cam: *const f64, + dcam: *mut f64, + x: *const f64, + dx: *mut f64, + w: *const f64, + wb: *mut f64, + feat: *const f64, + err: *mut f64, + derr: *mut f64, +) { + unsafe {dcompute_reproj_error(cam, dcam, x, dx, w, wb, feat, err, derr)}; +} + + +#[autodiff_reverse( + dcompute_reproj_error, + Duplicated, + Duplicated, + Duplicated, + Const, + DuplicatedOnly +)] +pub unsafe fn compute_reproj_error( + cam: *const f64, + x: *const f64, + w: *const f64, + feat: *const f64, + err: *mut f64, +) { + let mut proj = [0.; 2]; + project(cam, x, proj.as_mut_ptr()); + *err.add(0) = *w * (proj[0] - *feat.add(0)); + *err.add(1) = *w * (proj[1] - *feat.add(1)); +} + +#[no_mangle] +unsafe extern "C" fn rust2_unsafe_ba_objective( + n: i32, + m: i32, + p: i32, + cams: *const f64, + x: *const f64, + w: *const f64, + obs: *const i32, + feats: *const f64, + reproj_err: *mut f64, + w_err: *mut f64, +) { + let n = n as usize; + let m = m as usize; + let p = p as usize; + for i in 0..p { + let cam_idx = *obs.add(i * 2 + 0) as usize; + let pt_idx = *obs.add(i * 2 + 1) as usize; + let start = cam_idx * BA_NCAMPARAMS; + + let cam: *const f64 = cams.add(start); + let x: *const f64 = x.add(pt_idx * 3); + let w: *const f64 = w.add(i); + let feat: *const f64 = feats.add(i * 2); + let reproj_err: *mut f64 = reproj_err.add(i * 2); + compute_reproj_error(cam, x, w, feat, reproj_err); + } + + for i in 0..p { + compute_zach_weight_error(w.add(i), w_err.add(i)); + } +} diff --git a/tests/ui/autodiff/adbench/fft/src/lib.rs b/tests/ui/autodiff/adbench/fft/src/lib.rs new file mode 100644 index 0000000000000..6230b8973a67a --- /dev/null +++ b/tests/ui/autodiff/adbench/fft/src/lib.rs @@ -0,0 +1,9 @@ +//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat --crate-type=staticlib +//@ build-pass +//@ no-prefer-dynamic +//@ needs-enzyme +#![feature(slice_swap_unchecked)] +#![feature(autodiff)] + +pub mod safe; +pub mod unsf; diff --git a/tests/ui/autodiff/adbench/fft/src/safe.rs b/tests/ui/autodiff/adbench/fft/src/safe.rs new file mode 100644 index 0000000000000..89c9726a711f1 --- /dev/null +++ b/tests/ui/autodiff/adbench/fft/src/safe.rs @@ -0,0 +1,104 @@ +use std::autodiff::autodiff_reverse; +use std::f64::consts::PI; +use std::slice; + +fn bitreversal_perm(data: &mut [T]) { + let len = data.len() / 2; + let mut j = 1; + + for i in (1..data.len()).step_by(2) { + if j > i { + //dbg!(&i, &j); + data.swap(j-1, i-1); + data.swap(j, i); + //unsafe { + // data.swap_unchecked(j - 1, i - 1); + // data.swap_unchecked(j, i); + //} + } + + let mut m = len; + while m >= 2 && j > m { + j -= m; + m >>= 1; + } + + j += m; + } +} + +fn radix2(data: &mut [f64], i_sign: i32) { + let n = data.len() / 2; + if n == 1 { + return; + } + + let (a, b) = data.split_at_mut(n); + // assert_eq!(a.len(), b.len()); + radix2(a, i_sign); + radix2(b, i_sign); + + let wtemp = i_sign as f64 * (PI / n as f64).sin(); + let wpi = -i_sign as f64 * (2.0 * (PI / n as f64)).sin(); + let wpr = -2.0 * wtemp * wtemp; + let mut wr = 1.0; + let mut wi = 0.0; + + let (achunks, _) = a.as_chunks_mut(); + let (bchunks, _) = b.as_chunks_mut(); + for ([ax, ay], [bx, by]) in achunks.iter_mut().zip(bchunks.iter_mut()) { + let tempr = *bx * wr - *by * wi; + let tempi = *bx * wi + *by * wr; + + *bx = *ax - tempr; + *by = *ay - tempi; + *ax += tempr; + *ay += tempi; + + let wtemp_new = wr; + wr = wr * (wpr + 1.0) - wi * wpi; + wi = wi * (wpr + 1.0) + wtemp_new * wpi; + } +} + +fn rescale(data: &mut [f64], scale: usize) { + let scale = 1. / scale as f64; + for elm in data { + *elm *= scale; + } +} + +fn fft(data: &mut [f64]) { + bitreversal_perm(data); + radix2(data, 1); +} + +fn ifft(data: &mut [f64]) { + bitreversal_perm(data); + radix2(data, -1); + rescale(data, data.len() / 2); +} + +#[autodiff_reverse(dfoobar, DuplicatedOnly)] +pub fn foobar(data: &mut [f64]) { + fft(data); + ifft(data); +} + +#[no_mangle] +pub extern "C" fn rust_dfoobar(n: usize, data: *mut f64, ddata: *mut f64) { + let (data, ddata) = unsafe { + ( + slice::from_raw_parts_mut(data, n * 2), + slice::from_raw_parts_mut(ddata, n * 2), + ) + }; + + unsafe { dfoobar(data, ddata) }; +} + +#[no_mangle] +pub extern "C" fn rust_foobar(n: usize, data: *mut f64) { + let data = unsafe { slice::from_raw_parts_mut(data, n * 2) }; + foobar(data); +} diff --git a/tests/ui/autodiff/adbench/fft/src/unsf.rs b/tests/ui/autodiff/adbench/fft/src/unsf.rs new file mode 100644 index 0000000000000..693259da04050 --- /dev/null +++ b/tests/ui/autodiff/adbench/fft/src/unsf.rs @@ -0,0 +1,92 @@ +use std::autodiff::autodiff_reverse; +use std::f64::consts::PI; + +unsafe fn bitreversal_perm(data: *mut f64, len: usize) { + let mut j = 1; + + for i in (1..2 * len).step_by(2) { + if j > i { + std::ptr::swap(data.add(j - 1), data.add(i - 1)); + std::ptr::swap(data.add(j), data.add(i)); + } + + let mut m = len; + while m >= 2 && j > m { + j -= m; + m >>= 1; + } + + j += m; + } +} + +unsafe fn radix2(data: *mut f64, n: usize, i_sign: i32) { + if n == 1 { + return; + } + radix2(data, n / 2, i_sign); + radix2(data.add(n), n / 2, i_sign); + + let wtemp = i_sign as f64 * (PI / n as f64).sin(); + let wpi = -i_sign as f64 * (2.0 * (PI / n as f64)).sin(); + let wpr = -2.0 * wtemp * wtemp; + let mut wr = 1.0; + let mut wi = 0.0; + + for i in (0..n).step_by(2) { + let in_n = i + n; + let ax = &mut *data.add(i); + let ay = &mut *data.add(i + 1); + let bx = &mut *data.add(in_n); + let by = &mut *data.add(in_n + 1); + let tempr = *bx * wr - *by * wi; + let tempi = *bx * wi + *by * wr; + + *bx = *ax - tempr; + *by = *ay - tempi; + *ax += tempr; + *ay += tempi; + + let wtemp_new = wr; + wr = wr * (wpr + 1.0) - wi * wpi; + wi = wi * (wpr + 1.0) + wtemp_new * wpi; + } +} + +unsafe fn rescale(data: *mut f64, n: usize) { + let scale = 1. / n as f64; + for i in 0..2 * n { + *data.add(i) = *data.add(i) * scale; + } +} + +unsafe fn fft(data: *mut f64, n: usize) { + bitreversal_perm(data, n); + radix2(data, n, 1); +} + +unsafe fn ifft(data: *mut f64, n: usize) { + bitreversal_perm(data, n); + radix2(data, n, -1); + rescale(data, n); +} + +#[autodiff_reverse(unsafe_dfoobar, Const, DuplicatedOnly)] +pub unsafe fn unsafe_foobar(n: usize, data: *mut f64) { + fft(data, n); + ifft(data, n); +} + +#[no_mangle] +pub extern "C" fn rust_unsafe_dfoobar(n: usize, data: *mut f64, ddata: *mut f64) { + unsafe { + unsafe_dfoobar(n, data, ddata); + } +} + +#[no_mangle] +pub extern "C" fn rust_unsafe_foobar(n: usize, data: *mut f64) { + unsafe { + unsafe_foobar(n, data); + } +} diff --git a/tests/ui/autodiff/adbench/gmm/src/lib.rs b/tests/ui/autodiff/adbench/gmm/src/lib.rs new file mode 100644 index 0000000000000..9be8f561b62a2 --- /dev/null +++ b/tests/ui/autodiff/adbench/gmm/src/lib.rs @@ -0,0 +1,19 @@ +//@ revisions: LooseTypes default +//@[LooseTypes] compile-flags: -Zautodiff=Enable,LooseTypes -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat --crate-type=staticlib +//@[default] compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat --crate-type=staticlib +//@[LooseTypes] build-pass +//@[default] build-fail +//@ dont-check-compiler-stderr +//@ dont-check-compiler-stdout +//@ no-prefer-dynamic +//@ needs-enzyme +#![feature(autodiff)] +pub mod safe; +pub mod r#unsafe; + +#[derive(Clone, Copy)] +#[repr(C)] +pub struct Wishart { + pub gamma: f64, + pub m: i32, +} diff --git a/tests/ui/autodiff/adbench/gmm/src/safe.rs b/tests/ui/autodiff/adbench/gmm/src/safe.rs new file mode 100644 index 0000000000000..966dd3c3d3d39 --- /dev/null +++ b/tests/ui/autodiff/adbench/gmm/src/safe.rs @@ -0,0 +1,296 @@ +use crate::Wishart; +use std::f64::consts::PI; +use std::autodiff::autodiff_reverse; + +//#[cfg(feature = "libm")] +//use libm::lgamma; + +//#[cfg(not(feature = "libm"))] +mod cmath { + extern "C" { + pub fn lgamma(x: f64) -> f64; + } +} +//#[cfg(not(feature = "libm"))] +#[inline] +fn lgamma(x: f64) -> f64 { + unsafe { cmath::lgamma(x) } +} + +#[no_mangle] +pub extern "C" fn rust_dgmm_objective( + d: usize, + k: usize, + n: usize, + alphas: *const f64, + dalphas: *mut f64, + means: *const f64, + dmeans: *mut f64, + icf: *const f64, + dicf: *mut f64, + x: *const f64, + wishart: *const Wishart, + err: *mut f64, + derr: *mut f64, +) { + let alphas = unsafe { std::slice::from_raw_parts(alphas, k) }; + let means = unsafe { std::slice::from_raw_parts(means, k * d) }; + let icf = unsafe { std::slice::from_raw_parts(icf, k * d * (d + 1) / 2) }; + let x = unsafe { std::slice::from_raw_parts(x, n * d) }; + let wishart: Wishart = unsafe { *wishart }; + let mut my_err = unsafe { *err }; + + let d_alphas = unsafe { std::slice::from_raw_parts_mut(dalphas, k) }; + let d_means = unsafe { std::slice::from_raw_parts_mut(dmeans, k * d) }; + let d_icf = unsafe { std::slice::from_raw_parts_mut(dicf, k * d * (d + 1) / 2) }; + let mut my_derr = unsafe { *derr }; + let (mut qdiags, mut sum_qs, mut xcentered, mut qxcentered, mut main_term) = + get_workspace(d, k); + let (mut bqdiags, mut bsum_qs, mut bxcentered, mut bqxcentered, mut bmain_term) = + get_workspace(d, k); + + unsafe { dgmm_objective( + d, + k, + n, + alphas, + d_alphas, + means, + d_means, + icf, + d_icf, + x, + wishart.gamma, + wishart.m, + &mut my_err, + &mut my_derr, + &mut qdiags, + &mut bqdiags, + &mut sum_qs, + &mut bsum_qs, + &mut xcentered, + &mut bxcentered, + &mut qxcentered, + &mut bqxcentered, + &mut main_term, + &mut bmain_term, + )}; + + unsafe { *err = my_err }; + unsafe { *derr = my_derr }; +} + +#[no_mangle] +pub extern "C" fn rust_gmm_objective( + d: usize, + k: usize, + n: usize, + alphas: *const f64, + means: *const f64, + icf: *const f64, + x: *const f64, + wishart: *const Wishart, + err: *mut f64, +) { + let alphas = unsafe { std::slice::from_raw_parts(alphas, k) }; + let means = unsafe { std::slice::from_raw_parts(means, k * d) }; + let icf = unsafe { std::slice::from_raw_parts(icf, k * d * (d + 1) / 2) }; + let x = unsafe { std::slice::from_raw_parts(x, n * d) }; + let wishart: Wishart = unsafe { *wishart }; + let mut my_err = unsafe { *err }; + let (mut qdiags, mut sum_qs, mut xcentered, mut qxcentered, mut main_term) = + get_workspace(d, k); + gmm_objective( + d, + k, + n, + alphas, + means, + icf, + x, + wishart.gamma, + wishart.m, + &mut my_err, + &mut qdiags, + &mut sum_qs, + &mut xcentered, + &mut qxcentered, + &mut main_term, + ); + unsafe { *err = my_err }; +} + +fn get_workspace(d: usize, k: usize) -> (Vec, Vec, Vec, Vec, Vec) { + let qdiags = vec![0.; d * k]; + let sum_qs = vec![0.; k]; + let xcentered = vec![0.; d]; + let qxcentered = vec![0.; d]; + let main_term = vec![0.; k]; + (qdiags, sum_qs, xcentered, qxcentered, main_term) +} + +#[autodiff_reverse( + dgmm_objective, + Const, + Const, + Const, + Duplicated, + Duplicated, + Duplicated, + Const, + Const, + Const, + DuplicatedOnly, + Duplicated, + Duplicated, + Duplicated, + Duplicated, + Duplicated +)] +pub fn gmm_objective( + d: usize, + k: usize, + n: usize, + alphas: &[f64], + means: &[f64], + icf: &[f64], + x: &[f64], + gamma: f64, + m: i32, + err: &mut f64, + qdiags: &mut [f64], + sum_qs: &mut [f64], + xcentered: &mut [f64], + qxcentered: &mut [f64], + main_term: &mut [f64], +) { + let wishart: Wishart = Wishart { gamma, m }; + let constant = -(n as f64) * d as f64 * 0.5 * (2.0 * PI).ln(); + let icf_sz = d * (d + 1) / 2; + + // Let the compiler know sizes so it can eliminate bounds checks + assert_eq!(qdiags.len(), d * k); + assert_eq!(sum_qs.len(), k); + assert_eq!(xcentered.len(), d); + assert_eq!(qxcentered.len(), d); + assert_eq!(main_term.len(), k); + + preprocess_qs(d, k, icf, sum_qs, qdiags); + + let mut slse = 0.; + for ix in 0..n { + for ik in 0..k { + subtract( + d, + &x[ix as usize * d as usize..], + &means[ik as usize * d as usize..], + xcentered, + ); + qtimesx( + d, + &qdiags[ik as usize * d as usize..], + &icf[ik as usize * icf_sz as usize + d as usize..], + &*xcentered, + qxcentered, + ); + main_term[ik as usize] = + alphas[ik as usize] + sum_qs[ik as usize] - 0.5 * sqnorm(&*qxcentered); + } + + slse = slse + log_sum_exp(k, &main_term); + } + + let lse_alphas = log_sum_exp(k, alphas); + + *err = constant + slse - n as f64 * lse_alphas + + log_wishart_prior(d, k, wishart, &sum_qs, &*qdiags, icf); +} + +fn arr_max(n: usize, x: &[f64]) -> f64 { + let mut max = f64::NEG_INFINITY; + for i in 0..n { + if max < x[i] { + max = x[i]; + } + } + max +} + +fn preprocess_qs(d: usize, k: usize, icf: &[f64], sum_qs: &mut [f64], qdiags: &mut [f64]) { + let icf_sz = d * (d + 1) / 2; + for ik in 0..k { + sum_qs[ik as usize] = 0.; + for id in 0..d { + let q = icf[ik as usize * icf_sz as usize + id as usize]; + sum_qs[ik as usize] = sum_qs[ik as usize] + q; + qdiags[ik as usize * d as usize + id as usize] = q.exp(); + } + } +} +fn subtract(d: usize, x: &[f64], y: &[f64], out: &mut [f64]) { + assert!(x.len() >= d); + assert!(y.len() >= d); + assert!(out.len() >= d); + for i in 0..d { + out[i] = x[i] - y[i]; + } +} + +fn qtimesx(d: usize, q_diag: &[f64], ltri: &[f64], x: &[f64], out: &mut [f64]) { + assert!(out.len() >= d); + assert!(q_diag.len() >= d); + assert!(x.len() >= d); + for i in 0..d { + out[i] = q_diag[i] * x[i]; + } + + for i in 0..d { + let mut lparamsidx = i * (2 * d - i - 1) / 2; + for j in i + 1..d { + out[j] = out[j] + ltri[lparamsidx] * x[i]; + lparamsidx += 1; + } + } +} + +fn log_sum_exp(n: usize, x: &[f64]) -> f64 { + let mx = arr_max(n, x); + let semx: f64 = x.iter().map(|x| (x - mx).exp()).sum(); + semx.ln() + mx +} +fn log_gamma_distrib(a: f64, p: f64) -> f64 { + 0.25 * p * (p - 1.) * PI.ln() + + (1..=p as usize) + .map(|j| lgamma(a + 0.5 * (1. - j as f64))) + .sum::() +} + +fn log_wishart_prior( + p: usize, + k: usize, + wishart: Wishart, + sum_qs: &[f64], + qdiags: &[f64], + icf: &[f64], +) -> f64 { + let n = p + wishart.m as usize + 1; + let icf_sz = p * (p + 1) / 2; + + let c = n as f64 * p as f64 * (wishart.gamma.ln() - 0.5 * 2f64.ln()) + - log_gamma_distrib(0.5 * n as f64, p as f64); + + let out = (0..k) + .map(|ik| { + let frobenius = sqnorm(&qdiags[ik * p as usize..][..p]) + + sqnorm(&icf[ik * icf_sz as usize + p as usize..][..icf_sz - p]); + 0.5 * wishart.gamma * wishart.gamma * (frobenius) + - (wishart.m as f64) * sum_qs[ik as usize] + }) + .sum::(); + + out - k as f64 * c +} + +fn sqnorm(x: &[f64]) -> f64 { + x.iter().map(|x| x * x).sum() +} diff --git a/tests/ui/autodiff/adbench/gmm/src/unsafe.rs b/tests/ui/autodiff/adbench/gmm/src/unsafe.rs new file mode 100644 index 0000000000000..439b949627975 --- /dev/null +++ b/tests/ui/autodiff/adbench/gmm/src/unsafe.rs @@ -0,0 +1,143 @@ +use std::f64::consts::PI; +use crate::Wishart; +use std::autodiff::autodiff_reverse; + +//#[cfg(feature = "libm")] +//use libm::lgamma; +// +//#[cfg(not(feature = "libm"))] +mod cmath { + extern "C" { + pub fn lgamma(x: f64) -> f64; + } +} +//#[cfg(not(feature = "libm"))] +#[inline] +fn lgamma(x: f64) -> f64 { + unsafe { cmath::lgamma(x) } +} + +#[no_mangle] +pub extern "C" fn rust_unsafe_dgmm_objective(d: i32, k: i32, n: i32, alphas: *const f64, dalphas: *mut f64, means: *const f64, dmeans: *mut f64, icf: *const f64, dicf: *mut f64, x: *const f64, wishart: *const Wishart, err: *mut f64, derr: *mut f64) { + let k = k as usize; + let n = n as usize; + let d = d as usize; + unsafe { dgmm_objective(d, k, n, alphas, dalphas, means, dmeans, icf, dicf, x, wishart, err, derr); } +} + +#[no_mangle] +pub extern "C" fn rust_unsafe_gmm_objective(d: i32, k: i32, n: i32, alphas: *const f64, means: *const f64, icf: *const f64, x: *const f64, wishart: *const Wishart, err: *mut f64) { + let k = k as usize; + let n = n as usize; + let d = d as usize; + unsafe {gmm_objective(d, k, n, alphas, means, icf, x, wishart, err); } +} + +#[autodiff_reverse(dgmm_objective, Const, Const, Const, Duplicated, Duplicated, Duplicated, Const, Const, DuplicatedOnly)] +pub unsafe fn gmm_objective(d: usize, k: usize, n: usize, alphas: *const f64, means: *const f64, icf: *const f64, x: *const f64, wishart: *const Wishart, err: *mut f64) { + let constant = -(n as f64) * d as f64 * 0.5 * (2.0 * PI).ln(); + let icf_sz = d * (d + 1) / 2; + let mut qdiags = vec![0.; d * k]; + let mut sum_qs = vec![0.; k]; + let mut xcentered = vec![0.; d]; + let mut qxcentered = vec![0.; d]; + let mut main_term = vec![0.; k]; + + preprocess_qs(d, k, icf, sum_qs.as_mut_ptr(), qdiags.as_mut_ptr()); + + let mut slse = 0.; + for ix in 0..n { + for ik in 0..k { + subtract(d, x.add(ix * d), means.add(ik * d), xcentered.as_mut_ptr()); + qtimesx(d, qdiags.as_mut_ptr().add(ik * d), icf.add(ik * icf_sz + d), xcentered.as_ptr(), qxcentered.as_mut_ptr()); + main_term[ik] = *alphas.add(ik) + sum_qs[ik] - 0.5 * sqnorm(d, qxcentered.as_ptr()); + //main_term[ik] = alphas[ik] + sum_qs[ik] - 0.5 * sqnorm(d, &Qxcentered[0]); + } + + slse = slse + log_sum_exp(k, main_term.as_ptr()); + } + + let lse_alphas = log_sum_exp(k, alphas); + + *err = constant + slse - n as f64 * lse_alphas + log_wishart_prior(d, k, *wishart, sum_qs.as_ptr(), qdiags.as_ptr(), icf); +} + +unsafe fn arr_max(n: usize, x: *const f64) -> f64 { + let mut max = f64::NEG_INFINITY; + for i in 0..n { + if max < *x.add(i) { + max = *x.add(i); + } + } + max +} + +unsafe fn preprocess_qs(d: usize, k: usize, icf: *const f64, sum_qs: *mut f64, qdiags: *mut f64) { + let icf_sz = d * (d + 1) / 2; + for ik in 0..k { + *sum_qs.add(ik) = 0.; + for id in 0..d { + let q = *icf.add(ik * icf_sz + id); + *sum_qs.add(ik) = *sum_qs.add(ik) + q; + *qdiags.add(ik * d + id) = q.exp(); + } + } +} + +unsafe fn subtract(d: usize, x: *const f64, y: *const f64, out: *mut f64) { + for i in 0..d { + *out.add(i) = *x.add(i) - *y.add(i); + } +} + +unsafe fn qtimesx(d: usize, q_diag: *const f64, ltri: *const f64, x: *const f64, out: *mut f64) { + for i in 0..d { + *out.add(i) = *q_diag.add(i) * *x.add(i); + } + + for i in 0..d { + let mut lparamsidx = i*(2*d-i-1)/2; + for j in i + 1..d { + *out.add(j) = *out.add(j) + *ltri.add(lparamsidx) * *x.add(i); + lparamsidx += 1; + } + } +} + +unsafe fn log_sum_exp(n: usize, x: *const f64) -> f64 { + let mx = arr_max(n, x); + let mut semx: f64 = 0.0; + + for i in 0..n { + semx = semx + (*x.add(i) - mx).exp(); + } + semx.ln() + mx +} + +fn log_gamma_distrib(a: f64, p: f64) -> f64 { + 0.25 * p * (p - 1.) * PI.ln() + (1..=p as usize).map(|j| lgamma(a + 0.5 * (1. - j as f64))).sum::() +} + +unsafe fn log_wishart_prior(p: usize, k: usize, wishart: Wishart, sum_qs: *const f64, qdiags: *const f64, icf: *const f64) -> f64 { + let n = p + wishart.m as usize + 1; + let icf_sz = p * (p + 1) / 2; + + let c = n as f64 * p as f64 * (wishart.gamma.ln() - 0.5 * 2f64.ln()) - log_gamma_distrib(0.5 * n as f64, p as f64); + + let mut out = 0.; + + for ik in 0..k { + let frobenius = sqnorm(p, qdiags.add(ik * p)) + sqnorm(icf_sz - p, icf.add(ik * icf_sz + p)); + out = out + 0.5 * wishart.gamma * wishart.gamma * (frobenius) - wishart.m as f64 * *sum_qs.add(ik); + } + + out - k as f64 * c +} + +unsafe fn sqnorm(n: usize, x: *const f64) -> f64 { + let mut sum = 0.; + for i in 0..n { + sum += *x.add(i) * *x.add(i); + } + sum +} diff --git a/tests/ui/autodiff/adbench/lstm/src/lib.rs b/tests/ui/autodiff/adbench/lstm/src/lib.rs new file mode 100644 index 0000000000000..1910314f2ef2f --- /dev/null +++ b/tests/ui/autodiff/adbench/lstm/src/lib.rs @@ -0,0 +1,62 @@ +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat --crate-type=staticlib +//@ build-pass +//@ dont-check-compiler-stderr +//@ dont-check-compiler-stdout +//@ no-prefer-dynamic +//@ needs-enzyme +#![feature(autodiff)] + +pub (crate) mod unsf; +pub (crate) mod safe; +use std::slice; + + +#[no_mangle] +pub extern "C" fn rust_unsafe_lstm_objective(l: i32, c: i32, b: i32, main_params: *const f64, extra_params: *const f64, state: *mut f64, sequence: *const f64, loss: *mut f64) { + let l = l as usize; + let c = c as usize; + let b = b as usize; + unsafe {unsf::lstm_unsafe_objective(l,c,b,main_params,extra_params,state,sequence, loss);} +} +#[no_mangle] +pub extern "C" fn rust_safe_lstm_objective(l: i32, c: i32, b: i32, main_params: *const f64, extra_params: *const f64, state: *mut f64, sequence: *const f64, loss: *mut f64) { + let l = l as usize; + let c = c as usize; + let b = b as usize; + let (main_params, extra_params, state, sequence) = unsafe {( + slice::from_raw_parts(main_params, 2*l*4*b), + slice::from_raw_parts(extra_params, 3*b), + slice::from_raw_parts_mut(state, 2*l*b), + slice::from_raw_parts(sequence, c*b) + )}; + + unsafe { + safe::lstm_objective(l,c,b,main_params,extra_params,state,sequence, &mut *loss); + } +} + +#[no_mangle] +pub extern "C" fn rust_unsafe_dlstm_objective(l: i32, c: i32, b: i32, main_params: *const f64, d_main_params: *mut f64, extra_params: *const f64, d_extra_params: *mut f64, state: *mut f64, sequence: *const f64, res: *mut f64, d_res: *mut f64) { + let l = l as usize; + let c = c as usize; + let b = b as usize; + unsafe {unsf::d_lstm_unsafe_objective(l,c,b,main_params,d_main_params, extra_params,d_extra_params, state,sequence, res, d_res);} +} +#[no_mangle] +pub extern "C" fn rust_safe_dlstm_objective(l: i32, c: i32, b: i32, main_params: *const f64, d_main_params: *mut f64, extra_params: *const f64, d_extra_params: *mut f64, state: *mut f64, sequence: *const f64, res: *mut f64, d_res: *mut f64) { + let l = l as usize; + let c = c as usize; + let b = b as usize; + let (main_params, d_main_params, extra_params, d_extra_params, state, sequence) = unsafe {( + slice::from_raw_parts(main_params, 2*l*4*b), + slice::from_raw_parts_mut(d_main_params, 2*l*4*b), + slice::from_raw_parts(extra_params, 3*b), + slice::from_raw_parts_mut(d_extra_params, 3*b), + slice::from_raw_parts_mut(state, 2*l*b), + slice::from_raw_parts(sequence, c*b) + )}; + + unsafe { + safe::d_lstm_objective(l,c,b,main_params,d_main_params, extra_params,d_extra_params, state,sequence, &mut *res, &mut *d_res); + } +} diff --git a/tests/ui/autodiff/adbench/lstm/src/safe.rs b/tests/ui/autodiff/adbench/lstm/src/safe.rs new file mode 100644 index 0000000000000..199cc2475afe7 --- /dev/null +++ b/tests/ui/autodiff/adbench/lstm/src/safe.rs @@ -0,0 +1,237 @@ +use std::slice; +use std::autodiff::autodiff_reverse; +//use std::hint::assert_unchecked; + +// Sigmoid on scalar +fn sigmoid(x: f64) -> f64 { + 1.0 / (1.0 + (-x).exp()) +} + +// log(sum(exp(x), 2)) +#[inline] +fn logsumexp(vect: &[f64]) -> f64 { + let mut sum = 0.0; + for &val in vect { + sum += val.exp(); + } + sum += 2.0; // Adding 2 to sum + sum.ln() +} + +// LSTM OBJECTIVE +// The LSTM model +fn lstm_model( + hsize: usize, + weight: &[f64], + bias: &[f64], + hidden: &mut [f64], + cell: &mut [f64], + input: &[f64], +) { + let mut gates = vec![0.0; 4 * hsize]; + let gates = &mut gates[..4 * hsize]; + let (a, b) = gates.split_at_mut(2 * hsize); + let ((forget, ingate), (outgate, change)) = (a.split_at_mut(hsize), b.split_at_mut(hsize)); + + // unsafe {assert_unchecked(weight.len()== 4 * hsize)}; + // unsafe {assert_unchecked(bias.len()== 4 * hsize)}; + // unsafe {assert_unchecked(hidden.len()== hsize)}; + // unsafe {assert_unchecked(cell.len() >= hsize)}; + // unsafe {assert_unchecked(input.len() >= hsize)}; + // caching input + for i in 0..hsize { + forget[i] = sigmoid(input[i] * weight[i] + bias[i]); + ingate[i] = sigmoid(hidden[i] * weight[hsize + i] + bias[hsize + i]); + outgate[i] = sigmoid(input[i] * weight[2 * hsize + i] + bias[2 * hsize + i]); + change[i] = (hidden[i] * weight[3 * hsize + i] + bias[3 * hsize + i]).tanh(); + } + + // caching cell + for i in 0..hsize { + cell[i] = cell[i] * forget[i] + ingate[i] * change[i]; + } + + for i in 0..hsize { + hidden[i] = outgate[i] * cell[i].tanh(); + } +} + +// Predict LSTM output given an input +fn lstm_predict( + l: usize, + b: usize, + w: &[f64], + w2: &[f64], + s: &mut [f64], + x: &[f64], + x2: &mut [f64], +) { + for i in 0..b { + x2[i] = x[i] * w2[i]; + } + + let mut i = 0; + while i <= 2 * l * b - 1 { + // make borrow-checker happy with non-overlapping mutable references + let (xp, s1, s2) = if i == 0 { + let (s1, s2) = s.split_at_mut(b); + (x2.as_mut(), s1, s2) + } else { + let tmp = &mut s[i - 2 * b..]; + let (a, d) = tmp.split_at_mut(2 * b); + let (d, c) = d.split_at_mut(b); + + (a, d, c) + }; + + lstm_model( + b, + &w[i * 4..(i + b) * 4], + &w[(i + b) * 4..(i + 2 * b) * 4], + s1, + s2, + xp, + ); + + i += 2 * b; + } + + let xp = &s[i - 2 * b..]; + + for i in 0..b { + x2[i] = xp[i] * w2[b + i] + w2[2 * b + i]; + } +} + +// LSTM objective (loss function) +#[autodiff_reverse( + d_lstm_objective, + Const, + Const, + Const, + Duplicated, + Duplicated, + Const, + Const, + DuplicatedOnly +)] +pub(crate) fn lstm_objective( + l: usize, + c: usize, + b: usize, + main_params: &[f64], + extra_params: &[f64], + state: &mut [f64], + sequence: &[f64], + loss: &mut f64, +) { + let mut total = 0.0; + + let mut input = &sequence[..b]; + let mut ypred = vec![0.0; b]; + let mut ynorm = vec![0.0; b]; + + // unsafe{assert_unchecked(b > 0)}; + + let limit = (c - 1) * b; + for j in 0..(c - 1) { + let t = j * b; + lstm_predict(l, b, main_params, extra_params, state, input, &mut ypred); + let lse = logsumexp(&ypred); + for i in 0..b { + ynorm[i] = ypred[i] - lse; + } + + let ygold = &sequence[t + b..]; + for i in 0..b { + total += ygold[i] * ynorm[i]; + } + + input = ygold; + } + let count = (c - 1) * b; + + *loss = -total / count as f64; +} + +#[no_mangle] +pub extern "C" fn rust_lstm_objective( + l: i32, + c: i32, + b: i32, + main_params: *const f64, + extra_params: *const f64, + state: *mut f64, + sequence: *const f64, + loss: *mut f64, +) { + let l = l as usize; + let c = c as usize; + let b = b as usize; + let (main_params, extra_params, state, sequence) = unsafe { + ( + slice::from_raw_parts(main_params, 2 * l * 4 * b), + slice::from_raw_parts(extra_params, 3 * b), + slice::from_raw_parts_mut(state, 2 * l * b), + slice::from_raw_parts(sequence, c * b), + ) + }; + + unsafe { + lstm_objective( + l, + c, + b, + main_params, + extra_params, + state, + sequence, + &mut *loss, + ); + } +} + +#[no_mangle] +pub extern "C" fn rust_dlstm_objective( + l: i32, + c: i32, + b: i32, + main_params: *const f64, + d_main_params: *mut f64, + extra_params: *const f64, + d_extra_params: *mut f64, + state: *mut f64, + sequence: *const f64, + res: *mut f64, + d_res: *mut f64, +) { + let l = l as usize; + let c = c as usize; + let b = b as usize; + let (main_params, d_main_params, extra_params, d_extra_params, state, sequence) = unsafe { + ( + slice::from_raw_parts(main_params, 2 * l * 4 * b), + slice::from_raw_parts_mut(d_main_params, 2 * l * 4 * b), + slice::from_raw_parts(extra_params, 3 * b), + slice::from_raw_parts_mut(d_extra_params, 3 * b), + slice::from_raw_parts_mut(state, 2 * l * b), + slice::from_raw_parts(sequence, c * b), + ) + }; + + unsafe { + d_lstm_objective( + l, + c, + b, + main_params, + d_main_params, + extra_params, + d_extra_params, + state, + sequence, + &mut *res, + &mut *d_res, + ); + } +} diff --git a/tests/ui/autodiff/adbench/lstm/src/unsf.rs b/tests/ui/autodiff/adbench/lstm/src/unsf.rs new file mode 100644 index 0000000000000..fcad30a7c4b83 --- /dev/null +++ b/tests/ui/autodiff/adbench/lstm/src/unsf.rs @@ -0,0 +1,116 @@ +use std::autodiff::autodiff_reverse; + +// Sigmoid on scalar +fn sigmoid(x: f64) -> f64 { + 1.0 / (1.0 + (-x).exp()) +} + +// log(sum(exp(x), 2)) +unsafe fn logsumexp(vect: *const f64, sz: usize) -> f64 { + let mut sum: f64 = 0.0; + for i in 0..sz { + sum += (*vect.add(i)).exp(); + } + sum += 2.0; // Adding 2 to sum + sum.ln() +} + +// LSTM OBJECTIVE +// The LSTM model +unsafe fn lstm_model( + hsize: usize, + weight: *const f64, + bias: *const f64, + hidden: *mut f64, + cell: *mut f64, + input: *const f64, +) { +// // TODO NOTE THIS +// //__builtin_assume(hsize > 0); + let mut gates = vec![0.0; 4 * hsize]; + let forget: *mut f64 = gates.as_mut_ptr(); + let ingate: *mut f64 = gates[hsize..].as_mut_ptr(); + let outgate: *mut f64 = gates[2 * hsize..].as_mut_ptr(); + let change: *mut f64 = gates[3 * hsize..].as_mut_ptr(); + //let (a,b) = gates.split_at_mut(2*hsize); + //let ((forget, ingate), (outgate, change)) = ( + // a.split_at_mut(hsize), b.split_at_mut(hsize)); + + // caching input + for i in 0..hsize { + *forget.add(i) = sigmoid(*input.add(i) * *weight.add(i) + *bias.add(i)); + *ingate.add(i) = sigmoid(*hidden.add(i) * *weight.add(hsize + i) + *bias.add(hsize + i)); + *outgate.add(i) = sigmoid(*input.add(i) * *weight.add(2 * hsize + i) + *bias.add(2 * hsize + i)); + *change.add(i) = (*hidden.add(i) * *weight.add(3 * hsize + i) + *bias.add(3 * hsize + i)).tanh(); + } + + // caching cell + for i in 0..hsize { + *cell.add(i) = *cell.add(i) * *forget.add(i) + *ingate.add(i) * *change.add(i); + } + + for i in 0..hsize { + *hidden.add(i) = *outgate.add(i) * (*cell.add(i)).tanh(); + } +} + +// Predict LSTM output given an input +unsafe fn lstm_predict( + l: usize, + b: usize, + w: *const f64, + w2: *const f64, + s: *mut f64, + x: *const f64, + x2: *mut f64, +) { + for i in 0..b { + *x2.add(i) = *x.add(i) * *w2.add(i); + } + + let mut xp = x2; + let stop = 2 * l * b; + for i in (0..=stop - 1).step_by(2 * b) { + lstm_model(b, w.add(i * 4), w.add((i + b) * 4), s.add(i), s.add(i + b), xp); + xp = s.add(i); + } + + for i in 0..b { + *x2.add(i) = *xp.add(i) * *w2.add(b + i) + *w2.add(2 * b + i); + } +} + +// LSTM objective (loss function) +#[autodiff_reverse(d_lstm_unsafe_objective, Const, Const, Const, Duplicated, Duplicated, Const, Const, DuplicatedOnly)] +pub (crate) unsafe fn lstm_unsafe_objective(l: usize, c: usize, b: usize, main_params: *const f64, extra_params: *const f64, state: *mut f64, sequence: *const f64, loss: *mut f64) { + let mut total = 0.0; + let mut count = 0; + + //const double* input = &(sequence[0]); + let mut input = sequence; + let mut ypred = vec![0.0; b]; + let mut ynorm = vec![0.0; b]; + let mut lse; + + assert!(b > 0); + + let stop = (c - 1) * b; + for t in (0..=stop - 1).step_by(b) { + lstm_predict(l, b, main_params, extra_params, state, input, ypred.as_mut_ptr()); + lse = logsumexp(ypred.as_mut_ptr(), b); + for i in 0..b { + ynorm[i] = ypred[i] - lse; + } + + //let ygold = &sequence[t + b..]; + let ygold = sequence.add(t + b); + for i in 0..b { + total += *ygold.add(i) * ynorm[i]; + } + + count += b; + input = ygold; + } + + *loss = -total / count as f64; +} diff --git a/tests/ui/autodiff/adbench/ode-real/src/lib.rs b/tests/ui/autodiff/adbench/ode-real/src/lib.rs new file mode 100644 index 0000000000000..be1a6a903a4aa --- /dev/null +++ b/tests/ui/autodiff/adbench/ode-real/src/lib.rs @@ -0,0 +1,103 @@ +//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat --crate-type=staticlib +//@ build-pass +//@ no-prefer-dynamic +//@ needs-enzyme +#![feature(autodiff)] +#![feature(iter_next_chunk)] +#![feature(array_ptr_get)] +#![allow(non_snake_case)] +#![allow(non_camel_case_types)] +#![allow(non_upper_case_globals)] + +pub mod safe; +pub mod unsf; + +type StateType = [f64; 2 * N * N]; + +const N: usize = 32; + + +#[no_mangle] +pub extern "C" fn rust_lorenz_unsf(x: *const StateType, dxdt: *mut StateType, t: f64) { + let x: &StateType = unsafe { &*x }; + let dxdt: &mut StateType = unsafe { &mut *dxdt }; + unsafe {unsf::lorenz(x, dxdt, t)}; +} + + +#[no_mangle] +pub extern "C" fn rust_lorenz_safe(x: *const StateType, dxdt: *mut StateType, t: f64) { + let x: &StateType = unsafe { &*x }; + let dxdt: &mut StateType = unsafe { &mut *dxdt }; + safe::lorenz(x, dxdt, t); +} + +#[no_mangle] +pub extern "C" fn rust_dbrusselator_2d_loop_unsf(adjoint: *mut StateType, x: *const StateType, dx: *mut StateType, p: *const [f64;3], dp: *mut [f64;3], t: f64) { + let mut null1 = [0.; 1 * N * N]; + let mut null2 = [0.; 1 * N * N]; + let dx1: *mut f64 = dx.as_mut_ptr(); + let dx2: *mut f64 = unsafe { dx.as_mut_ptr().add(N*N) }; + let dadj1: *mut f64 = adjoint.as_mut_ptr(); + let dadj2: *mut f64 = unsafe { adjoint.as_mut_ptr().add(N*N) }; + let x1: *const f64 = x.as_ptr(); + let x2: *const f64 = unsafe { x.as_ptr().add(N*N) }; + + unsafe {unsf::dbrusselator_2d_loop_unsf(null1.as_mut_ptr(), dadj1, + null2.as_mut_ptr(), dadj2, + x1, dx1, + x2, dx2, + p as *mut f64, dp as *mut f64, t)}; +} + +#[no_mangle] +pub extern "C" fn rust_dbrusselator_2d_loop_safe(adjoint: *mut StateType, x: *const StateType, dx: *mut StateType, p: *const [f64;3], dp: *mut [f64;3], t: f64) { + let x: &StateType = unsafe { &*x }; + let dx: &mut StateType = unsafe { &mut *dx }; + let adjoint: &mut StateType = unsafe { &mut *adjoint }; + + let p: &[f64;3] = unsafe { &*p }; + let dp: &mut [f64;3] = unsafe { &mut *dp }; + + assert!(p[0] == 3.4); + assert!(p[1] == 1.); + assert!(p[2] == 10.); + assert!(t == 2.1); + + //let mut x1 = [0.; 2 * N * N]; + //let mut dx1 = [0.; 2 *N * N]; + //let (tmp1, tmp2) = x1.split_at_mut(N * N); + //let mut x1: [f64; N * N] = tmp1.try_into().unwrap(); + //let mut x2: [f64; N * N] = tmp2.try_into().unwrap(); + //init_brusselator(&mut x1, &mut x2); + //for i in 0..N*N { + // let tmp = (x1[i] - x[i]).abs(); + // if (tmp / x[i] > 1e-5) { + // dbg!(tmp); + // dbg!(tmp / x[i]); + // dbg!(i); + // dbg!(x1[i]); + // dbg!(x[i]); + // println!("x1[{}] = {} != x[{}] = {}", i, x1[i], i, x[i]); + // panic!(); + // } + //} + + // Alternative ways to split the inputs + //let [ mut dx1, mut dx2]: [[f64; N*N]; 2] = unsafe { *std::mem::transmute::<*mut StateType, &mut [[f64; N*N]; 2]>(dx) }; + //let [dx1, dx2]: &mut [[f64; N*N];2] = unsafe { dx.cast::<[[f64; N*N]; 2]>().as_mut().unwrap() }; + + // https://discord.com/channels/273534239310479360/273541522815713281/1236945105601040446 + let ([dx1, dx2], []): (&mut [[f64; N*N]], &mut [f64]) = dx.as_chunks_mut() else { unreachable!() }; + let ([dadj1, dadj2], []): (&mut [[f64; N*N]], &mut [f64])= adjoint.as_chunks_mut() else { unreachable!() }; + let ([x1, x2], []): (&[[f64; N*N]], &[f64])= x.as_chunks() else { unreachable!() }; + + let mut null1 = [0.; 1 * N * N]; + let mut null2 = [0.; 1 * N * N]; + safe::dbrusselator_2d_loop(&mut null1, dadj1, + &mut null2, dadj2, + x1, dx1, + x2, dx2, + p, dp, t); + return; +} diff --git a/tests/ui/autodiff/adbench/ode-real/src/safe.rs b/tests/ui/autodiff/adbench/ode-real/src/safe.rs new file mode 100644 index 0000000000000..3ccf8921d872e --- /dev/null +++ b/tests/ui/autodiff/adbench/ode-real/src/safe.rs @@ -0,0 +1,76 @@ +use std::autodiff::autodiff_reverse; +use std::convert::TryInto; + +const N: usize = 32; +const xmin: f64 = 0.; +const xmax: f64 = 1.; +const ymin: f64 = 0.; +const ymax: f64 = 1.; + +#[inline(always)] +fn range(min: f64, max: f64, i: usize, N_var: usize) -> f64 { + (max - min) / (N_var as f64 - 1.) * i as f64 + min +} + +fn brusselator_f(x: f64, y: f64, t: f64) -> f64 { + let eq1 = (x - 0.3) * (x - 0.3) + (y - 0.6) * (y - 0.6) <= 0.1 * 0.1; + let eq2 = t >= 1.1; + if eq1 && eq2 { + 5.0 + } else { + 0.0 + } +} + +#[expect(unused)] +fn init_brusselator(u: &mut [f64], v: &mut [f64]) { + assert!(u.len() == N * N); + assert!(v.len() == N * N); + for i in 0..N { + for j in 0..N { + let x = range(xmin, xmax, i, N); + let y = range(ymin, ymax, j, N); + u[N * i + j] = 22.0 * (y * (1.0 - y)) * (y * (1.0 - y)).sqrt(); + v[N * i + j] = 27.0 * (x * (1.0 - x)) * (x * (1.0 - x)).sqrt(); + } + } +} + +#[no_mangle] +#[autodiff_reverse(dbrusselator_2d_loop, Duplicated, Duplicated, Duplicated, Duplicated, Duplicated, Const)] +pub fn brusselator_2d_loop(d_u: &mut [f64;N*N], d_v: &mut [f64;N*N], u: &[f64;N*N], v: &[f64;N*N], p: &[f64;3], t: f64) { + let A = p[0]; + let B = p[1]; + let alpha = p[2]; + let dx = 1. / (N - 1) as f64; + let alpha = alpha / (dx * dx); + for i in 0..N { + for j in 0..N { + let x = range(xmin, xmax, i, N); + let y = range(ymin, ymax, j, N); + let ip1 = if i == N - 1 { i } else { i + 1 }; + let im1 = if i == 0 { i } else { i - 1 }; + let jp1 = if j == N - 1 { j } else { j + 1 }; + let jm1 = if j == 0 { j } else { j - 1 }; + let u2v = u[N * i + j] * u[N * i + j] * v[N * i + j]; + d_u[N * i + j] = alpha * (u[N * im1 + j] + u[N * ip1 + j] + u[N * i + jp1] + u[N * i + jm1] - 4. * u[N * i + j]) + + B + u2v - (A + 1.) * u[N * i + j] + brusselator_f(x, y, t); + d_v[N * i + j] = alpha * (v[N * im1 + j] + v[N * ip1 + j] + v[N * i + jp1] + v[N * i + jm1] - 4. * v[N * i + j]) + + A * u[N * i + j] - u2v; + } + } +} + +pub type StateType = [f64; 2 * N * N]; + +pub fn lorenz(x: &StateType, dxdt: &mut StateType, t: f64) { + let p = [3.4, 1., 10.]; + let (tmp1, tmp2) = dxdt.split_at_mut(N * N); + let mut dxdt1: [f64; N * N] = tmp1.try_into().unwrap(); + let mut dxdt2: [f64; N * N] = tmp2.try_into().unwrap(); + let (tmp1, tmp2) = x.split_at(N * N); + let u: [f64; N * N] = tmp1.try_into().unwrap(); + let v: [f64; N * N] = tmp2.try_into().unwrap(); + brusselator_2d_loop(&mut dxdt1, &mut dxdt2, &u, &v, &p, t); +} + diff --git a/tests/ui/autodiff/adbench/ode-real/src/unsf.rs b/tests/ui/autodiff/adbench/ode-real/src/unsf.rs new file mode 100644 index 0000000000000..7db673ed54ad6 --- /dev/null +++ b/tests/ui/autodiff/adbench/ode-real/src/unsf.rs @@ -0,0 +1,80 @@ +use std::autodiff::autodiff_reverse; +use std::convert::TryInto; + +const N: usize = 32; +const xmin: f64 = 0.; +const xmax: f64 = 1.; +const ymin: f64 = 0.; +const ymax: f64 = 1.; + +#[inline(always)] +fn range(min: f64, max: f64, i: usize, N_var: usize) -> f64 { + (max - min) / (N_var as f64 - 1.) * i as f64 + min +} + +fn brusselator_f(x: f64, y: f64, t: f64) -> f64 { + let eq1 = (x - 0.3) * (x - 0.3) + (y - 0.6) * (y - 0.6) <= 0.1 * 0.1; + let eq2 = t >= 1.1; + if eq1 && eq2 { + 5.0 + } else { + 0.0 + } +} + +#[expect(unused)] +unsafe fn init_brusselator(u: *mut f64, v: *mut f64) { + for i in 0..N { + for j in 0..N { + let x = range(xmin, xmax, i, N); + let y = range(ymin, ymax, j, N); + *u.add(N * i + j) = 22.0 * (y * (1.0 - y)) * (y * (1.0 - y)).sqrt(); + *v.add(N * i + j) = 27.0 * (x * (1.0 - x)) * (x * (1.0 - x)).sqrt(); + } + } +} + +#[no_mangle] +#[autodiff_reverse(dbrusselator_2d_loop_unsf, Duplicated, Duplicated, Duplicated, Duplicated, Duplicated, Const)] +pub unsafe fn brusselator_2d_loop_unsf(d_u: *mut f64, d_v: *mut f64, u: *const f64, v: *const f64, p: *const f64, t: f64) { + let A = *p.add(0); + let B = *p.add(1); + let alpha = *p.add(2); + let dx = 1. / (N - 1) as f64; + let alpha = alpha / (dx * dx); + for i in 0..N { + for j in 0..N { + let x = range(xmin, xmax, i, N); + let y = range(ymin, ymax, j, N); + let ip1 = if i == N - 1 { i } else { i + 1 }; + let im1 = if i == 0 { i } else { i - 1 }; + let jp1 = if j == N - 1 { j } else { j + 1 }; + let jm1 = if j == 0 { j } else { j - 1 }; + let u2v = *u.add(N * i + j) * *u.add(N * i + j) * *v.add(N * i + j); + *d_u.add(N * i + j) = alpha * (*u.add(N * im1 + j) + *u.add(N * ip1 + j) + *u.add(N * i + jp1) + *u.add(N * i + jm1) - 4. * *u.add(N * i + j)) + + B + u2v - (A + 1.) * *u.add(N * i + j) + brusselator_f(x, y, t); + *d_v.add(N * i + j) = alpha * (*v.add(N * im1 + j) + *v.add(N * ip1 + j) + *v.add(N * i + jp1) + *v.add(N * i + jm1) - 4. * *v.add(N * i + j)) + + A * *u.add(N * i + j) - u2v; + } + } +} + +type StateType = [f64; 2 * N * N]; + +pub unsafe fn lorenz(x: *const StateType, dxdt: *mut StateType, t: f64) { + let p = [3.4, 1., 10.]; + let x = x as *const f64; + let dxdt = dxdt as *mut f64; + let dxdt1: *mut f64 = dxdt as *mut f64; + let dxdt2: *mut f64 = unsafe {dxdt.add(N * N)} as *mut f64; + //let (tmp1, tmp2) = dxdt.split_at_mut(N * N); + //let mut dxdt1: [f64; N * N] = tmp1.try_into().unwrap(); + //let mut dxdt2: [f64; N * N] = tmp2.try_into().unwrap(); + let u: *const f64 = x as *const f64; + let v: *const f64 = unsafe{x.add(N * N)} as *const f64; + //let (tmp1, tmp2) = x.split_at(N * N); + //let u: [f64; N * N] = tmp1.try_into().unwrap(); + //let v: [f64; N * N] = tmp2.try_into().unwrap(); + unsafe {brusselator_2d_loop_unsf(dxdt1 as *mut f64, dxdt2 as *mut f64, u as *const f64, v as *const f64, p.as_ptr(), t)}; +} +