Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 63 additions & 58 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,9 @@ class LinearMapInfo {
/// The original function.
SILFunction *const original;

/// The derivative function.
SILFunction *const derivative;

/// Activity info of the original function.
const DifferentiableActivityInfo &activityInfo;

Expand Down Expand Up @@ -464,9 +467,16 @@ class LinearMapInfo {
/// A type converter, used to compute struct/enum SIL types.
Lowering::TypeConverter &typeConverter;

SILBuilder &builder;

private:
/// Remaps the given type into the derivative function's context.
SILType remapTypeInDerivative(SILType ty) {
if (ty.hasArchetype())
return derivative->mapTypeIntoContext(ty.mapTypeOutOfContext());
return derivative->mapTypeIntoContext(ty);
}

/// Adds a `VarDecl` member with the given name and type to the given nominal
/// declaration.
VarDecl *addVarDecl(NominalTypeDecl *nominal, StringRef name, Type type) {
auto &astCtx = nominal->getASTContext();
auto id = astCtx.getIdentifier(name);
Expand All @@ -485,9 +495,9 @@ class LinearMapInfo {
/// Retrieves the file unit that contains implicit declarations in the
/// current Swift module. If it does not exist, create one.
///
// FIXME: Currently it defaults to the file containing `origFn`, if it can be
// determined. Otherwise, it defaults to any file unit in the module. To
// handle this more properly, we should make a DerivedFileUnit class to
// FIXME: Currently it defaults to the file containing `original`, if it can
// be determined. Otherwise, it defaults to any file unit in the module. To
// handle this more properly, we could revive the DerivedFileUnit class to
// contain all synthesized implicit type declarations.
SourceFile &getDeclarationFileUnit() {
if (original->hasLocation())
Expand Down Expand Up @@ -699,7 +709,7 @@ class LinearMapInfo {
/// branching enum field.
void generateDifferentiationDataStructures(
ADContext &context, const SILAutoDiffIndices &indices,
SILFunction *assocFn);
SILFunction *derivative);

public:
bool shouldDifferentiateApplyInst(ApplyInst *ai);
Expand All @@ -710,10 +720,9 @@ class LinearMapInfo {

explicit LinearMapInfo(ADContext &context,
AutoDiffLinearMapKind kind,
SILFunction *original, SILFunction *assocFn,
SILFunction *original, SILFunction *derivative,
const SILAutoDiffIndices &indices,
const DifferentiableActivityInfo &activityInfo,
SILBuilder &builder);
const DifferentiableActivityInfo &activityInfo);

/// Returns the linear map struct associated with the given original block.
StructDecl *getLinearMapStruct(SILBasicBlock *origBB) const {
Expand Down Expand Up @@ -771,7 +780,9 @@ class LinearMapInfo {
/// `struct_extract` in the original function.
VarDecl *lookUpLinearMapDecl(SILInstruction *inst) {
auto lookup = linearMapValueMap.find(inst);
return lookup == linearMapValueMap.end() ? nullptr : lookup->getSecond();
assert(lookup != linearMapValueMap.end() &&
"No linear map declaration corresponding to the given instruction");
return lookup->getSecond();
}
};

Expand Down Expand Up @@ -1513,14 +1524,13 @@ static void collectMinimalIndicesForFunctionCall(

LinearMapInfo::LinearMapInfo(ADContext &context,
AutoDiffLinearMapKind kind,
SILFunction *original, SILFunction *assocFn,
SILFunction *original, SILFunction *derivative,
const SILAutoDiffIndices &indices,
const DifferentiableActivityInfo &activityInfo,
SILBuilder &builder)
: kind(kind), original(original), activityInfo(activityInfo),
indices(indices), typeConverter(context.getTypeConverter()),
builder(builder) {
generateDifferentiationDataStructures(context, indices, assocFn);
const DifferentiableActivityInfo &activityInfo)
: kind(kind), original(original), derivative(derivative),
activityInfo(activityInfo), indices(indices),
typeConverter(context.getTypeConverter()) {
generateDifferentiationDataStructures(context, indices, derivative);
}

/// Returns a flag that indicates whether the `apply` instruction should be
Expand Down Expand Up @@ -1615,7 +1625,7 @@ bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) {
}

/// Takes an `apply` instruction and adds its linear map function to the
/// linear map struct if it's active.
/// linear map struct if it is active.
void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
const SILAutoDiffIndices &indices) {
SmallVector<SILValue, 4> allResults;
Expand All @@ -1627,8 +1637,7 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,

// Check if there are any active results or arguments. If not, skip
// this instruction.
auto hasActiveResults = llvm::any_of(
allResults, [&](SILValue res) {
auto hasActiveResults = llvm::any_of(allResults, [&](SILValue res) {
return activityInfo.isActive(res, indices);
});
auto hasActiveArguments = llvm::any_of(
Expand All @@ -1645,9 +1654,12 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
// parameters from the function type.
// - Otherwise, use the active parameters.
AutoDiffIndexSubset *parameters;
auto originalFnSubstTy = ai->getSubstCalleeType();
if (originalFnSubstTy->isDifferentiable()) {
parameters = originalFnSubstTy->getDifferentiationParameterIndices();
auto origFnSubstTy = ai->getSubstCalleeType();
auto remappedOrigFnSubstTy =
remapTypeInDerivative(SILType::getPrimitiveObjectType(origFnSubstTy))
.castTo<SILFunctionType>();
if (remappedOrigFnSubstTy->isDifferentiable()) {
parameters = remappedOrigFnSubstTy->getDifferentiationParameterIndices();
} else {
parameters = AutoDiffIndexSubset::get(
original->getASTContext(),
Expand All @@ -1660,29 +1672,29 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
// Check for non-differentiable original function type.
auto checkNondifferentiableOriginalFunctionType =
[&](CanSILFunctionType origFnTy) {
// Check and diagnose non-differentiable arguments.
// Check non-differentiable arguments.
for (unsigned paramIndex : range(origFnTy->getNumParameters())) {
auto remappedParamType =
origFnTy->getParameters()[paramIndex].getSILStorageType();
if (applyIndices.isWrtParameter(paramIndex) &&
!origFnTy->getParameters()[paramIndex]
.getSILStorageType()
.isDifferentiable(builder.getModule()))
!remappedParamType.isDifferentiable(derivative->getModule()))
return true;
}
// Check non-differentiable results.
if (!origFnTy->getResults()[applyIndices.source]
.getSILStorageType()
.isDifferentiable(builder.getModule()))
auto remappedResultType =
origFnTy->getResults()[applyIndices.source].getSILStorageType();
if (!remappedResultType.isDifferentiable(derivative->getModule()))
return true;
return false;
};
if (checkNondifferentiableOriginalFunctionType(originalFnSubstTy))
if (checkNondifferentiableOriginalFunctionType(remappedOrigFnSubstTy))
return;

AutoDiffAssociatedFunctionKind assocFnKind(kind);
auto assocFnType = originalFnSubstTy->getAutoDiffAssociatedFunctionType(
auto assocFnType = remappedOrigFnSubstTy->getAutoDiffAssociatedFunctionType(
parameters, source, /*differentiationOrder*/ 1, assocFnKind,
context.getTypeConverter(),
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
LookUpConformanceInModule(derivative->getModule().getSwiftModule()));

auto assocFnResultTypes =
assocFnType->getAllResultsType().castTo<TupleType>();
Expand Down Expand Up @@ -1745,8 +1757,6 @@ void LinearMapInfo::generateDifferentiationDataStructures(
for (auto &origBB : *original) {
for (auto &inst : origBB) {
if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
LLVM_DEBUG(getADDebugStream()
<< "Adding linear map struct field for " << *ai);
// Check for active 'inout' arguments.
bool isInout = false;
auto paramInfos = ai->getSubstCalleeConv().getParameters();
Expand All @@ -1761,13 +1771,17 @@ void LinearMapInfo::generateDifferentiationDataStructures(
}
}
if (isInout)
break;
continue;

// Add linear map field to struct for active `apply` instructions.
// Skip array literal intrinsic applications since array literal
// initialization is linear and handled separately.
if (!shouldDifferentiateApplyInst(ai) || isArrayLiteralIntrinsic(ai))
continue;

// Add linear map to struct for active instructions.
// Do not add it for array functions since those are already linear
// and we don't need to add it to the struct.
if (shouldDifferentiateApplyInst(ai) && !isArrayLiteralIntrinsic(ai))
addLinearMapToStruct(context, ai, indices);
LLVM_DEBUG(getADDebugStream() << "Adding linear map struct field for "
<< *ai);
addLinearMapToStruct(context, ai, indices);
}
}
}
Expand Down Expand Up @@ -3327,8 +3341,8 @@ class VJPEmitter final
context(context), original(original), attr(attr), vjp(vjp),
invoker(invoker), activityInfo(getActivityInfo(
context, original, attr->getIndices(), vjp)),
pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original,
vjp, attr->getIndices(), activityInfo, getBuilder()) {
pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original, vjp,
attr->getIndices(), activityInfo) {
// Create empty pullback function.
pullback = createEmptyPullback();
context.getGeneratedFunctions().push_back(pullback);
Expand Down Expand Up @@ -4156,7 +4170,7 @@ class JVPEmitter final
//--------------------------------------------------------------------------//

/// The builder for the differential function.
SILBuilder differentialAndBuilder;
SILBuilder differentialBuilder;

/// Mapping from original basic blocks to corresponding differential basic
/// blocks.
Expand Down Expand Up @@ -4196,9 +4210,9 @@ class JVPEmitter final
ASTContext &getASTContext() const { return jvp->getASTContext(); }
SILModule &getModule() const { return jvp->getModule(); }
const SILAutoDiffIndices &getIndices() const { return attr->getIndices(); }
SILBuilder &getDifferentialBuilder() { return differentialAndBuilder; }
SILBuilder &getDifferentialBuilder() { return differentialBuilder; }
SILFunction &getDifferential() {
return differentialAndBuilder.getFunction();
return differentialBuilder.getFunction();
}
SILArgument *getDifferentialStructArgument(SILBasicBlock *origBB) {
#ifndef NDEBUG
Expand Down Expand Up @@ -4242,15 +4256,6 @@ class JVPEmitter final
return activityInfo;
}

static SILBuilder
initializeDifferentialAndBuilder(ADContext &context, SILFunction *original,
SILDifferentiableAttr *attr,
LinearMapInfo *linearMapInfo) {
auto *differential =
createEmptyDifferential(context, original, attr, linearMapInfo);
return SILBuilder(*differential);
}

//--------------------------------------------------------------------------//
// Differential struct mapping
//--------------------------------------------------------------------------//
Expand Down Expand Up @@ -5226,9 +5231,9 @@ class JVPEmitter final
invoker(invoker), activityInfo(getActivityInfo(
context, original, attr->getIndices(), jvp)),
differentialInfo(context, AutoDiffLinearMapKind::Differential, original,
jvp, attr->getIndices(), activityInfo, getBuilder()),
differentialAndBuilder(initializeDifferentialAndBuilder(
context, original, attr, &differentialInfo)),
jvp, attr->getIndices(), activityInfo),
differentialBuilder(SILBuilder(*createEmptyDifferential(
context, original, attr, &differentialInfo))),
diffLocalAllocBuilder(getDifferential()) {
// Create empty differential function.
context.getGeneratedFunctions().push_back(&getDifferential());
Expand Down
19 changes: 19 additions & 0 deletions test/AutoDiff/generics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,25 @@ extension TF_697_Sequential: TF_697_Layer where Layer1: TF_697_Layer {
}
}

// TF-817: Test remapping `apply` callee types in derivative function context.
struct TF_817<T> {
func foo(_ index: Int) -> T {
fatalError()
}
}
extension TF_817: Differentiable where T: Differentiable {
@differentiating(foo)
func vjpFoo(index: Int) -> (value: T, pullback: (T.TangentVector) -> (TangentVector)) {
fatalError()
}
}
extension TF_817 {
@differentiable(wrt: self where T: Differentiable)
public func test(index: Int) -> T {
return self.foo(0) // crash happened here
}
}

// Test layout requirements.

// The layout requirement is "contextual": the requirement is not on `T`, the
Expand Down