Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 30 additions & 16 deletions lib/SILOptimizer/Differentiation/PullbackCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
//
//===----------------------------------------------------------------------===//

#include "swift/Basic/STLExtras.h"
#define DEBUG_TYPE "differentiation"

#include "swift/SILOptimizer/Differentiation/PullbackCloner.h"
Expand All @@ -31,6 +30,7 @@
#include "swift/AST/PropertyWrappers.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/Basic/Assertions.h"
#include "swift/Basic/STLExtras.h"
#include "swift/SIL/ApplySite.h"
#include "swift/SIL/InstructionUtils.h"
#include "swift/SIL/Projection.h"
Expand Down Expand Up @@ -131,6 +131,10 @@ class PullbackCloner::Implementation final
/// Stack buffers allocated for storing local adjoint values.
SmallVector<AllocStackInst *, 64> functionLocalAllocations;

/// Copies created to deal with destructive enum operations
/// (unchecked_take_enum_addr)
llvm::SmallDenseMap<InitEnumDataAddrInst*, SILValue> enumDataAdjCopies;

/// A set used to remember local allocations that were destroyed.
llvm::SmallDenseSet<SILValue> destroyedLocalAllocations;

Expand Down Expand Up @@ -1858,7 +1862,7 @@ class PullbackCloner::Implementation final
/// Handle a sequence of `init_enum_data_addr` and `inject_enum_addr`
/// instructions.
///
/// Original: y = init_enum_data_addr x
/// Original: x = init_enum_data_addr y : $*Enum, #Enum.Case
/// inject_enum_addr y
///
/// Adjoint: adj[x] += unchecked_take_enum_data_addr adj[y]
Expand All @@ -1879,6 +1883,10 @@ class PullbackCloner::Implementation final
return;
}

// No associated value => no adjoint to propagate
if (!inject->getElement()->hasAssociatedValues())
return;

InitEnumDataAddrInst *origData = nullptr;
for (auto use : origEnum->getUses()) {
if (auto *init = dyn_cast<InitEnumDataAddrInst>(use->getUser())) {
Expand All @@ -1900,9 +1908,9 @@ class PullbackCloner::Implementation final
}
}

SILValue adjStruct = getAdjointBuffer(bb, origEnum);
SILValue adjDest = getAdjointBuffer(bb, origEnum);
StructDecl *adjStructDecl =
adjStruct->getType().getStructOrBoundGenericStruct();
adjDest->getType().getStructOrBoundGenericStruct();

VarDecl *adjOptVar = nullptr;
if (adjStructDecl) {
Expand All @@ -1922,35 +1930,35 @@ class PullbackCloner::Implementation final

SILLocation loc = origData->getLoc();
StructElementAddrInst *adjOpt =
builder.createStructElementAddr(loc, adjStruct, adjOptVar);
builder.createStructElementAddr(loc, adjDest, adjOptVar);

// unchecked_take_enum_data_addr is destructive, so copy
// Optional<T.TangentVector> to a new alloca.
AllocStackInst *adjOptCopy =
createFunctionLocalAllocation(adjOpt->getType(), loc);
builder.createCopyAddr(loc, adjOpt, adjOptCopy, IsNotTake,
IsInitialization);
// The Optional copy is invalidated, do not attempt to destroy it at the end
// of the pullback. The value returned from unchecked_take_enum_data_addr is
// destroyed in visitInitEnumDataAddrInst.
auto [_, inserted] = enumDataAdjCopies.try_emplace(origData, adjOptCopy);
assert(inserted && "expected single buffer");

EnumElementDecl *someElemDecl = getASTContext().getOptionalSomeDecl();
UncheckedTakeEnumDataAddrInst *adjData =
builder.createUncheckedTakeEnumDataAddr(loc, adjOptCopy, someElemDecl);

setAdjointBuffer(bb, origData, adjData);

// The Optional copy is invalidated, do not attempt to destroy it at the end
// of the pullback. The value returned from unchecked_take_enum_data_addr is
// destroyed in visitInitEnumDataAddrInst.
destroyedLocalAllocations.insert(adjOptCopy);
addToAdjointBuffer(bb, origData, adjData, loc);
}

/// Handle `init_enum_data_addr` instruction.
/// Destroy the value returned from `unchecked_take_enum_data_addr`.
void visitInitEnumDataAddrInst(InitEnumDataAddrInst *init) {
auto bufIt = bufferMap.find({init->getParent(), SILValue(init)});
if (bufIt == bufferMap.end())
return;
SILValue adjData = bufIt->second;
builder.emitDestroyAddr(init->getLoc(), adjData);
SILValue adjOptCopy = enumDataAdjCopies.at(init);

builder.emitDestroyAddr(init->getLoc(), adjOptCopy);
destroyedLocalAllocations.insert(adjOptCopy);
enumDataAdjCopies.erase(init);
}

