Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
use rustc_ast::expand::typetree::FncTree;
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
use rustc_middle::{bug, ty};
use rustc_target::callconv::PassMode;
use tracing::debug;
Expand All @@ -16,25 +16,23 @@ use crate::llvm::{self, TRUE, Type, Value};

pub(crate) fn adjust_activity_to_abi<'tcx>(
tcx: TyCtxt<'tcx>,
instance: Instance<'tcx>,
fn_ptr_ty: Ty<'tcx>,
typing_env: TypingEnv<'tcx>,
da: &mut Vec<DiffActivity>,
) {
let fn_ty = instance.ty(tcx, typing_env);

if !matches!(fn_ty.kind(), ty::FnDef(..)) {
bug!("expected fn def for autodiff, got {:?}", fn_ty);
if !matches!(fn_ptr_ty.kind(), ty::FnPtr(..)) {
bug!("expected fn ptr for autodiff, got {:?}", fn_ptr_ty);
}

// We don't actually pass the types back into the type system.
// All we do is decide how to handle the arguments.
let sig = fn_ty.fn_sig(tcx).skip_binder();
let poly_sig = fn_ptr_ty.fn_sig(tcx);
let sig = poly_sig.skip_binder();

// FIXME(Sa4dUs): pass proper varargs once we have support for differentiating variadic functions
let Ok(fn_abi) =
tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty())))
else {
bug!("failed to get fn_abi of instance with empty varargs");
let pci = PseudoCanonicalInput { typing_env, value: (poly_sig, ty::List::empty()) };
let Ok(fn_abi) = tcx.fn_abi_of_fn_ptr(pci) else {
bug!("failed to get fn_abi of fn_ptr with empty varargs");
};

let mut new_activities = vec![];
Expand Down
6 changes: 5 additions & 1 deletion compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1229,9 +1229,13 @@ fn codegen_autodiff<'ll, 'tcx>(
bug!("could not find autodiff attrs")
};

let fn_ty = fn_source.ty(tcx, TypingEnv::fully_monomorphized());
let fn_sig = fn_ty.fn_sig(tcx);
let fn_ptr_ty = Ty::new_fn_ptr(tcx, fn_sig);

adjust_activity_to_abi(
tcx,
fn_source,
fn_ptr_ty,
TypingEnv::fully_monomorphized(),
&mut diff_attrs.input_activity,
);
Expand Down
Loading