Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 28 additions & 57 deletions include/swift/SILOptimizer/Differentiation/LinearMapInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,26 +69,23 @@ class LinearMapInfo {
/// Differentiation indices of the function.
const AutoDiffConfig config;

/// Mapping from original basic blocks to linear map structs.
llvm::DenseMap<SILBasicBlock *, StructDecl *> linearMapStructs;
/// Mapping from original basic blocks to linear map tuple types.
llvm::DenseMap<SILBasicBlock *, TupleType *> linearMapTuples;

/// Mapping from original basic blocks to branching trace enums.
/// For pullbacks: these are predecessor enums.
/// For differentials: these are successor enums.
llvm::DenseMap<SILBasicBlock *, EnumDecl *> branchingTraceDecls;

/// Mapping from `apply` instructions in the original function to the
/// corresponding linear map field declaration in the linear map struct.
llvm::DenseMap<ApplyInst *, VarDecl *> linearMapFieldMap;
/// corresponding linear map tuple type index.
llvm::DenseMap<ApplyInst *, unsigned> linearMapIndexMap;

/// Mapping from predecessor-successor basic block pairs in the original
/// function to the corresponding branching trace enum case.
llvm::DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, EnumElementDecl *>
branchingTraceEnumCases;

/// Mapping from linear map structs to their branching trace enum fields.
llvm::DenseMap<StructDecl *, VarDecl *> linearMapStructEnumFields;

/// Blocks in a loop.
llvm::SmallSetVector<SILBasicBlock *, 4> blocksInLoop;

Expand All @@ -102,37 +99,21 @@ class LinearMapInfo {
/// Remaps the given type into the derivative function's context.
SILType remapTypeInDerivative(SILType ty);

/// Adds a `VarDecl` member with the given name and type to the given nominal
/// declaration.
VarDecl *addVarDecl(NominalTypeDecl *nominal, StringRef name, Type type);

/// Retrieves the file unit that contains implicit declarations in the
/// current Swift module.
SynthesizedFileUnit &getSynthesizedFile() { return synthesizedFile; }

/// Computes and sets the access level for the given nominal type, given the
/// original function linkage.
void computeAccessLevel(NominalTypeDecl *nominal, SILLinkage originalLinkage);

/// Creates an enum declaration with the given JVP/VJP generic signature,
/// whose cases represent the predecessors/successors of the given original
/// block.
EnumDecl *createBranchingTraceDecl(SILBasicBlock *originalBB,
CanGenericSignature genericSig,
SILLoopInfo *loopInfo);
CanGenericSignature genericSig);
void populateBranchingTraceDecl(SILBasicBlock *originalBB,
SILLoopInfo *loopInfo);

/// Creates a struct declaration with the given JVP/VJP generic signature, for
/// storing the linear map values and predecessor/successor basic block of the
/// given original block.
StructDecl *createLinearMapStruct(SILBasicBlock *originalBB,
CanGenericSignature genericSig);

/// Adds a linear map field to the linear map struct.
VarDecl *addLinearMapDecl(ApplyInst *ai, SILType linearMapType);

/// Given an `apply` instruction, conditionally adds a linear map struct field
/// for its linear map function if it is active.
void addLinearMapToStruct(ADContext &context, ApplyInst *ai);
/// Given an `apply` instruction, conditionally gets a linear map tuple field
/// AST type for its linear map function if it is active.
Type getLinearMapType(ADContext &context, ApplyInst *ai);

/// Generates linear map struct and branching enum declarations for the given
/// function. Linear map structs are populated with linear map fields and a
Expand All @@ -153,22 +134,20 @@ class LinearMapInfo {
const DifferentiableActivityInfo &activityInfo,
SILLoopInfo *loopInfo);

/// Returns the linear map struct associated with the given original block.
StructDecl *getLinearMapStruct(SILBasicBlock *origBB) const {
return linearMapStructs.lookup(origBB);
/// Returns the linear map tuple associated with the given original block.
TupleType *getLinearMapTupleType(SILBasicBlock *origBB) const {
return linearMapTuples.lookup(origBB);
}

/// Returns the lowered SIL type of the linear map struct associated with the
/// Returns the lowered SIL type of the linear map tuple associated with the
/// given original block.
SILType getLinearMapStructLoweredType(SILBasicBlock *origBB) const {
SILType getLinearMapTupleLoweredType(SILBasicBlock *origBB) const {
auto derivativeGenSig =
derivative->getLoweredFunctionType()->getSubstGenericSignature();
auto *linMapStruct = getLinearMapStruct(origBB);
auto linMapStructType =
linMapStruct->getDeclaredInterfaceType()->getReducedType(
derivativeGenSig);
Lowering::AbstractionPattern pattern(derivativeGenSig, linMapStructType);
return typeConverter.getLoweredType(pattern, linMapStructType,
auto linMapTupleType =
getLinearMapTupleType(origBB)->getReducedType(derivativeGenSig);
Lowering::AbstractionPattern pattern(derivativeGenSig, linMapTupleType);
return typeConverter.getLoweredType(pattern, linMapTupleType,
TypeExpansionContext::minimal());
}

Expand Down Expand Up @@ -199,29 +178,21 @@ class LinearMapInfo {
return branchingTraceEnumCases.lookup({origPredBB, origSuccBB});
}

/// Returns the mapping from linear map structs to their branching trace enum
/// fields.
llvm::DenseMap<StructDecl *, VarDecl *> &getLinearMapStructEnumFields() {
return linearMapStructEnumFields;
}

/// Returns the branching trace enum field for the linear map struct of the
/// given original block.
VarDecl *lookUpLinearMapStructEnumField(SILBasicBlock *origBB) const {
auto *linearMapStruct = getLinearMapStruct(origBB);
return linearMapStructEnumFields.lookup(linearMapStruct);
}

/// Finds the linear map declaration in the pullback struct for the given
/// Finds the linear map index in the pullback tuple for the given
/// `apply` instruction in the original function.
VarDecl *lookUpLinearMapDecl(ApplyInst *ai) const {
unsigned lookUpLinearMapIndex(ApplyInst *ai) const {
assert(ai->getFunction() == original);
auto lookup = linearMapFieldMap.find(ai);
assert(lookup != linearMapFieldMap.end() &&
auto lookup = linearMapIndexMap.find(ai);
assert(lookup != linearMapIndexMap.end() &&
"No linear map field corresponding to the given `apply`");
return lookup->getSecond();
}

Type lookUpLinearMapType(ApplyInst *ai) const {
unsigned idx = lookUpLinearMapIndex(ai);
return getLinearMapTupleType(ai->getParentBlock())->getElement(idx).getType();
}

bool hasLoops() const {
return !blocksInLoop.empty();
}
Expand Down
31 changes: 31 additions & 0 deletions lib/SIL/IR/SILPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3449,6 +3449,36 @@ static void printSILDifferentiabilityWitnesses(
dw->print(Ctx.OS(), Ctx.printVerbose());
}

static void printSILLinearMapTypes(SILPrintContext &Ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this only for printing branching trace enums now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

const ModuleDecl *M) {
auto &OS = Ctx.OS();

PrintOptions Options = PrintOptions::printSIL();
Options.TypeDefinitions = true;
Options.VarInitializers = true;
Options.ExplodePatternBindingDecls = true;
Options.SkipImplicit = false;
Options.PrintGetSetOnRWProperties = true;
Options.PrintInSILBody = false;

SmallVector<Decl *, 32> topLevelDecls;
M->getTopLevelDecls(topLevelDecls);
for (const Decl *D : topLevelDecls) {
if (D->getDeclContext() == M)
continue;

if (!isa<StructDecl>(D) && !isa<EnumDecl>(D))
continue;

StringRef Name = cast<TypeDecl>(D)->getNameStr();
if (!Name.startswith("_AD__"))
continue;

D->print(OS, Options);
OS << "\n\n";
}
}

static void
printSILCoverageMaps(SILPrintContext &Ctx,
const SILModule::CoverageMapCollectionType &CoverageMaps) {
Expand Down Expand Up @@ -3624,6 +3654,7 @@ void SILModule::print(SILPrintContext &PrintCtx, ModuleDecl *M,
printSILGlobals(PrintCtx, getSILGlobalList());
printSILDifferentiabilityWitnesses(PrintCtx,
getDifferentiabilityWitnessList());
printSILLinearMapTypes(PrintCtx, getSwiftModule());
printSILFunctions(PrintCtx, getFunctionList());
printSILVTables(PrintCtx, getVTables());
printSILWitnessTables(PrintCtx, getWitnessTableList());
Expand Down
4 changes: 3 additions & 1 deletion lib/SIL/Parser/ParseSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2416,7 +2416,9 @@ static bool parseSILDifferentiabilityWitnessConfigAndFunction(
auto origFnType = resultOrigFn->getLoweredFunctionType();
auto *parameterIndices = IndexSubset::get(
P.Context, origFnType->getNumParameters(), rawParameterIndices);
auto *resultIndices = IndexSubset::get(P.Context, origFnType->getNumResults(),
auto *resultIndices = IndexSubset::get(P.Context,
origFnType->getNumResults() +
origFnType->getNumIndirectMutatingParameters(),
rawResultIndices);
resultConfig = AutoDiffConfig(parameterIndices, resultIndices, witnessGenSig);
return false;
Expand Down
93 changes: 39 additions & 54 deletions lib/SILOptimizer/Differentiation/JVPCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class JVPCloner::Implementation final
/// elements destructured from the linear map basic block argument. In the
/// beginning of each differential basic block, the block's differential
/// struct is destructured into the individual elements stored here.
llvm::DenseMap<VarDecl *, SILValue> differentialStructElements;
llvm::DenseMap<SILBasicBlock *, SILInstructionResultArray> differentialTupleElements;

/// An auxiliary differential local allocation builder.
TangentBuilder diffLocalAllocBuilder;
Expand All @@ -119,23 +119,17 @@ class JVPCloner::Implementation final
TangentBuilder &getDifferentialBuilder() { return differentialBuilder; }
SILFunction &getDifferential() { return differentialBuilder.getFunction(); }
SILArgument *getDifferentialStructArgument(SILBasicBlock *origBB) {
#ifndef NDEBUG
auto *diffStruct = differentialStructArguments[origBB]
->getType()
.getStructOrBoundGenericStruct();
assert(diffStruct == differentialInfo.getLinearMapStruct(origBB));
#endif
return differentialStructArguments[origBB];
}

//--------------------------------------------------------------------------//
// Differential struct mapping
// Differential tuple mapping
//--------------------------------------------------------------------------//

void initializeDifferentialStructElements(SILBasicBlock *origBB,
SILInstructionResultArray values);
void initializeDifferentialTupleElements(SILBasicBlock *origBB,
SILInstructionResultArray values);

SILValue getDifferentialStructElement(SILBasicBlock *origBB, VarDecl *field);
SILValue getDifferentialTupleElement(ApplyInst *ai);

//--------------------------------------------------------------------------//
// General utilities
Expand All @@ -158,22 +152,21 @@ class JVPCloner::Implementation final

/// Build a differential struct value for the original block corresponding to
/// the given terminator.
StructInst *buildDifferentialValueStructValue(TermInst *termInst) {
TupleInst *buildDifferentialValueStructValue(TermInst *termInst) {
assert(termInst->getFunction() == original);
auto loc = termInst->getFunction()->getLocation();
auto *origBB = termInst->getParent();
auto *jvpBB = BBMap[origBB];
assert(jvpBB && "Basic block mapping should exist");
auto *diffStruct = differentialInfo.getLinearMapStruct(origBB);
assert(diffStruct && "The differential struct should have been declared");
auto structLoweredTy = getNominalDeclLoweredType(diffStruct);
auto tupleLoweredTy =
remapType(differentialInfo.getLinearMapTupleLoweredType(origBB));
auto bbDifferentialValues = differentialValues[origBB];
if (!origBB->isEntry()) {
auto *enumArg = jvpBB->getArguments().back();
bbDifferentialValues.insert(bbDifferentialValues.begin(), enumArg);
}
return getBuilder().createStruct(loc, structLoweredTy,
bbDifferentialValues);
return getBuilder().createTuple(loc, tupleLoweredTy,
bbDifferentialValues);
}

//--------------------------------------------------------------------------//
Expand Down Expand Up @@ -438,8 +431,8 @@ class JVPCloner::Implementation final
auto *mainDifferentialStruct = diffBB->getArguments().back();
diffBuilder.setInsertionPoint(diffBB);
auto *dsi =
diffBuilder.createDestructureStruct(diffLoc, mainDifferentialStruct);
initializeDifferentialStructElements(bb, dsi->getResults());
diffBuilder.createDestructureTuple(diffLoc, mainDifferentialStruct);
initializeDifferentialTupleElements(bb, dsi->getResults());
TypeSubstCloner::visitInstructionsInBlock(bb);
}

Expand Down Expand Up @@ -667,12 +660,11 @@ class JVPCloner::Implementation final
// Add the differential function for when we create the struct we partially
// apply to the differential we are generating.
auto differential = jvpDirectResults.back();
auto *differentialDecl = differentialInfo.lookUpLinearMapDecl(ai);
auto differentialType = differentialInfo.lookUpLinearMapType(ai);
auto originalDifferentialType =
getOpType(differential->getType()).getAs<SILFunctionType>();
auto loweredDifferentialType =
getOpType(getLoweredType(differentialDecl->getInterfaceType()))
.castTo<SILFunctionType>();
getOpType(getLoweredType(differentialType)).castTo<SILFunctionType>();
// If actual differential type does not match lowered differential type,
// reabstract the differential using a thunk.
if (!loweredDifferentialType->isEqual(originalDifferentialType)) {
Expand Down Expand Up @@ -1218,9 +1210,7 @@ class JVPCloner::Implementation final
auto &diffBuilder = getDifferentialBuilder();

// Get the differential value.
auto *field = differentialInfo.lookUpLinearMapDecl(ai);
assert(field);
SILValue differential = getDifferentialStructElement(bb, field);
SILValue differential = getDifferentialTupleElement(ai);
auto differentialType = remapSILTypeInDifferential(differential->getType())
.castTo<SILFunctionType>();

Expand Down Expand Up @@ -1432,31 +1422,27 @@ JVPCloner::~JVPCloner() { delete &impl; }
// Differential struct mapping
//--------------------------------------------------------------------------//

void JVPCloner::Implementation::initializeDifferentialStructElements(
SILBasicBlock *origBB, SILInstructionResultArray values) {
auto *diffStructDecl = differentialInfo.getLinearMapStruct(origBB);
assert(diffStructDecl->getStoredProperties().size() == values.size() &&
"The number of differential struct fields must equal the number of "
void JVPCloner::Implementation::initializeDifferentialTupleElements(
SILBasicBlock *origBB, SILInstructionResultArray values) {
auto *diffTupleTyple = differentialInfo.getLinearMapTupleType(origBB);
assert(diffTupleTyple->getNumElements() == values.size() &&
"The number of differential tuple fields must equal the number of "
"differential struct element values");
for (auto pair : llvm::zip(diffStructDecl->getStoredProperties(), values)) {
assert(std::get<1>(pair)->getOwnershipKind() != OwnershipKind::Guaranteed &&
"Differential struct elements must be @owned");
auto insertion = differentialStructElements.insert(
{std::get<0>(pair), std::get<1>(pair)});
(void)insertion;
assert(insertion.second &&
"A differential struct element mapping already exists!");
}
auto res = differentialTupleElements.insert({origBB, values});
(void)res;
assert(res.second && "A pullback struct element already exists!");
}

SILValue
JVPCloner::Implementation::getDifferentialStructElement(SILBasicBlock *origBB,
VarDecl *field) {
assert(differentialInfo.getLinearMapStruct(origBB) ==
cast<StructDecl>(field->getDeclContext()));
assert(differentialStructElements.count(field) &&
"Differential struct element for this field does not exist!");
return differentialStructElements.lookup(field);
/// Returns the differential tuple element value corresponding to the given
/// original block and apply inst.
SILValue JVPCloner::Implementation::getDifferentialTupleElement(ApplyInst *ai) {
unsigned idx = differentialInfo.lookUpLinearMapIndex(ai);
assert((idx > 0 || (idx == 0 && ai->getParentBlock()->isEntry())) &&
"impossible linear map index");
auto values = differentialTupleElements.lookup(ai->getParentBlock());
assert(idx < values.size() &&
"differential tuple element for this apply does not exist!");
return values[idx];
}

//--------------------------------------------------------------------------//
Expand All @@ -1481,9 +1467,9 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
createEntryArguments(&differential);
auto *lastArg = diffBB->getArguments().back();
#ifndef NDEBUG
auto diffStructLoweredType = remapSILTypeInDifferential(
differentialInfo.getLinearMapStructLoweredType(&origBB));
assert(lastArg->getType() == diffStructLoweredType);
auto diffTupleLoweredType = remapSILTypeInDifferential(
differentialInfo.getLinearMapTupleLoweredType(&origBB));
assert(lastArg->getType() == diffTupleLoweredType);
#endif
differentialStructArguments[&origBB] = lastArg;
}
Expand Down Expand Up @@ -1671,10 +1657,9 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
// Accept a differential struct in the differential parameter list. This is
// the returned differential's closure context.
auto *origEntry = original->getEntryBlock();
auto *dfStruct = linearMapInfo->getLinearMapStruct(origEntry);
auto dfStructType =
dfStruct->getDeclaredInterfaceType()->getReducedType(witnessCanGenSig);
dfParams.push_back({dfStructType, ParameterConvention::Direct_Owned});
auto dfTupleType =
linearMapInfo->getLinearMapTupleLoweredType(origEntry).getASTType();
dfParams.push_back({dfTupleType, ParameterConvention::Direct_Owned});

Mangle::DifferentiationMangler mangler;
auto diffName = mangler.mangleLinearMap(
Expand Down
Loading