diff --git a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp index 786da33ac272d..c947ec6aa93b8 100644 --- a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp @@ -15,7 +15,6 @@ // //===----------------------------------------------------------------------===// -#include "swift/Basic/STLExtras.h" #define DEBUG_TYPE "differentiation" #include "swift/SILOptimizer/Differentiation/PullbackCloner.h" @@ -31,6 +30,7 @@ #include "swift/AST/PropertyWrappers.h" #include "swift/AST/TypeCheckRequests.h" #include "swift/Basic/Assertions.h" +#include "swift/Basic/STLExtras.h" #include "swift/SIL/ApplySite.h" #include "swift/SIL/InstructionUtils.h" #include "swift/SIL/Projection.h" @@ -131,6 +131,10 @@ class PullbackCloner::Implementation final /// Stack buffers allocated for storing local adjoint values. SmallVector functionLocalAllocations; + /// Copies created to deal with destructive enum operations + /// (unchecked_take_enum_addr) + llvm::SmallDenseMap enumDataAdjCopies; + /// A set used to remember local allocations that were destroyed. llvm::SmallDenseSet destroyedLocalAllocations; @@ -1858,7 +1862,7 @@ class PullbackCloner::Implementation final /// Handle a sequence of `init_enum_data_addr` and `inject_enum_addr` /// instructions. /// - /// Original: y = init_enum_data_addr x + /// Original: x = init_enum_data_addr y : $*Enum, #Enum.Case /// inject_enum_addr y /// /// Adjoint: adj[x] += unchecked_take_enum_data_addr adj[y] @@ -1879,6 +1883,10 @@ class PullbackCloner::Implementation final return; } + // No associated value => no adjoint to propagate + if (!inject->getElement()->hasAssociatedValues()) + return; + InitEnumDataAddrInst *origData = nullptr; for (auto use : origEnum->getUses()) { if (auto *init = dyn_cast(use->getUser())) { @@ -1900,9 +1908,9 @@ class PullbackCloner::Implementation final } } - SILValue adjStruct = getAdjointBuffer(bb, origEnum); + SILValue adjDest = getAdjointBuffer(bb, origEnum); StructDecl *adjStructDecl = - adjStruct->getType().getStructOrBoundGenericStruct(); + adjDest->getType().getStructOrBoundGenericStruct(); VarDecl *adjOptVar = nullptr; if (adjStructDecl) { @@ -1922,7 +1930,7 @@ class PullbackCloner::Implementation final SILLocation loc = origData->getLoc(); StructElementAddrInst *adjOpt = - builder.createStructElementAddr(loc, adjStruct, adjOptVar); + builder.createStructElementAddr(loc, adjDest, adjOptVar); // unchecked_take_enum_data_addr is destructive, so copy // Optional to a new alloca. @@ -1930,27 +1938,27 @@ class PullbackCloner::Implementation final createFunctionLocalAllocation(adjOpt->getType(), loc); builder.createCopyAddr(loc, adjOpt, adjOptCopy, IsNotTake, IsInitialization); + // The Optional copy is invalidated, do not attempt to destroy it at the end + // of the pullback. The value returned from unchecked_take_enum_data_addr is + // destroyed in visitInitEnumDataAddrInst. + auto [_, inserted] = enumDataAdjCopies.try_emplace(origData, adjOptCopy); + assert(inserted && "expected single buffer"); EnumElementDecl *someElemDecl = getASTContext().getOptionalSomeDecl(); UncheckedTakeEnumDataAddrInst *adjData = builder.createUncheckedTakeEnumDataAddr(loc, adjOptCopy, someElemDecl); - setAdjointBuffer(bb, origData, adjData); - - // The Optional copy is invalidated, do not attempt to destroy it at the end - // of the pullback. The value returned from unchecked_take_enum_data_addr is - // destroyed in visitInitEnumDataAddrInst. - destroyedLocalAllocations.insert(adjOptCopy); + addToAdjointBuffer(bb, origData, adjData, loc); } /// Handle `init_enum_data_addr` instruction. /// Destroy the value returned from `unchecked_take_enum_data_addr`. void visitInitEnumDataAddrInst(InitEnumDataAddrInst *init) { - auto bufIt = bufferMap.find({init->getParent(), SILValue(init)}); - if (bufIt == bufferMap.end()) - return; - SILValue adjData = bufIt->second; - builder.emitDestroyAddr(init->getLoc(), adjData); + SILValue adjOptCopy = enumDataAdjCopies.at(init); + + builder.emitDestroyAddr(init->getLoc(), adjOptCopy); + destroyedLocalAllocations.insert(adjOptCopy); + enumDataAdjCopies.erase(init); } /// Handle `unchecked_ref_cast` instruction. @@ -2567,6 +2575,12 @@ bool PullbackCloner::Implementation::run() { } } } + // Ensure all enum adjoint copeis have been cleaned up + for (const auto &enumData : enumDataAdjCopies) { + leakFound = true; + getADDebugStream() << "Found leaked temporary:\n" << enumData.second; + } + // Ensure all local allocations have been cleaned up. for (auto localAlloc : functionLocalAllocations) { if (!destroyedLocalAllocations.count(localAlloc)) { diff --git a/test/AutoDiff/SILOptimizer/optional_pullback.swift b/test/AutoDiff/SILOptimizer/optional_pullback.swift index 909a9e54d12da..b851e0e0d1df4 100644 --- a/test/AutoDiff/SILOptimizer/optional_pullback.swift +++ b/test/AutoDiff/SILOptimizer/optional_pullback.swift @@ -7,10 +7,11 @@ import _Differentiation // CHECK-SAME: (@in_guaranteed Optional<τ_0_0>.TangentVector) -> @out τ_0_0.TangentVector // // CHECK: bb0(%[[RET_TAN:.+]] : $*τ_0_0.TangentVector, %[[OPT_TAN:.+]] : $*Optional<τ_0_0>.TangentVector): -// CHECK: %[[RET_TAN_BUF:.+]] = alloc_stack $τ_0_0.TangentVector +// CHECK: %[[RET_TAN_BUF:.+]] = alloc_stack $τ_0_0.TangentVector, let, name "derivative of 'x' // CHECK: %[[ZERO1:.+]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic.zero!getter // CHECK: apply %[[ZERO1]]<τ_0_0.TangentVector>(%[[RET_TAN_BUF]], %{{.*}}) +// CHECK: %[[ADJ_IN_BB:.+]] = alloc_stack $τ_0_0.TangentVector // // CHECK: %[[TAN_VAL_COPY:.+]] = alloc_stack $Optional<τ_0_0.TangentVector> // CHECK: %[[TAN_BUF:.+]] = alloc_stack $Optional<τ_0_0>.TangentVector @@ -21,13 +22,12 @@ import _Differentiation // // CHECK: %[[TAN_DATA:.+]] = unchecked_take_enum_data_addr %[[TAN_VAL_COPY]] : $*Optional<τ_0_0.TangentVector>, #Optional.some!enumelt // CHECK: %[[PLUS_EQUAL:.+]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic."+=" -// CHECK: apply %[[PLUS_EQUAL]]<τ_0_0.TangentVector>(%[[RET_TAN_BUF]], %[[TAN_DATA]], %{{.*}}) -// -// CHECK: destroy_addr %[[TAN_DATA]] : $*τ_0_0.TangentVector -// CHECK: %[[ZERO2:.+]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic.zero!getter -// CHECK: apply %[[ZERO2]]<τ_0_0.TangentVector>(%[[TAN_DATA]], %{{.*}}) -// CHECK: destroy_addr %[[TAN_DATA]] : $*τ_0_0.TangentVector -// +// CHECK: apply %[[PLUS_EQUAL]]<τ_0_0.TangentVector>(%[[ADJ_IN_BB]], %[[TAN_DATA]], %{{.*}}) + +// CHECK: %[[PLUS_EQUAL:.+]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic."+=" +// CHECK: apply %[[PLUS_EQUAL]]<τ_0_0.TangentVector>(%[[RET_TAN_BUF]], %[[ADJ_IN_BB]], %{{.*}}) +// CHECK: destroy_addr %[[ADJ_IN_BB]] : $*τ_0_0.TangentVector + // CHECK: copy_addr [take] %[[RET_TAN_BUF:.+]] to [init] %[[RET_TAN:.+]] // CHECK: destroy_addr %[[TAN_BUF]] : $*Optional<τ_0_0>.TangentVector // CHECK: dealloc_stack %[[TAN_BUF]] : $*Optional<τ_0_0>.TangentVector diff --git a/test/AutoDiff/compiler_crashers_fixed/issue-75280-guard-let-optional.swift b/test/AutoDiff/compiler_crashers_fixed/issue-75280-guard-let-optional.swift new file mode 100644 index 0000000000000..a2f9f47baa89d --- /dev/null +++ b/test/AutoDiff/compiler_crashers_fixed/issue-75280-guard-let-optional.swift @@ -0,0 +1,13 @@ +// RUN: %target-swift-frontend -emit-sil -verify %s + +// https://github.com/swiftlang/swift/issues/75280 +// Ensure we accumulate adjoints properly for inject_enum_addr instructions and +// handle `nil` case (no adjoint value to propagate) + + +import _Differentiation +@differentiable(reverse) func a(_ f: Optional, c: @differentiable(reverse) (F) -> A) -> Optional where F: Differentiable, A: Differentiable +{ + guard let f else {return nil} + return c(f) +}