diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 81c1e8b214e81..42a1811cc6502 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -1400,7 +1400,11 @@ class DifferentiableActivityInfo { void setUseful(SILValue value, unsigned dependentVariableIndex); void setUsefulAcrossArrayInitialization(SILValue value, unsigned dependentVariableIndex); - void recursivelySetVaried(SILValue value, unsigned independentVariableIndex); + /// Marks the given value as "varied" and recursively propagates "varied" + /// inwards (to operands) through projections. Skips any `@noDerivative` + /// struct field projections. + void propagateVariedInwardsThroughProjections( + SILValue value, unsigned independentVariableIndex); void propagateUsefulThroughBuffer(SILValue value, unsigned dependentVariableIndex); @@ -1875,27 +1879,18 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di, } } } - // Handle `store`. - else if (auto *si = dyn_cast(&inst)) { - if (isVaried(si->getSrc(), i)) - recursivelySetVaried(si->getDest(), i); - } - // Handle `store_borrow`. - else if (auto *si = dyn_cast(&inst)) { - if (isVaried(si->getSrc(), i)) - recursivelySetVaried(si->getDest(), i); - } - // Handle `copy_addr`. - else if (auto *cai = dyn_cast(&inst)) { - if (isVaried(cai->getSrc(), i)) - recursivelySetVaried(cai->getDest(), i); - } - // Handle `unconditional_checked_cast_addr`. - else if (auto *uccai = - dyn_cast(&inst)) { - if (isVaried(uccai->getSrc(), i)) - recursivelySetVaried(uccai->getDest(), i); + // Handle store-like instructions: + // `store`, `store_borrow`, `copy_addr`, `unconditional_checked_cast` +#define PROPAGATE_VARIED_THROUGH_STORE(INST) \ + else if (auto *si = dyn_cast(&inst)) { \ + if (isVaried(si->getSrc(), i)) \ + propagateVariedInwardsThroughProjections(si->getDest(), i); \ } + PROPAGATE_VARIED_THROUGH_STORE(Store) + PROPAGATE_VARIED_THROUGH_STORE(StoreBorrow) + PROPAGATE_VARIED_THROUGH_STORE(CopyAddr) + PROPAGATE_VARIED_THROUGH_STORE(UnconditionalCheckedCastAddr) +#undef PROPAGATE_VARIED_THROUGH_STORE // Handle `tuple_element_addr`. else if (auto *teai = dyn_cast(&inst)) { if (isVaried(teai->getOperand(), i)) { @@ -1908,24 +1903,19 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di, setVaried(teai, i); } } - -// Handle `struct_extract` and `struct_element_addr` instructions. -// - If the field is marked `@noDerivative`, do not set the result as varied -// because it is not in the set of differentiable variables. -// - Otherwise, propagate variedness from operand to result as usual. + // Handle `struct_extract` and `struct_element_addr` instructions. + // - If the field is marked `@noDerivative`, do not set the result as + // varied because it is not in the set of differentiable variables. + // - Otherwise, propagate variedness from operand to result as usual. #define PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(INST) \ - else if (auto *sei = dyn_cast(&inst)) { \ - if (isVaried(sei->getOperand(), i)) { \ - auto hasNoDeriv = sei->getField()->getAttrs() \ - .hasAttribute(); \ - if (!hasNoDeriv) \ - setVaried(sei, i); \ - } \ - } - PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(StructExtract) - PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(StructElementAddr) -#undef VISIT_STRUCT_ELEMENT_INNS - + else if (auto *sei = dyn_cast(&inst)) { \ + if (isVaried(sei->getOperand(), i) && \ + !sei->getField()->getAttrs().hasAttribute()) \ + setVaried(sei, i); \ + } + PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(StructExtract) + PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(StructElementAddr) +#undef PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION // Handle `br`. else if (auto *bi = dyn_cast(&inst)) { for (auto &op : bi->getAllOperands()) @@ -1947,12 +1937,10 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di, } // Handle `switch_enum`. else if (auto *sei = dyn_cast(&inst)) { - if (isVaried(sei->getOperand(), i)) { + if (isVaried(sei->getOperand(), i)) for (auto *succBB : sei->getSuccessorBlocks()) for (auto *arg : succBB->getArguments()) setVaried(arg, i); - // Default block cannot have arguments. - } } // Handle everything else. else { @@ -2002,27 +1990,34 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di, if (paramInfos[i].isIndirectInOut()) checkAndSetUseful(ai->getArgumentsWithoutIndirectResults()[i]); } - // Handle `store`. - else if (auto *si = dyn_cast(&inst)) { - if (isUseful(si->getDest(), i)) - setUseful(si->getSrc(), i); - } - // Handle `store_borrow`. - else if (auto *sbi = dyn_cast(&inst)) { - if (isUseful(sbi->getDest(), i)) - setUseful(sbi->getSrc(), i); - } - // Handle `copy_addr`. - else if (auto *cai = dyn_cast(&inst)) { - if (isUseful(cai->getDest(), i)) - propagateUsefulThroughBuffer(cai->getSrc(), i); + // Handle store-like instructions: + // `store`, `store_borrow`, `copy_addr`, `unconditional_checked_cast` +#define PROPAGATE_USEFUL_THROUGH_STORE(INST, PROPAGATE) \ + else if (auto *si = dyn_cast(&inst)) { \ + if (isUseful(si->getDest(), i)) \ + PROPAGATE(si->getSrc(), i); \ } - // Handle `unconditional_checked_cast_addr`. - else if (auto *uccai = - dyn_cast(&inst)) { - if (isUseful(uccai->getDest(), i)) - propagateUsefulThroughBuffer(uccai->getSrc(), i); + PROPAGATE_USEFUL_THROUGH_STORE(Store, setUseful) + PROPAGATE_USEFUL_THROUGH_STORE(StoreBorrow, setUseful) + PROPAGATE_USEFUL_THROUGH_STORE(CopyAddr, propagateUsefulThroughBuffer) + PROPAGATE_USEFUL_THROUGH_STORE(UnconditionalCheckedCastAddr, + propagateUsefulThroughBuffer) +#undef PROPAGATE_USEFUL_THROUGH_STORE + // Handle struct element extraction, skipping `@noDerivative` fields: + // `struct_extract`, `struct_element_addr`. +#define PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(INST, PROPAGATE) \ + else if (auto *sei = dyn_cast(&inst)) { \ + if (isUseful(sei, i)) { \ + auto hasNoDeriv = sei->getField()->getAttrs() \ + .hasAttribute(); \ + if (!hasNoDeriv) \ + PROPAGATE(sei->getOperand(), i); \ + } \ } + PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructExtract, setUseful) + PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructElementAddr, + propagateUsefulThroughBuffer) +#undef PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION // Handle everything else. else if (llvm::any_of(inst.getResults(), [&](SILValue res) { return isUseful(res, i); })) { @@ -2100,15 +2095,23 @@ void DifferentiableActivityInfo::setUseful(SILValue value, setUsefulAcrossArrayInitialization(value, dependentVariableIndex); } -void DifferentiableActivityInfo::recursivelySetVaried( +void DifferentiableActivityInfo::propagateVariedInwardsThroughProjections( SILValue value, unsigned independentVariableIndex) { - setVaried(value, independentVariableIndex); - if (auto *inst = value->getDefiningInstruction()) { - if (auto *ai = dyn_cast(inst)) +#define SKIP_NODERIVATIVE(INST) \ + if (auto *sei = dyn_cast(value)) \ + if (sei->getField()->getAttrs().hasAttribute()) \ return; - for (auto &op : inst->getAllOperands()) - recursivelySetVaried(op.get(), independentVariableIndex); - } + SKIP_NODERIVATIVE(StructExtract) + SKIP_NODERIVATIVE(StructElementAddr) +#undef SKIP_NODERIVATIVE + setVaried(value, independentVariableIndex); + auto *inst = value->getDefiningInstruction(); + if (!inst || isa(inst)) + return; + // Standard propagation. + for (auto &op : inst->getAllOperands()) + propagateVariedInwardsThroughProjections( + op.get(), independentVariableIndex); } void DifferentiableActivityInfo::propagateUsefulThroughBuffer( @@ -2125,14 +2128,25 @@ void DifferentiableActivityInfo::propagateUsefulThroughBuffer( propagateUsefulThroughBuffer(operand.get(), dependentVariableIndex); // Recursively propagate usefulness through users that are projections or // `begin_access` instructions. - for (auto use : value->getUses()) - for (auto res : use->getUser()->getResults()) + for (auto use : value->getUses()) { + for (auto res : use->getUser()->getResults()) { +#define SKIP_NODERIVATIVE(INST) \ + if (auto *sei = dyn_cast(res)) \ + if (sei->getField()->getAttrs().hasAttribute()) \ + continue; + SKIP_NODERIVATIVE(StructExtract) + SKIP_NODERIVATIVE(StructElementAddr) +#undef SKIP_NODERIVATIVE if (Projection::isAddressProjection(res) || isa(res)) propagateUsefulThroughBuffer(res, dependentVariableIndex); + } + } } bool DifferentiableActivityInfo::isVaried( SILValue value, unsigned independentVariableIndex) const { + assert(independentVariableIndex < variedValueSets.size() && + "Independent variable index out of range"); auto &set = variedValueSets[independentVariableIndex]; return set.count(value); } @@ -2147,6 +2161,8 @@ bool DifferentiableActivityInfo::isVaried( bool DifferentiableActivityInfo::isUseful( SILValue value, unsigned dependentVariableIndex) const { + assert(dependentVariableIndex < usefulValueSets.size() && + "Dependent variable index out of range"); auto &set = usefulValueSets[dependentVariableIndex]; return set.count(value); } @@ -4534,8 +4550,6 @@ class JVPEmitter final auto loc = si->getLoc(); auto tanValSrc = materializeTangent(getTangentValue(si->getSrc()), loc); auto &tanValDest = getTangentBuffer(si->getParent(), si->getDest()); - if (errorOccurred) - return; diffBuilder.emitStoreValueOperation( loc, tanValSrc, tanValDest, si->getOwnershipQualifier()); } @@ -4548,8 +4562,6 @@ class JVPEmitter final auto loc = sbi->getLoc(); auto tanValSrc = materializeTangent(getTangentValue(sbi->getSrc()), loc); auto &tanValDest = getTangentBuffer(sbi->getParent(), sbi->getDest()); - if (errorOccurred) - return; diffBuilder.createStoreBorrow(loc, tanValSrc, tanValDest); } @@ -4569,8 +4581,6 @@ class JVPEmitter final auto *bb = cai->getParent(); auto &tanSrc = getTangentBuffer(bb, cai->getSrc()); auto tanDest = getTangentBuffer(bb, cai->getDest()); - if (errorOccurred) - return; diffBuilder.createCopyAddr(loc, tanSrc, tanDest, cai->isTakeOfSrc(), cai->isInitializationOfDest()); @@ -4586,8 +4596,6 @@ class JVPEmitter final auto *bb = uccai->getParent(); auto &tanSrc = getTangentBuffer(bb, uccai->getSrc()); auto tanDest = getTangentBuffer(bb, uccai->getDest()); - if (errorOccurred) - return; diffBuilder.createUnconditionalCheckedCastAddr( loc, tanSrc, tanSrc->getType().getASTType(), tanDest, @@ -6051,17 +6059,6 @@ class PullbackEmitter final : public SILInstructionVisitor { if (!insertion.second) // not inserted return insertion.first->getSecond(); - // Diagnose `struct_element_addr` instructions to `@noDerivative` fields. - if (auto *seai = dyn_cast(originalBuffer)) { - if (seai->getField()->getAttrs().hasAttribute()) { - getContext().emitNondifferentiabilityError( - originalBuffer, getInvoker(), - diag::autodiff_noderivative_stored_property); - errorOccurred = true; - return (bufferMap[{origBB, originalBuffer}] = SILValue()); - } - } - // If the original buffer is a projection, return a corresponding projection // into the adjoint buffer. if (auto adjProj = getAdjointProjection(origBB, originalBuffer)) @@ -6099,8 +6096,6 @@ class PullbackEmitter final : public SILInstructionVisitor { assert(originalBuffer->getFunction() == &getOriginal()); assert(rhsBufferAccess->getFunction() == &getPullback()); auto adjointBuffer = getAdjointBuffer(origBB, originalBuffer); - if (errorOccurred) - return; accumulateIndirect(adjointBuffer, rhsBufferAccess, loc); } @@ -6360,8 +6355,6 @@ class PullbackEmitter final : public SILInstructionVisitor { retElts.push_back(newVal); } else { auto adjBuf = getAdjointBuffer(origEntry, origParam); - if (errorOccurred) - return; indParamAdjoints.push_back(adjBuf); } }; @@ -6830,8 +6823,6 @@ class PullbackEmitter final : public SILInstructionVisitor { seed = materializeAdjoint(getAdjointValue(bb, origResult), loc); } else { seed = getAdjointBuffer(bb, origResult); - if (errorOccurred) - return; } // Create allocations for pullback indirect results. @@ -6887,8 +6878,6 @@ class PullbackEmitter final : public SILInstructionVisitor { builder.emitDestroyAddrAndFold(loc, tan); } else { if (origArg->getType().isAddress()) { - if (errorOccurred) - return; auto *tmpBuf = builder.createAllocStack(loc, tan->getType()); builder.emitStoreValueOperation(loc, tan, tmpBuf, StoreOwnershipQualifier::Init); @@ -7182,8 +7171,6 @@ class PullbackEmitter final : public SILInstructionVisitor { StoreOwnershipQualifier::Init); // Accumulate the adjoint value in the local buffer into the adjoint buffer. addToAdjointBuffer(bb, inst->getOperand(0), localBuf, inst->getLoc()); - if (errorOccurred) - return; builder.emitDestroyAddr(inst->getLoc(), localBuf); builder.createDeallocStack(inst->getLoc(), localBuf); } @@ -7196,8 +7183,6 @@ class PullbackEmitter final : public SILInstructionVisitor { void visitStoreOperation(SILBasicBlock *bb, SILLocation loc, SILValue origSrc, SILValue origDest) { auto &adjBuf = getAdjointBuffer(bb, origDest); - if (errorOccurred) - return; auto bufType = remapType(adjBuf->getType()); auto adjVal = builder.emitLoadValueOperation( loc, adjBuf, LoadOwnershipQualifier::Take); @@ -7220,8 +7205,6 @@ class PullbackEmitter final : public SILInstructionVisitor { void visitCopyAddrInst(CopyAddrInst *cai) { auto *bb = cai->getParent(); auto &adjDest = getAdjointBuffer(bb, cai->getDest()); - if (errorOccurred) - return; auto destType = remapType(adjDest->getType()); addToAdjointBuffer(bb, cai->getSrc(), adjDest, cai->getLoc()); builder.emitDestroyAddrAndFold(cai->getLoc(), adjDest); @@ -7275,8 +7258,6 @@ class PullbackEmitter final : public SILInstructionVisitor { auto *bb = uccai->getParent(); auto &adjDest = getAdjointBuffer(bb, uccai->getDest()); auto &adjSrc = getAdjointBuffer(bb, uccai->getSrc()); - if (errorOccurred) - return; auto destType = remapType(adjDest->getType()); auto castBuf = builder.createAllocStack(uccai->getLoc(), adjSrc->getType()); builder.createUnconditionalCheckedCastAddr( diff --git a/test/AutoDiff/autodiff_diagnostics.swift b/test/AutoDiff/autodiff_diagnostics.swift index d0fdd2fb7b4c2..8a36f7fb68be4 100644 --- a/test/AutoDiff/autodiff_diagnostics.swift +++ b/test/AutoDiff/autodiff_diagnostics.swift @@ -42,11 +42,9 @@ struct NoDerivativeProperty : Differentiable { var x: Float @noDerivative var y: Float } -// expected-error @+1 {{function is not differentiable}} _ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) { s -> Float in var tmp = s - // expected-note @+1 {{cannot differentiate through a '@noDerivative' stored property; do you want to use 'withoutDerivative(at:)'?}} - tmp.y = tmp.x + tmp.y = tmp.x // No diagnostics expected. return tmp.x } _ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) { s in diff --git a/test/AutoDiff/simple_math.swift b/test/AutoDiff/simple_math.swift index 401fb7ed8651f..a370a1d92e8e8 100644 --- a/test/AutoDiff/simple_math.swift +++ b/test/AutoDiff/simple_math.swift @@ -304,6 +304,21 @@ SimpleMathTests.test("StructGeneric") { expectEqual(405, gradient(at: 3, in: fifthPower)) } +SimpleMathTests.test("StructWithNoDerivativeProperty") { + struct NoDerivativeProperty : Differentiable { + var x: Float + @noDerivative var y: Float + } + expectEqual( + NoDerivativeProperty.TangentVector(x: 1), + gradient(at: NoDerivativeProperty(x: 1, y: 1)) { s -> Float in + var tmp = s + tmp.y = tmp.x + return tmp.x + } + ) +} + SimpleMathTests.test("SubsetIndices") { func grad(_ lossFunction: @differentiable (Float, Float) -> Float) -> Float { return gradient(at: 1) { x in lossFunction(x * x, 10.0) }