Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "swift/AST/TypeCheckRequests.h"
#include "swift/SIL/FormalLinkage.h"
#include "swift/SIL/LoopInfo.h"
#include "swift/SIL/PrettyStackTrace.h"
#include "swift/SIL/Projection.h"
#include "swift/SIL/SILBuilder.h"
#include "swift/SIL/TypeSubstCloner.h"
Expand Down Expand Up @@ -3343,6 +3344,7 @@ class VJPEmitter final
context, original, attr->getIndices(), vjp)),
pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original, vjp,
attr->getIndices(), activityInfo) {
PrettyStackTraceSILFunction trace("generating VJP for", original);
// Create empty pullback function.
pullback = createEmptyPullback();
context.getGeneratedFunctions().push_back(pullback);
Expand Down Expand Up @@ -5234,6 +5236,8 @@ class JVPEmitter final
differentialBuilder(SILBuilder(*createEmptyDifferential(
context, original, attr, &differentialInfo))),
diffLocalAllocBuilder(getDifferential()) {
PrettyStackTraceSILFunction trace("generating JVP and differential for",
original);
// Create empty differential function.
context.getGeneratedFunctions().push_back(&getDifferential());
}
Expand Down Expand Up @@ -5787,6 +5791,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
explicit PullbackEmitter(VJPEmitter &vjpEmitter)
: vjpEmitter(vjpEmitter), builder(getPullback()),
localAllocBuilder(getPullback()) {
PrettyStackTraceSILFunction trace("generating pullback for",
&getOriginal());
// Get dominance and post-order info for the original function.
auto &passManager = getContext().getPassManager();
auto *domAnalysis = passManager.getAnalysis<DominanceAnalysis>();
Expand Down Expand Up @@ -7969,6 +7975,14 @@ static SILFunction *createEmptyJVP(
bool ADContext::processDifferentiableAttribute(
SILFunction *original, SILDifferentiableAttr *attr,
DifferentiationInvoker invoker) {
std::string traceMessage;
llvm::raw_string_ostream OS(traceMessage);
OS << "processing `[differentiable ";
attr->print(OS);
OS << "]` attribute on";
OS.flush();
PrettyStackTraceSILFunction trace(traceMessage.c_str(), original);

auto &module = getModule();
// Try to look up JVP only if attribute specifies JVP name or if original
// function is an external declaration. If JVP function cannot be found,
Expand Down Expand Up @@ -8750,17 +8764,21 @@ void ADContext::foldDifferentiableFunctionExtraction(

bool ADContext::processDifferentiableFunctionInst(
DifferentiableFunctionInst *dfi) {
PrettyStackTraceSILNode dfiTrace("canonicalizing `differentiable_function`",
cast<SILInstruction>(dfi));
PrettyStackTraceSILFunction fnTrace("...in", dfi->getFunction());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this pattern of PrettyStackTraceSILNode followed by PrettyStackTraceSILFunction is used because PrettyStackTraceSILNode by itself does not print any info about the containing SILFunction.

I copied this technique from elsewhere in the codebase. It's less intrusive than adding a flag to PrettyStackTraceSILNode for printing SIL function info.

Example stack trace:

Stack dump:
...
3.	While canonicalizing `differentiable_function` SIL node   %18 = differentiable_function [parameters 0] %17 : $@callee_guaranteed (Float, Int) -> Float // users: %22, %19
4.	While ...in SIL function "@main".

LLVM_DEBUG({
auto &s = getADDebugStream() << "Processing DifferentiableFunctionInst:\n";
dfi->printInContext(s);
});

// If `dfi` already has derivative functions, do not process.
if (dfi->hasDerivativeFunctions())
return false;

SILFunction *parent = dfi->getFunction();
auto loc = dfi->getLoc();
SILBuilder builder(dfi);

auto differentiableFnValue =
promoteToDifferentiableFunction(dfi, builder, loc, dfi);
// Mark `dfi` as processed so that it won't be reprocessed after deletion.
Expand Down