Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::builder::{Builder, PlaceRef, UNNAMED};
use crate::context::SimpleCx;
use crate::declare::declare_simple_fn;
use crate::llvm;
use crate::llvm::{Metadata, TRUE, Type};
use crate::llvm::{TRUE, Type};
use crate::value::Value;

pub(crate) fn adjust_activity_to_abi<'tcx>(
Expand Down Expand Up @@ -159,32 +159,32 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
let mut outer_pos: usize = 0;
let mut activity_pos = 0;

let enzyme_const = cx.create_metadata(b"enzyme_const");
let enzyme_out = cx.create_metadata(b"enzyme_out");
let enzyme_dup = cx.create_metadata(b"enzyme_dup");
let enzyme_dupv = cx.create_metadata(b"enzyme_dupv");
let enzyme_dupnoneed = cx.create_metadata(b"enzyme_dupnoneed");
let enzyme_dupnoneedv = cx.create_metadata(b"enzyme_dupnoneedv");
let global_const = cx.declare_global("enzyme_const", cx.type_ptr());
let global_out = cx.declare_global("enzyme_out", cx.type_ptr());
let global_dup = cx.declare_global("enzyme_dup", cx.type_ptr());
let global_dupv = cx.declare_global("enzyme_dupv", cx.type_ptr());
let global_dupnoneed = cx.declare_global("enzyme_dupnoneed", cx.type_ptr());
let global_dupnoneedv = cx.declare_global("enzyme_dupnoneedv", cx.type_ptr());

