Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 85 additions & 104 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,11 @@ class DifferentiableActivityInfo {
void setUseful(SILValue value, unsigned dependentVariableIndex);
void setUsefulAcrossArrayInitialization(SILValue value,
unsigned dependentVariableIndex);
void recursivelySetVaried(SILValue value, unsigned independentVariableIndex);
/// Marks the given value as "varied" and recursively propagates "varied"
/// inwards (to operands) through projections. Skips any `@noDerivative`
/// struct field projections.
void propagateVariedInwardsThroughProjections(
SILValue value, unsigned independentVariableIndex);
void propagateUsefulThroughBuffer(SILValue value,
unsigned dependentVariableIndex);

Expand Down Expand Up @@ -1875,27 +1879,18 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
}
}
}
// Handle `store`.
else if (auto *si = dyn_cast<StoreInst>(&inst)) {
if (isVaried(si->getSrc(), i))
recursivelySetVaried(si->getDest(), i);
}
// Handle `store_borrow`.
else if (auto *si = dyn_cast<StoreBorrowInst>(&inst)) {
if (isVaried(si->getSrc(), i))
recursivelySetVaried(si->getDest(), i);
}
// Handle `copy_addr`.
else if (auto *cai = dyn_cast<CopyAddrInst>(&inst)) {
if (isVaried(cai->getSrc(), i))
recursivelySetVaried(cai->getDest(), i);
}
// Handle `unconditional_checked_cast_addr`.
else if (auto *uccai =
dyn_cast<UnconditionalCheckedCastAddrInst>(&inst)) {
if (isVaried(uccai->getSrc(), i))
recursivelySetVaried(uccai->getDest(), i);
// Handle store-like instructions:
// `store`, `store_borrow`, `copy_addr`, `unconditional_checked_cast`
#define PROPAGATE_VARIED_THROUGH_STORE(INST) \
else if (auto *si = dyn_cast<INST##Inst>(&inst)) { \
if (isVaried(si->getSrc(), i)) \
propagateVariedInwardsThroughProjections(si->getDest(), i); \
}
PROPAGATE_VARIED_THROUGH_STORE(Store)
PROPAGATE_VARIED_THROUGH_STORE(StoreBorrow)
PROPAGATE_VARIED_THROUGH_STORE(CopyAddr)
PROPAGATE_VARIED_THROUGH_STORE(UnconditionalCheckedCastAddr)
#undef PROPAGATE_VARIED_THROUGH_STORE
// Handle `tuple_element_addr`.
else if (auto *teai = dyn_cast<TupleElementAddrInst>(&inst)) {
if (isVaried(teai->getOperand(), i)) {
Expand All @@ -1908,24 +1903,19 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
setVaried(teai, i);
}
}

// Handle `struct_extract` and `struct_element_addr` instructions.
// - If the field is marked `@noDerivative`, do not set the result as varied
// because it is not in the set of differentiable variables.
// - Otherwise, propagate variedness from operand to result as usual.
// Handle `struct_extract` and `struct_element_addr` instructions.
// - If the field is marked `@noDerivative`, do not set the result as
// varied because it is not in the set of differentiable variables.
// - Otherwise, propagate variedness from operand to result as usual.
#define PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(INST) \
else if (auto *sei = dyn_cast<INST##Inst>(&inst)) { \
if (isVaried(sei->getOperand(), i)) { \
auto hasNoDeriv = sei->getField()->getAttrs() \
.hasAttribute<NoDerivativeAttr>(); \
if (!hasNoDeriv) \
setVaried(sei, i); \
} \
}
PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(StructExtract)
PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(StructElementAddr)
#undef VISIT_STRUCT_ELEMENT_INNS

else if (auto *sei = dyn_cast<INST##Inst>(&inst)) { \
if (isVaried(sei->getOperand(), i) && \
!sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \
setVaried(sei, i); \
}
PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(StructExtract)
PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(StructElementAddr)
#undef PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION
// Handle `br`.
else if (auto *bi = dyn_cast<BranchInst>(&inst)) {
for (auto &op : bi->getAllOperands())
Expand All @@ -1947,12 +1937,10 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
}
// Handle `switch_enum`.
else if (auto *sei = dyn_cast<SwitchEnumInst>(&inst)) {
if (isVaried(sei->getOperand(), i)) {
if (isVaried(sei->getOperand(), i))
for (auto *succBB : sei->getSuccessorBlocks())
for (auto *arg : succBB->getArguments())
setVaried(arg, i);
// Default block cannot have arguments.
}
}
// Handle everything else.
else {
Expand Down Expand Up @@ -2002,27 +1990,34 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
if (paramInfos[i].isIndirectInOut())
checkAndSetUseful(ai->getArgumentsWithoutIndirectResults()[i]);
}
// Handle `store`.
else if (auto *si = dyn_cast<StoreInst>(&inst)) {
if (isUseful(si->getDest(), i))
setUseful(si->getSrc(), i);
}
// Handle `store_borrow`.
else if (auto *sbi = dyn_cast<StoreBorrowInst>(&inst)) {
if (isUseful(sbi->getDest(), i))
setUseful(sbi->getSrc(), i);
}
// Handle `copy_addr`.
else if (auto *cai = dyn_cast<CopyAddrInst>(&inst)) {
if (isUseful(cai->getDest(), i))
propagateUsefulThroughBuffer(cai->getSrc(), i);
// Handle store-like instructions:
// `store`, `store_borrow`, `copy_addr`, `unconditional_checked_cast`
#define PROPAGATE_USEFUL_THROUGH_STORE(INST, PROPAGATE) \
else if (auto *si = dyn_cast<INST##Inst>(&inst)) { \
if (isUseful(si->getDest(), i)) \
PROPAGATE(si->getSrc(), i); \
}
// Handle `unconditional_checked_cast_addr`.
else if (auto *uccai =
dyn_cast<UnconditionalCheckedCastAddrInst>(&inst)) {
if (isUseful(uccai->getDest(), i))
propagateUsefulThroughBuffer(uccai->getSrc(), i);
PROPAGATE_USEFUL_THROUGH_STORE(Store, setUseful)
PROPAGATE_USEFUL_THROUGH_STORE(StoreBorrow, setUseful)
PROPAGATE_USEFUL_THROUGH_STORE(CopyAddr, propagateUsefulThroughBuffer)
PROPAGATE_USEFUL_THROUGH_STORE(UnconditionalCheckedCastAddr,
propagateUsefulThroughBuffer)
#undef PROPAGATE_USEFUL_THROUGH_STORE
// Handle struct element extraction, skipping `@noDerivative` fields:
// `struct_extract`, `struct_element_addr`.
#define PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(INST, PROPAGATE) \
else if (auto *sei = dyn_cast<INST##Inst>(&inst)) { \
if (isUseful(sei, i)) { \
auto hasNoDeriv = sei->getField()->getAttrs() \
.hasAttribute<NoDerivativeAttr>(); \
if (!hasNoDeriv) \
PROPAGATE(sei->getOperand(), i); \
} \
}
PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructExtract, setUseful)
PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructElementAddr,
propagateUsefulThroughBuffer)
#undef PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION
// Handle everything else.
else if (llvm::any_of(inst.getResults(),
[&](SILValue res) { return isUseful(res, i); })) {
Expand Down Expand Up @@ -2100,15 +2095,23 @@ void DifferentiableActivityInfo::setUseful(SILValue value,
setUsefulAcrossArrayInitialization(value, dependentVariableIndex);
}

void DifferentiableActivityInfo::recursivelySetVaried(
void DifferentiableActivityInfo::propagateVariedInwardsThroughProjections(
SILValue value, unsigned independentVariableIndex) {
setVaried(value, independentVariableIndex);
if (auto *inst = value->getDefiningInstruction()) {
if (auto *ai = dyn_cast<ApplyInst>(inst))
#define SKIP_NODERIVATIVE(INST) \
if (auto *sei = dyn_cast<INST##Inst>(value)) \
if (sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \
return;
for (auto &op : inst->getAllOperands())
recursivelySetVaried(op.get(), independentVariableIndex);
}
SKIP_NODERIVATIVE(StructExtract)
SKIP_NODERIVATIVE(StructElementAddr)
#undef SKIP_NODERIVATIVE
setVaried(value, independentVariableIndex);
auto *inst = value->getDefiningInstruction();
if (!inst || isa<ApplyInst>(inst))
return;
// Standard propagation.
for (auto &op : inst->getAllOperands())
propagateVariedInwardsThroughProjections(
op.get(), independentVariableIndex);
}

void DifferentiableActivityInfo::propagateUsefulThroughBuffer(
Expand All @@ -2125,14 +2128,25 @@ void DifferentiableActivityInfo::propagateUsefulThroughBuffer(
propagateUsefulThroughBuffer(operand.get(), dependentVariableIndex);
// Recursively propagate usefulness through users that are projections or
// `begin_access` instructions.
for (auto use : value->getUses())
for (auto res : use->getUser()->getResults())
for (auto use : value->getUses()) {
for (auto res : use->getUser()->getResults()) {
#define SKIP_NODERIVATIVE(INST) \
if (auto *sei = dyn_cast<INST##Inst>(res)) \
if (sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \
continue;
SKIP_NODERIVATIVE(StructExtract)
SKIP_NODERIVATIVE(StructElementAddr)
#undef SKIP_NODERIVATIVE
if (Projection::isAddressProjection(res) || isa<BeginAccessInst>(res))
propagateUsefulThroughBuffer(res, dependentVariableIndex);
}
}
}

bool DifferentiableActivityInfo::isVaried(
SILValue value, unsigned independentVariableIndex) const {
assert(independentVariableIndex < variedValueSets.size() &&
"Independent variable index out of range");
auto &set = variedValueSets[independentVariableIndex];
return set.count(value);
}
Expand All @@ -2147,6 +2161,8 @@ bool DifferentiableActivityInfo::isVaried(

bool DifferentiableActivityInfo::isUseful(
SILValue value, unsigned dependentVariableIndex) const {
assert(dependentVariableIndex < usefulValueSets.size() &&
"Dependent variable index out of range");
auto &set = usefulValueSets[dependentVariableIndex];
return set.count(value);
}
Expand Down Expand Up @@ -4534,8 +4550,6 @@ class JVPEmitter final
auto loc = si->getLoc();
auto tanValSrc = materializeTangent(getTangentValue(si->getSrc()), loc);
auto &tanValDest = getTangentBuffer(si->getParent(), si->getDest());
if (errorOccurred)
return;
diffBuilder.emitStoreValueOperation(
loc, tanValSrc, tanValDest, si->getOwnershipQualifier());
}
Expand All @@ -4548,8 +4562,6 @@ class JVPEmitter final
auto loc = sbi->getLoc();
auto tanValSrc = materializeTangent(getTangentValue(sbi->getSrc()), loc);
auto &tanValDest = getTangentBuffer(sbi->getParent(), sbi->getDest());
if (errorOccurred)
return;
diffBuilder.createStoreBorrow(loc, tanValSrc, tanValDest);
}

Expand All @@ -4569,8 +4581,6 @@ class JVPEmitter final
auto *bb = cai->getParent();
auto &tanSrc = getTangentBuffer(bb, cai->getSrc());
auto tanDest = getTangentBuffer(bb, cai->getDest());
if (errorOccurred)
return;

diffBuilder.createCopyAddr(loc, tanSrc, tanDest, cai->isTakeOfSrc(),
cai->isInitializationOfDest());
Expand All @@ -4586,8 +4596,6 @@ class JVPEmitter final
auto *bb = uccai->getParent();
auto &tanSrc = getTangentBuffer(bb, uccai->getSrc());
auto tanDest = getTangentBuffer(bb, uccai->getDest());
if (errorOccurred)
return;

diffBuilder.createUnconditionalCheckedCastAddr(
loc, tanSrc, tanSrc->getType().getASTType(), tanDest,
Expand Down Expand Up @@ -6051,17 +6059,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
if (!insertion.second) // not inserted
return insertion.first->getSecond();

// Diagnose `struct_element_addr` instructions to `@noDerivative` fields.
if (auto *seai = dyn_cast<StructElementAddrInst>(originalBuffer)) {
if (seai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) {
getContext().emitNondifferentiabilityError(
originalBuffer, getInvoker(),
diag::autodiff_noderivative_stored_property);
errorOccurred = true;
return (bufferMap[{origBB, originalBuffer}] = SILValue());
}
}

// If the original buffer is a projection, return a corresponding projection
// into the adjoint buffer.
if (auto adjProj = getAdjointProjection(origBB, originalBuffer))
Expand Down Expand Up @@ -6099,8 +6096,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
assert(originalBuffer->getFunction() == &getOriginal());
assert(rhsBufferAccess->getFunction() == &getPullback());
auto adjointBuffer = getAdjointBuffer(origBB, originalBuffer);
if (errorOccurred)
return;
accumulateIndirect(adjointBuffer, rhsBufferAccess, loc);
}

Expand Down Expand Up @@ -6360,8 +6355,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
retElts.push_back(newVal);
} else {
auto adjBuf = getAdjointBuffer(origEntry, origParam);
if (errorOccurred)
return;
indParamAdjoints.push_back(adjBuf);
}
};
Expand Down Expand Up @@ -6830,8 +6823,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
seed = materializeAdjoint(getAdjointValue(bb, origResult), loc);
} else {
seed = getAdjointBuffer(bb, origResult);
if (errorOccurred)
return;
}

// Create allocations for pullback indirect results.
Expand Down Expand Up @@ -6887,8 +6878,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
builder.emitDestroyAddrAndFold(loc, tan);
} else {
if (origArg->getType().isAddress()) {
if (errorOccurred)
return;
auto *tmpBuf = builder.createAllocStack(loc, tan->getType());
builder.emitStoreValueOperation(loc, tan, tmpBuf,
StoreOwnershipQualifier::Init);
Expand Down Expand Up @@ -7182,8 +7171,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
StoreOwnershipQualifier::Init);
// Accumulate the adjoint value in the local buffer into the adjoint buffer.
addToAdjointBuffer(bb, inst->getOperand(0), localBuf, inst->getLoc());
if (errorOccurred)
return;
builder.emitDestroyAddr(inst->getLoc(), localBuf);
builder.createDeallocStack(inst->getLoc(), localBuf);
}
Expand All @@ -7196,8 +7183,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
void visitStoreOperation(SILBasicBlock *bb, SILLocation loc,
SILValue origSrc, SILValue origDest) {
auto &adjBuf = getAdjointBuffer(bb, origDest);
if (errorOccurred)
return;
auto bufType = remapType(adjBuf->getType());
auto adjVal = builder.emitLoadValueOperation(
loc, adjBuf, LoadOwnershipQualifier::Take);
Expand All @@ -7220,8 +7205,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
void visitCopyAddrInst(CopyAddrInst *cai) {
auto *bb = cai->getParent();
auto &adjDest = getAdjointBuffer(bb, cai->getDest());
if (errorOccurred)
return;
auto destType = remapType(adjDest->getType());
addToAdjointBuffer(bb, cai->getSrc(), adjDest, cai->getLoc());
builder.emitDestroyAddrAndFold(cai->getLoc(), adjDest);
Expand Down Expand Up @@ -7275,8 +7258,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
auto *bb = uccai->getParent();
auto &adjDest = getAdjointBuffer(bb, uccai->getDest());
auto &adjSrc = getAdjointBuffer(bb, uccai->getSrc());
if (errorOccurred)
return;
auto destType = remapType(adjDest->getType());
auto castBuf = builder.createAllocStack(uccai->getLoc(), adjSrc->getType());
builder.createUnconditionalCheckedCastAddr(
Expand Down
4 changes: 1 addition & 3 deletions test/AutoDiff/autodiff_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,9 @@ struct NoDerivativeProperty : Differentiable {
var x: Float
@noDerivative var y: Float
}
// expected-error @+1 {{function is not differentiable}}
_ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) { s -> Float in
var tmp = s
// expected-note @+1 {{cannot differentiate through a '@noDerivative' stored property; do you want to use 'withoutDerivative(at:)'?}}
tmp.y = tmp.x
tmp.y = tmp.x // No diagnostics expected.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big fix for semantics 👍

return tmp.x
}
_ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) { s in
Expand Down
15 changes: 15 additions & 0 deletions test/AutoDiff/simple_math.swift
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,21 @@ SimpleMathTests.test("StructGeneric") {
expectEqual(405, gradient(at: 3, in: fifthPower))
}

SimpleMathTests.test("StructWithNoDerivativeProperty") {
struct NoDerivativeProperty : Differentiable {
var x: Float
@noDerivative var y: Float
}
expectEqual(
NoDerivativeProperty.TangentVector(x: 1),
gradient(at: NoDerivativeProperty(x: 1, y: 1)) { s -> Float in
var tmp = s
tmp.y = tmp.x
return tmp.x
}
)
}

SimpleMathTests.test("SubsetIndices") {
func grad(_ lossFunction: @differentiable (Float, Float) -> Float) -> Float {
return gradient(at: 1) { x in lossFunction(x * x, 10.0) }
Expand Down