Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3001,18 +3001,18 @@ ERROR(differentiable_attr_protocol_req_assoc_func,none,
ERROR(differentiable_attr_stored_property_variable_unsupported,none,
"'@differentiable' attribute on stored property cannot specify "
"'jvp:' or 'vjp:'", ())
ERROR(differentiable_attr_class_member_no_dynamic_self,none,
"'@differentiable' attribute cannot be declared on class methods "
ERROR(differentiable_attr_class_member_dynamic_self_result_unsupported,none,
"'@differentiable' attribute cannot be declared on class members "
"returning 'Self'", ())
// TODO(TF-654): Remove when differentiation supports class initializers.
ERROR(differentiable_attr_class_init_not_yet_supported,none,
"'@differentiable' attribute does not yet support class initializers",
())
ERROR(differentiable_attr_nonfinal_class_init_unsupported,none,
"'@differentiable' attribute cannot be declared on 'init' in a non-final "
"class; consider making %0 final", (Type))
ERROR(differentiable_attr_empty_where_clause,none,
"empty 'where' clause in '@differentiable' attribute", ())
// SWIFT_ENABLE_TENSORFLOW
ERROR(differentiable_attr_nongeneric_trailing_where,none,
"trailing 'where' clause in '@differentiable' attribute of non-generic function %0", (DeclName))
"trailing 'where' clause in '@differentiable' attribute of non-generic "
"function %0", (DeclName))
ERROR(differentiable_attr_where_clause_for_nongeneric_original,none,
"'where' clause is valid only when original function is generic %0",
(DeclName))
Expand Down Expand Up @@ -3049,6 +3049,12 @@ ERROR(derivative_attr_not_in_same_file_as_original,none,
"derivative not in the same file as the original function", ())
ERROR(derivative_attr_original_stored_property_unsupported,none,
"cannot register derivative for stored property %0", (DeclNameRef))
ERROR(derivative_attr_class_member_dynamic_self_result_unsupported,none,
"cannot register derivative for class member %0 returning 'Self'",
(DeclNameRef))
ERROR(derivative_attr_nonfinal_class_init_unsupported,none,
"cannot register derivative for 'init' in a non-final class; consider "
"making %0 final", (Type))
ERROR(derivative_attr_original_already_has_derivative,none,
"a derivative already exists for %0", (DeclName))
NOTE(derivative_attr_duplicate_note,none,
Expand Down
10 changes: 10 additions & 0 deletions include/swift/SILOptimizer/Utils/Differentiation/PullbackEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,16 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
void visitUnconditionalCheckedCastAddrInst(
UnconditionalCheckedCastAddrInst *uccai);

/// Handle `unchecked_ref_cast` instruction.
/// Original: y = unchecked_ref_cast x
/// Adjoint: adj[x] += adj[y] (assuming x' and y' have the same type)
void visitUncheckedRefCastInst(UncheckedRefCastInst *urci);

/// Handle `upcast` instruction.
/// Original: y = upcast x
/// Adjoint: adj[x] += adj[y] (assuming x' and y' have the same type)
void visitUpcastInst(UpcastInst *ui);

#define NOT_DIFFERENTIABLE(INST, DIAG) void visit##INST##Inst(INST##Inst *inst);
#undef NOT_DIFFERENTIABLE

Expand Down
19 changes: 9 additions & 10 deletions lib/SIL/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2705,17 +2705,15 @@ TypeConverter::getConstantInfo(TypeExpansionContext expansion,
loweredInterfaceType);

// SWIFT_ENABLE_TENSORFLOW
// In the case of autodiff derivative functions, the above computations
// determine `silFnType` by first computing the derivative function type at
// the AST level and then lowering that. Unfortunately, the actual
// SILFunctionType for the function is determined by first lowering the
// function's AST type, and then computing the derivative function type at the
// SIL level. "Lowering" does not commute with "getting the autodiff
// associated type", so these two computations produce different results.
// Therefore `silFnType` is not the actual type of the function that
// `constant` refers to.
// For derivative functions, the above computations determine `silFnType`
// by first computing the derivative AST function type and then lowering it to
// SIL. Unfortunately, the expected derivative SIL function type is determined
// by first lowering the original function's AST type, and then computing its
// SIL derivative function type. "Lowering" does not commute with "getting the
// derivative type", so these two computations produce different results.
// Therefore, `silFnType` is not the expected SIL derivative function type.
//
// We hackily fix this problem by redoing the computation in the right order.
// We fix this problem by performing the computation in the right order.
if (auto *autoDiffFuncId = constant.autoDiffDerivativeFunctionIdentifier) {
auto origFnConstantInfo = getConstantInfo(
TypeExpansionContext::minimal(), constant.asAutoDiffOriginalFunction());
Expand All @@ -2725,6 +2723,7 @@ TypeConverter::getConstantInfo(TypeExpansionContext expansion,
loweredIndices, /*resultIndex*/ 0, autoDiffFuncId->getKind(),
*this, LookUpConformanceInModule(&M));
}
// SWIFT_ENABLE_TENSORFLOW END

LLVM_DEBUG(llvm::dbgs() << "lowering type for constant ";
constant.print(llvm::dbgs());
Expand Down
37 changes: 31 additions & 6 deletions lib/SILGen/SILGenPoly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3749,8 +3749,9 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
auto *thunk = fb.getOrCreateFunction(
loc, name, customDerivativeFn->getLinkage(), thunkFnTy, IsBare,
IsNotTransparent, customDerivativeFn->isSerialized(),
customDerivativeFn->isDynamicallyReplaceable(), customDerivativeFn->getEntryCount(),
IsThunk, customDerivativeFn->getClassSubclassScope());
customDerivativeFn->isDynamicallyReplaceable(),
customDerivativeFn->getEntryCount(), IsThunk,
customDerivativeFn->getClassSubclassScope());
thunk->setInlineStrategy(AlwaysInline);
if (!thunk->empty())
return thunk;
Expand All @@ -3762,15 +3763,39 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
thunkSGF.collectThunkParams(loc, params, &indirectResults);

auto *fnRef = thunkSGF.B.createFunctionRef(loc, customDerivativeFn);
auto fnRefType =
fnRef->getType().castTo<SILFunctionType>();
auto fnRefType = fnRef->getType().castTo<SILFunctionType>();

// Collect thunk arguments, converting ownership.
SmallVector<SILValue, 8> arguments;
for (auto *indRes : indirectResults)
arguments.push_back(indRes);
forwardFunctionArguments(thunkSGF, loc, fnRefType, params,
arguments);
forwardFunctionArguments(thunkSGF, loc, fnRefType, params, arguments);

// Special support for thunking class initializer derivatives.
//
// User-defined custom derivatives take a metatype as the last parameter:
// - `$(Param0, Param1, ..., @thick Class.Type) -> (...)`
// But class initializers take an allocated instance as the last parameter:
// - `$(Param0, Param1, ..., @owned Class) -> (...)`
//
// Adjust forwarded arguments:
// - Pop the last `@owned Class` argument.
// - Create a `@thick Class.Type` value and pass it as the last argument.
auto *origAFD =
cast<AbstractFunctionDecl>(originalFn->getDeclContext()->getAsDecl());
if (isa<ConstructorDecl>(origAFD) &&
SILDeclRef(origAFD, SILDeclRef::Kind::Initializer).mangle() ==
originalFn->getName()) {
auto classArgument = arguments.pop_back_val();
auto *classDecl = classArgument->getType().getClassOrBoundGenericClass();
assert(classDecl && "Expected last argument to have class type");
auto classMetatype = MetatypeType::get(
classDecl->getDeclaredInterfaceType(), MetatypeRepresentation::Thick);
auto canClassMetatype = classMetatype->getCanonicalType();
auto *metatype = thunkSGF.B.createMetatype(
loc, SILType::getPrimitiveObjectType(canClassMetatype));
arguments.push_back(metatype);
}
// Apply function argument.
auto apply = thunkSGF.emitApplyWithRethrow(
loc, fnRef, /*substFnType*/ fnRef->getType(),
Expand Down
28 changes: 24 additions & 4 deletions lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB,
auto *tanField = cast<VarDecl>(tanFieldLookup.front());
// Create a local allocation for the element adjoint buffer.
auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType();
auto eltTanSILType = SILType::getPrimitiveAddressType(eltTanType);
auto eltTanSILType =
remapType(SILType::getPrimitiveAddressType(eltTanType));
auto *eltAdjBuffer = createFunctionLocalAllocation(eltTanSILType, loc);
builder.emitScopedBorrowOperation(
loc, adjClass, [&](SILValue borrowedAdjClass) {
Expand Down Expand Up @@ -1090,7 +1091,7 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint,
auto arrayTanType = cast<StructType>(arrayAdjoint->getType().getASTType());
auto arrayType = arrayTanType->getParent()->castTo<BoundGenericStructType>();
auto eltTanType = arrayType->getGenericArgs().front()->getCanonicalType();
auto eltTanSILType = SILType::getPrimitiveAddressType(eltTanType);
auto eltTanSILType = remapType(SILType::getPrimitiveAddressType(eltTanType));
// Get `function_ref` and generic signature of
// `Array.TangentVector.subscript.getter`.
auto *arrayTanStructDecl = arrayTanType->getStructOrBoundGenericStruct();
Expand Down Expand Up @@ -1602,12 +1603,11 @@ void PullbackEmitter::visitLoadOperation(SingleValueInstruction *inst) {
void PullbackEmitter::visitStoreOperation(SILBasicBlock *bb, SILLocation loc,
SILValue origSrc, SILValue origDest) {
auto &adjBuf = getAdjointBuffer(bb, origDest);
auto bufType = remapType(adjBuf->getType());
auto adjVal =
builder.emitLoadValueOperation(loc, adjBuf, LoadOwnershipQualifier::Take);
recordTemporary(adjVal);
addAdjointValue(bb, origSrc, makeConcreteAdjointValue(adjVal), loc);
emitZeroIndirect(bufType.getASTType(), adjBuf, loc);
emitZeroIndirect(adjBuf->getType().getASTType(), adjBuf, loc);
}

void PullbackEmitter::visitStoreInst(StoreInst *si) {
Expand Down Expand Up @@ -1672,6 +1672,26 @@ void PullbackEmitter::visitUnconditionalCheckedCastAddrInst(
emitZeroIndirect(destType.getASTType(), adjDest, uccai->getLoc());
}

void PullbackEmitter::visitUncheckedRefCastInst(UncheckedRefCastInst *urci) {
auto *bb = urci->getParent();
assert(urci->getOperand()->getType().isObject());
assert(getRemappedTangentType(urci->getOperand()->getType()) ==
getRemappedTangentType(urci->getType()) &&
"Operand/result must have the same `TangentVector` type");
auto adj = getAdjointValue(bb, urci);
addAdjointValue(bb, urci->getOperand(), adj, urci->getLoc());
}

void PullbackEmitter::visitUpcastInst(UpcastInst *ui) {
auto *bb = ui->getParent();
assert(ui->getOperand()->getType().isObject());
assert(getRemappedTangentType(ui->getOperand()->getType()) ==
getRemappedTangentType(ui->getType()) &&
"Operand/result must have the same `TangentVector` type");
auto adj = getAdjointValue(bb, ui);
addAdjointValue(bb, ui->getOperand(), adj, ui->getLoc());
}

#define NOT_DIFFERENTIABLE(INST, DIAG) \
void PullbackEmitter::visit##INST##Inst(INST##Inst *inst) { \
getContext().emitNondifferentiabilityError(inst, getInvoker(), \
Expand Down
77 changes: 58 additions & 19 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3969,32 +3969,39 @@ llvm::Expected<IndexSubset *> DifferentiableAttributeTypeCheckRequest::evaluate(
return nullptr;
}

// Diagnose if original function is an invalid class member.
bool isOriginalClassMember = original->getDeclContext() &&
original->getDeclContext()->getSelfClassDecl();

// Diagnose if original function is an invalid class member.
if (isOriginalClassMember) {
// Class methods returning dynamic `Self` are not supported.
// (For class methods, dynamic `Self` is supported only as the single
// result - tuple-returning JVPs/VJPs would not type-check.)
if (auto *originalFn = dyn_cast<FuncDecl>(original)) {
if (originalFn->hasDynamicSelfResult()) {
diags.diagnose(attr->getLocation(),
diag::differentiable_attr_class_member_no_dynamic_self);
auto *classDecl = original->getDeclContext()->getSelfClassDecl();
assert(classDecl);
// Class members returning dynamic `Self` are not supported.
// Dynamic `Self` is supported only as a single top-level result for class
// members. JVP/VJP functions returning `(Self, ...)` tuples would not
// type-check.
bool diagnoseDynamicSelfResult = original->hasDynamicSelfResult();
if (diagnoseDynamicSelfResult) {
// Diagnose class initializers in non-final classes.
if (isa<ConstructorDecl>(original)) {
if (!classDecl->isFinal()) {
diags.diagnose(
attr->getLocation(),
diag::differentiable_attr_nonfinal_class_init_unsupported,
classDecl->getDeclaredInterfaceType());
attr->setInvalid();
return nullptr;
}
}
// Diagnose all other declarations returning dynamic `Self`.
else {
diags.diagnose(
attr->getLocation(),
diag::
differentiable_attr_class_member_dynamic_self_result_unsupported);
attr->setInvalid();
return nullptr;
}
}

// TODO(TF-654): Class initializers are not yet supported.
// Extra JVP/VJP type calculation logic is necessary because classes have
// both allocators and initializers.
if (auto *initDecl = dyn_cast<ConstructorDecl>(original)) {
diags.diagnose(attr->getLocation(),
diag::differentiable_attr_class_init_not_yet_supported);
attr->setInvalid();
return nullptr;
}
}

// Resolve the derivative generic signature.
Expand Down Expand Up @@ -4284,6 +4291,38 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
return true;
}
}
// Diagnose if original function is an invalid class member.
bool isOriginalClassMember =
originalAFD->getDeclContext() &&
originalAFD->getDeclContext()->getSelfClassDecl();
if (isOriginalClassMember) {
auto *classDecl = originalAFD->getDeclContext()->getSelfClassDecl();
assert(classDecl);
// Class members returning dynamic `Self` are not supported.
// Dynamic `Self` is supported only as a single top-level result for class
// members. JVP/VJP functions returning `(Self, ...)` tuples would not
// type-check.
bool diagnoseDynamicSelfResult = originalAFD->hasDynamicSelfResult();
if (diagnoseDynamicSelfResult) {
// Diagnose class initializers in non-final classes.
if (isa<ConstructorDecl>(originalAFD)) {
if (!classDecl->isFinal()) {
diags.diagnose(attr->getLocation(),
diag::derivative_attr_nonfinal_class_init_unsupported,
classDecl->getDeclaredInterfaceType());
return true;
}
}
// Diagnose all other declarations returning dynamic `Self`.
else {
diags.diagnose(
attr->getLocation(),
diag::derivative_attr_class_member_dynamic_self_result_unsupported,
DeclNameRef(originalAFD->getFullName()));
return true;
}
}
}
attr->setOriginalFunction(originalAFD);

// Get the resolved differentiability parameter indices.
Expand Down
17 changes: 14 additions & 3 deletions test/AutoDiff/Sema/differentiable_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1080,8 +1080,7 @@ class Super: Differentiable {

var base: Float

// NOTE(TF-654): Class initializers are not yet supported.
// expected-error @+1 {{'@differentiable' attribute does not yet support class initializers}}
// expected-error @+1 {{'@differentiable' attribute cannot be declared on 'init' in a non-final class; consider making 'Super' final}}
@differentiable
init(base: Float) {
self.base = base
Expand Down Expand Up @@ -1124,7 +1123,7 @@ class Super: Differentiable {
func instanceMethod<T>(_ x: Float, y: T) -> Float { x }

// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
// expected-error @+1 {{'@differentiable' attribute cannot be declared on class methods returning 'Self'}}
// expected-error @+1 {{'@differentiable' attribute cannot be declared on class members returning 'Self'}}
@differentiable(vjp: vjpDynamicSelfResult)
func dynamicSelfResult() -> Self { self }

Expand All @@ -1148,6 +1147,18 @@ class Sub: Super {
override func testSuperclassDerivatives(_ x: Float) -> Float { x }
}

final class FinalClass: Differentiable {
typealias TangentVector = DummyTangentVector
func move(along _: TangentVector) {}

var base: Float

@differentiable
init(base: Float) {
self.base = base
}
}

// Test unsupported accessors: `set`, `_read`, `_modify`.

struct UnsupportedAccessors: Differentiable {
Expand Down
Loading