while activity_pos < inputs.len() {
let diff_activity = inputs[activity_pos as usize];
// Duplicated arguments received a shadow argument, into which enzyme will write the
// gradient.
let (activity, duplicated): (&Metadata, bool) = match diff_activity {
let (activity, duplicated): (&llvm::Value, bool) = match diff_activity {
DiffActivity::None => panic!("not a valid input activity"),
DiffActivity::Const => (enzyme_const, false),
DiffActivity::Active => (enzyme_out, false),
DiffActivity::ActiveOnly => (enzyme_out, false),
DiffActivity::Dual => (enzyme_dup, true),
DiffActivity::Dualv => (enzyme_dupv, true),
DiffActivity::DualOnly => (enzyme_dupnoneed, true),
DiffActivity::DualvOnly => (enzyme_dupnoneedv, true),
DiffActivity::Duplicated => (enzyme_dup, true),
DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true),
DiffActivity::FakeActivitySize(_) => (enzyme_const, false),
DiffActivity::Const => (global_const, false),
DiffActivity::Active => (global_out, false),
DiffActivity::ActiveOnly => (global_out, false),
DiffActivity::Dual => (global_dup, true),
DiffActivity::Dualv => (global_dupv, true),
DiffActivity::DualOnly => (global_dupnoneed, true),
DiffActivity::DualvOnly => (global_dupnoneedv, true),
DiffActivity::Duplicated => (global_dup, true),
DiffActivity::DuplicatedOnly => (global_dupnoneed, true),
DiffActivity::FakeActivitySize(_) => (global_const, false),
};
let outer_arg = outer_args[outer_pos];
args.push(cx.get_metadata_value(activity));
args.push(activity);
if matches!(diff_activity, DiffActivity::Dualv) {
let next_outer_arg = outer_args[outer_pos + 1];
let elem_bytes_size: u64 = match inputs[activity_pos + 1] {
Expand Down Expand Up @@ -244,7 +244,7 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
assert_eq!(cx.type_kind(next_outer_ty3), TypeKind::Integer);
args.push(next_outer_arg2);
}
args.push(cx.get_metadata_value(enzyme_const));
args.push(global_const);
args.push(next_outer_arg);
outer_pos += 2 + 2 * iterations;
activity_pos += 2;
Expand Down Expand Up @@ -353,13 +353,13 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
let mut args = Vec::with_capacity(num_args as usize + 1);
args.push(fn_to_diff);

let enzyme_primal_ret = cx.create_metadata(b"enzyme_primal_return");
let global_primal_ret = cx.declare_global("enzyme_primal_return", cx.type_ptr());
if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
args.push(cx.get_metadata_value(enzyme_primal_ret));
args.push(global_primal_ret);
}
if attrs.width > 1 {
let enzyme_width = cx.create_metadata(b"enzyme_width");
args.push(cx.get_metadata_value(enzyme_width));
let global_width = cx.declare_global("enzyme_width", cx.type_ptr());
args.push(global_width);
args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
}

Expand Down
23 changes: 0 additions & 23 deletions tests/ui/autodiff/autodiff_illegal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,6 @@ fn f14(x: f32) -> Foo {

type MyFloat = f32;

// We would like to support type alias to f32/f64 in argument type in the future,
// but that requires us to implement our checks at a later stage
// like THIR which has type information available.
#[autodiff_reverse(df15, Active, Active)]
fn f15(x: MyFloat) -> f32 {
//~^^ ERROR failed to resolve: use of undeclared type `MyFloat` [E0433]
unimplemented!()
}

// We would like to support type alias to f32/f64 in return type in the future
#[autodiff_reverse(df16, Active, Active)]
fn f16(x: f32) -> MyFloat {
Expand All @@ -136,13 +127,6 @@ fn f17(x: f64) -> F64Trans {
unimplemented!()
}

// We would like to support `#[repr(transparent)]` f32/f64 wrapper in argument type in the future
#[autodiff_reverse(df18, Active, Active)]
fn f18(x: F64Trans) -> f64 {
//~^^ ERROR failed to resolve: use of undeclared type `F64Trans` [E0433]
unimplemented!()
}

// Invalid return activity
#[autodiff_forward(df19, Dual, Active)]
fn f19(x: f32) -> f32 {
Expand All @@ -163,11 +147,4 @@ fn f21(x: f32) -> f32 {
unimplemented!()
}

struct DoesNotImplDefault;
#[autodiff_forward(df22, Dual)]
pub fn f22() -> DoesNotImplDefault {
//~^^ ERROR the function or associated item `default` exists for tuple `(DoesNotImplDefault, DoesNotImplDefault)`, but its trait bounds were not satisfied
unimplemented!()
}

fn main() {}
39 changes: 5 additions & 34 deletions tests/ui/autodiff/autodiff_illegal.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -107,53 +107,24 @@ LL | #[autodiff_reverse(df13, Reverse)]
| ^^^^^^^

error: invalid return activity Active in Forward Mode
--> $DIR/autodiff_illegal.rs:147:1
--> $DIR/autodiff_illegal.rs:131:1
|
LL | #[autodiff_forward(df19, Dual, Active)]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

error: invalid return activity Dual in Reverse Mode
--> $DIR/autodiff_illegal.rs:153:1
--> $DIR/autodiff_illegal.rs:137:1
|
LL | #[autodiff_reverse(df20, Active, Dual)]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

error: invalid return activity Duplicated in Reverse Mode
--> $DIR/autodiff_illegal.rs:160:1
--> $DIR/autodiff_illegal.rs:144:1
|
LL | #[autodiff_reverse(df21, Active, Duplicated)]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

error[E0433]: failed to resolve: use of undeclared type `MyFloat`
--> $DIR/autodiff_illegal.rs:116:1
|
LL | #[autodiff_reverse(df15, Active, Active)]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ use of undeclared type `MyFloat`

error[E0433]: failed to resolve: use of undeclared type `F64Trans`
--> $DIR/autodiff_illegal.rs:140:1
|
LL | #[autodiff_reverse(df18, Active, Active)]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ use of undeclared type `F64Trans`

error[E0599]: the function or associated item `default` exists for tuple `(DoesNotImplDefault, DoesNotImplDefault)`, but its trait bounds were not satisfied
--> $DIR/autodiff_illegal.rs:167:1
|
LL | struct DoesNotImplDefault;
| ------------------------- doesn't satisfy `DoesNotImplDefault: Default`
LL | #[autodiff_forward(df22, Dual)]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ function or associated item cannot be called on `(DoesNotImplDefault, DoesNotImplDefault)` due to unsatisfied trait bounds
|
= note: the following trait bounds were not satisfied:
`DoesNotImplDefault: Default`
which is required by `(DoesNotImplDefault, DoesNotImplDefault): Default`
help: consider annotating `DoesNotImplDefault` with `#[derive(Default)]`
|
LL + #[derive(Default)]
LL | struct DoesNotImplDefault;
|

error: aborting due to 21 previous errors
error: aborting due to 18 previous errors

Some errors have detailed explanations: E0428, E0433, E0599, E0658.
Some errors have detailed explanations: E0428, E0658.
For more information about an error, try `rustc --explain E0428`.
41 changes: 41 additions & 0 deletions tests/ui/autodiff/incremental.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//@ revisions: DEBUG RELEASE
//@[RELEASE] compile-flags: -Zautodiff=Enable,NoTT -C opt-level=3 -Clto=fat
//@[DEBUG] compile-flags: -Zautodiff=Enable,NoTT -C opt-level=0 -Clto=fat -C debuginfo=2
//@ needs-enzyme
//@ incremental
//@ no-prefer-dynamic
//@ build-pass
#![crate_type = "bin"]
#![feature(autodiff)]

// We used to use llvm's metadata to instruct enzyme how to differentiate a function.
// In debug mode we would use incremental compilation which caused the metadata to be
// dropped. We now use globals instead and add this test to verify that incremental
// keeps working. Also testing debug mode while at it.

use std::autodiff::autodiff_reverse;

#[autodiff_reverse(bar, Duplicated, Duplicated)]
pub fn foo(r: &[f64; 10], res: &mut f64) {
let mut output = [0.0; 10];
output[0] = r[0];
output[1] = r[1] * r[2];
output[2] = r[4] * r[5];
output[3] = r[2] * r[6];
output[4] = r[1] * r[7];
output[5] = r[2] * r[8];
output[6] = r[1] * r[9];
output[7] = r[5] * r[6];
output[8] = r[5] * r[7];
output[9] = r[4] * r[8];
*res = output.iter().sum();
}
fn main() {
let inputs = Box::new([3.1; 10]);
let mut d_inputs = Box::new([0.0; 10]);
let mut res = Box::new(0.0);
let mut d_res = Box::new(1.0);

bar(&inputs, &mut d_inputs, &mut res, &mut d_res);
dbg!(&d_inputs);
}
Loading