/// Handle `unchecked_ref_cast` instruction.
Expand Down Expand Up @@ -2567,6 +2575,12 @@ bool PullbackCloner::Implementation::run() {
}
}
}
// Ensure all enum adjoint copeis have been cleaned up
for (const auto &enumData : enumDataAdjCopies) {
leakFound = true;
getADDebugStream() << "Found leaked temporary:\n" << enumData.second;
}

// Ensure all local allocations have been cleaned up.
for (auto localAlloc : functionLocalAllocations) {
if (!destroyedLocalAllocations.count(localAlloc)) {
Expand Down
16 changes: 8 additions & 8 deletions test/AutoDiff/SILOptimizer/optional_pullback.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ import _Differentiation
// CHECK-SAME: (@in_guaranteed Optional<τ_0_0>.TangentVector) -> @out τ_0_0.TangentVector
//
// CHECK: bb0(%[[RET_TAN:.+]] : $*τ_0_0.TangentVector, %[[OPT_TAN:.+]] : $*Optional<τ_0_0>.TangentVector):
// CHECK: %[[RET_TAN_BUF:.+]] = alloc_stack $τ_0_0.TangentVector
// CHECK: %[[RET_TAN_BUF:.+]] = alloc_stack $τ_0_0.TangentVector, let, name "derivative of 'x'

// CHECK: %[[ZERO1:.+]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic.zero!getter
// CHECK: apply %[[ZERO1]]<τ_0_0.TangentVector>(%[[RET_TAN_BUF]], %{{.*}})
// CHECK: %[[ADJ_IN_BB:.+]] = alloc_stack $τ_0_0.TangentVector
//
// CHECK: %[[TAN_VAL_COPY:.+]] = alloc_stack $Optional<τ_0_0.TangentVector>
// CHECK: %[[TAN_BUF:.+]] = alloc_stack $Optional<τ_0_0>.TangentVector
Expand All @@ -21,13 +22,12 @@ import _Differentiation
//
// CHECK: %[[TAN_DATA:.+]] = unchecked_take_enum_data_addr %[[TAN_VAL_COPY]] : $*Optional<τ_0_0.TangentVector>, #Optional.some!enumelt
// CHECK: %[[PLUS_EQUAL:.+]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic."+="
// CHECK: apply %[[PLUS_EQUAL]]<τ_0_0.TangentVector>(%[[RET_TAN_BUF]], %[[TAN_DATA]], %{{.*}})
//
// CHECK: destroy_addr %[[TAN_DATA]] : $*τ_0_0.TangentVector
// CHECK: %[[ZERO2:.+]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic.zero!getter
// CHECK: apply %[[ZERO2]]<τ_0_0.TangentVector>(%[[TAN_DATA]], %{{.*}})
// CHECK: destroy_addr %[[TAN_DATA]] : $*τ_0_0.TangentVector
//
// CHECK: apply %[[PLUS_EQUAL]]<τ_0_0.TangentVector>(%[[ADJ_IN_BB]], %[[TAN_DATA]], %{{.*}})

// CHECK: %[[PLUS_EQUAL:.+]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic."+="
// CHECK: apply %[[PLUS_EQUAL]]<τ_0_0.TangentVector>(%[[RET_TAN_BUF]], %[[ADJ_IN_BB]], %{{.*}})
// CHECK: destroy_addr %[[ADJ_IN_BB]] : $*τ_0_0.TangentVector

// CHECK: copy_addr [take] %[[RET_TAN_BUF:.+]] to [init] %[[RET_TAN:.+]]
// CHECK: destroy_addr %[[TAN_BUF]] : $*Optional<τ_0_0>.TangentVector
// CHECK: dealloc_stack %[[TAN_BUF]] : $*Optional<τ_0_0>.TangentVector
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: %target-swift-frontend -emit-sil -verify %s

// https://github.com/swiftlang/swift/issues/75280
// Ensure we accumulate adjoints properly for inject_enum_addr instructions and
// handle `nil` case (no adjoint value to propagate)


import _Differentiation
@differentiable(reverse) func a<F, A>(_ f: Optional<F>, c: @differentiable(reverse) (F) -> A) -> Optional<A> where F: Differentiable, A: Differentiable
{
guard let f else {return nil}
return c(f)
}