Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions include/swift/SIL/SILInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2046,6 +2046,11 @@ class ApplyInstBase<Impl, Base, true>
ApplyInstBase(As &&...args)
: ApplyInstBase<Impl, Base, false>(std::forward<As>(args)...) {}

// SWIFT_ENABLE_TENSORFLOW
private:
const Impl &asImpl() const { return static_cast<const Impl &>(*this); }
// SWIFT_ENABLE_TENSORFLOW END

public:
using super::getCallee;
using super::getSubstCalleeType;
Expand Down Expand Up @@ -2133,6 +2138,35 @@ class ApplyInstBase<Impl, Base, true>
bool hasSemantics(StringRef semanticsString) const {
return doesApplyCalleeHaveSemantics(getCallee(), semanticsString);
}

// SWIFT_ENABLE_TENSORFLOW
private:
/// Predicate used to filter InoutArgumentRange.
struct OperandToInoutArgument {
ArrayRef<SILParameterInfo> paramInfos;
OperandValueArrayRef arguments;
OperandToInoutArgument(const Impl &inst)
: paramInfos(inst.getSubstCalleeConv().getParameters()),
arguments(inst.getArgumentsWithoutIndirectResults()) {
assert(paramInfos.size() == arguments.size());
}
Optional<SILValue> operator()(unsigned long i) const {
if (paramInfos[i].isIndirectMutating())
return arguments[i];
return None;
}
};

public:
using InoutArgumentRange =
OptionalTransformRange<IntRange<unsigned long>, OperandToInoutArgument>;
/// Returns all `@inout` and `@inout_aliasable` arguments passed to the
/// instruction.
InoutArgumentRange getInoutArguments() const {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: getInoutArguments is currently defined in the "full application" case of ApplyInstBase. This required duplicating asImpl from the "partial application" ApplyInstBase superclass, which is a private function.

Other options:

  • Define getInoutArguments as a top-level function in Differentiation.cpp specialized for just ApplyInst. No problems there.
  • Move getInoutArguments to the "partial application" superclass ApplyInstBase. This requires more duplication because getArgumentsWithoutIndirectResults is defined only in the "full application" ApplyInstBase.

return InoutArgumentRange(indices(getArgumentsWithoutIndirectResults()),
OperandToInoutArgument(asImpl()));
}
// SWIFT_ENABLE_TENSORFLOW END
};

/// ApplyInst - Represents the full application of a function value.
Expand Down
58 changes: 23 additions & 35 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1622,11 +1622,8 @@ LinearMapInfo::LinearMapInfo(ADContext &context,
/// active argument.
bool LinearMapInfo::shouldDifferentiateApplyInst(ApplyInst *ai) {
// Function applications with an inout argument should be differentiated.
auto paramInfos = ai->getSubstCalleeConv().getParameters();
auto arguments = ai->getArgumentsWithoutIndirectResults();
for (auto i : swift::indices(paramInfos))
if (paramInfos[i].isIndirectInOut() &&
activityInfo.isActive(arguments[i], indices))
for (auto inoutArg : ai->getInoutArguments())
if (activityInfo.isActive(inoutArg, indices))
return true;

bool hasActiveDirectResults = false;
Expand All @@ -1642,6 +1639,7 @@ bool LinearMapInfo::shouldDifferentiateApplyInst(ApplyInst *ai) {
if (isArrayLiteralIntrinsic(ai) && hasActiveResults)
return true;

auto arguments = ai->getArgumentsWithoutIndirectResults();
bool hasActiveArguments = llvm::any_of(arguments,
[&](SILValue arg) { return activityInfo.isActive(arg, indices); });
return hasActiveResults && hasActiveArguments;
Expand Down Expand Up @@ -1834,20 +1832,13 @@ void LinearMapInfo::generateDifferentiationDataStructures(
for (auto &origBB : *original) {
for (auto &inst : origBB) {
if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
// Check for active 'inout' arguments.
bool isInout = false;
auto paramInfos = ai->getSubstCalleeConv().getParameters();
for (unsigned i : swift::indices(paramInfos)) {
if (paramInfos[i].isIndirectInOut() &&
activityInfo.isActive(ai->getArgumentsWithoutIndirectResults()[i],
indices)) {
// Reject functions with active inout arguments. It's not yet
// supported.
isInout = true;
break;
}
}
if (isInout)
// Skip `apply` instructions with active `inout` arguments.
// TODO(TF-129): Support `inout` argument differentiation.
bool hasActiveInoutArgument =
llvm::any_of(ai->getInoutArguments(), [&](SILValue inoutArg) {
return activityInfo.isActive(inoutArg, indices);
});
if (hasActiveInoutArgument)
continue;

// Add linear map field to struct for active `apply` instructions.
Expand Down Expand Up @@ -2008,10 +1999,13 @@ void DifferentiableActivityInfo::propagateVaried(
// If callee is non-varying, skip.
if (isWithoutDerivative(ai->getCallee()))
return;
// If operand is varied, set all direct and indirect results as varied.
// If operand is varied, set all direct/indirect results and inout arguments
// as varied.
if (isVaried(operand->get(), i)) {
for (auto indRes : ai->getIndirectSILResults())
propagateVariedInwardsThroughProjections(indRes, i);
for (auto inoutArg : ai->getInoutArguments())
propagateVariedInwardsThroughProjections(inoutArg, i);
forEachApplyDirectResult(ai, [&](SILValue directResult) {
setVariedAndPropagateToUsers(directResult, i);
});
Expand Down Expand Up @@ -3778,7 +3772,7 @@ class VJPEmitter final
sei->getLoc(), getOpValue(sei->getOperand()), newDefaultBB, caseBBs);
}

// If an `apply` has active results or active inout parameters, replace it
// If an `apply` has active results or active inout arguments, replace it
// with an `apply` of its VJP.
void visitApplyInst(ApplyInst *ai) {
// If the function should not be differentiated or its the array literal
Expand All @@ -3790,13 +3784,10 @@ class VJPEmitter final
return;
}

// Check and reject functions with active inout arguments. It's not yet
// supported.
auto paramInfos = ai->getSubstCalleeConv().getParameters();
auto paramArgs = ai->getArgumentsWithoutIndirectResults();
for (unsigned i : swift::indices(paramInfos)) {
if (paramInfos[i].isIndirectInOut() &&
activityInfo.isActive(paramArgs[i], getIndices())) {
// Diagnose functions with active inout arguments.
// TODO(TF-129): Support `inout` argument differentiation.
for (auto inoutArg : ai->getInoutArguments()) {
if (activityInfo.isActive(inoutArg, getIndices())) {
context.emitNondifferentiabilityError(ai, invoker,
diag::autodiff_cannot_differentiate_through_inout_arguments);
errorOccurred = true;
Expand Down Expand Up @@ -5472,13 +5463,10 @@ class JVPEmitter final
return;
}

// Check and reject functions with active inout arguments. It's not yet
// supported.
auto paramInfos = ai->getSubstCalleeConv().getParameters();
auto paramArgs = ai->getArgumentsWithoutIndirectResults();
for (unsigned i : swift::indices(paramInfos)) {
if (paramInfos[i].isIndirectInOut() &&
activityInfo.isActive(paramArgs[i], getIndices())) {
// Diagnose functions with active inout arguments.
// TODO(TF-129): Support `inout` argument differentiation.
for (auto inoutArg : ai->getInoutArguments()) {
if (activityInfo.isActive(inoutArg, getIndices())) {
context.emitNondifferentiabilityError(ai, invoker,
diag::autodiff_cannot_differentiate_through_inout_arguments);
errorOccurred = true;
Expand Down
1 change: 0 additions & 1 deletion lib/SILOptimizer/Mandatory/Differentiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ inline void createEntryArguments(SILFunction *f) {
auto *decl = new (ctx) ParamDecl(loc, loc, Identifier(), loc,
Identifier(), moduleDecl);
decl->setSpecifier(ParamDecl::Specifier::Default);
// decl->setType(type.getASTType());
entry->createFunctionArgument(type, decl);
};
for (auto indResTy : conv.getIndirectSILResultTypes())
Expand Down
49 changes: 48 additions & 1 deletion test/AutoDiff/activity_analysis.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %target-swift-emit-sil -Xllvm -debug-only=differentiation 2>&1 %s | %FileCheck %s
// RUN: %target-swift-emit-sil -verify -Xllvm -debug-only=differentiation 2>&1 %s | %FileCheck %s

// Check that `@noDerivative` struct projections have "NONE" activity.

Expand Down Expand Up @@ -203,3 +203,50 @@ func TF_954(_ x: Float) -> Float {
// CHECK: bb5:
// CHECK: [ACTIVE] %40 = begin_access [read] [static] %2 : $*Float
// CHECK: [ACTIVE] %41 = load [trivial] %40 : $*Float

//===----------------------------------------------------------------------===//
// Non-differentiable functions
//===----------------------------------------------------------------------===//

// Check `inout` arguments.

// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
func activeInoutArg(_ x: Float) -> Float {
var result = x
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
result += x
return result
}

// CHECK-LABEL: [AD] Activity info for ${{.*}}activeInoutArg{{.*}} at (source=0 parameters=(0))
// CHECK: [ACTIVE] %0 = argument of bb0 : $Float
// CHECK: [ACTIVE] %2 = alloc_stack $Float, var, name "result"
// CHECK: [ACTIVE] %5 = begin_access [modify] [static] %2 : $*Float
// CHECK: [NONE] // function_ref static Float.+= infix(_:_:)
// CHECK: [NONE] %7 = apply %6(%5, %0, %4) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> ()
// CHECK: [ACTIVE] %9 = begin_access [read] [static] %2 : $*Float
// CHECK: [ACTIVE] %10 = load [trivial] %9 : $*Float

// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
func activeInoutArgNonactiveInitialResult(_ x: Float) -> Float {
var result: Float = 1
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
result += x
return result
}

// CHECK-LABEL: [AD] Activity info for ${{.*}}activeInoutArgNonactiveInitialResult{{.*}} at (source=0 parameters=(0))
// CHECK-LABEL: [ACTIVE] %0 = argument of bb0 : $Float
// CHECK-LABEL: [ACTIVE] %2 = alloc_stack $Float, var, name "result"
// CHECK-LABEL: [NONE] // function_ref Float.init(_builtinIntegerLiteral:)
// CHECK-LABEL: [USEFUL] %6 = apply %5(%3, %4) : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float
// CHECK-LABEL: [USEFUL] %8 = metatype $@thin Float.Type
// CHECK-LABEL: [ACTIVE] %9 = begin_access [modify] [static] %2 : $*Float
// CHECK-LABEL: [NONE] // function_ref static Float.+= infix(_:_:)
// CHECK-LABEL: [NONE] %11 = apply %10(%9, %0, %8) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> ()
// CHECK-LABEL: [ACTIVE] %13 = begin_access [read] [static] %2 : $*Float
// CHECK-LABEL: [ACTIVE] %14 = load [trivial] %13 : $*Float
55 changes: 50 additions & 5 deletions test/AutoDiff/differentiation_transform_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -270,23 +270,68 @@ func roundingGivesError(x: Float) -> Float {
// Inout arguments
//===----------------------------------------------------------------------===//

// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
func activeInoutArg(_ x: Float) -> Float {
var a = x
var result = x
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
a += x
return a
result += x
return result
}

// expected-error @+1 {{function is not differentiable}}
_ = pullback(at: .zero, in: activeInoutArg(_:))
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
func activeInoutArgNonactiveInitialResult(_ x: Float) -> Float {
var result: Float = 1
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
result += x
return result
}

// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
func activeInoutArgTuple(_ x: Float) -> Float {
var tuple = (x, x)
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
tuple.0 *= x
return x * tuple.0
}

// expected-error @+1 {{function is not differentiable}}
_ = pullback(at: .zero, in: activeInoutArgTuple(_:))
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
func activeInoutArgControlFlow(_ array: [Float]) -> Float {
var result: Float = 1
for i in withoutDerivative(at: array).indices {
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
result += array[i]
}
return result
}

// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
func activeInoutArgControlFlowComplex(_ array: [Float], _ bool: Bool) -> Float {
var result: Float = 1
if bool {
if bool {}
for i in withoutDerivative(at: array).indices {
switch i % 2 {
case 0: continue
case 1: break
default: break
}
result = result + 1
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
result += array[i]
}
}
return result
}

//===----------------------------------------------------------------------===//
// Non-varied results
Expand Down
28 changes: 21 additions & 7 deletions test/AutoDiff/forward_mode_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,37 @@ func calls_diff_of_nested(_ x: Float) -> Float {
// Inout arguments
//===----------------------------------------------------------------------===//

func activeInoutArg(_ x: Float) -> Float {
var a = x
// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
func activeInoutArgNonactiveInitialResult(_ x: Float) -> Float {
var result: Float = 1
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
a += x
return a
result += x
return result
}
// expected-error @+1 {{function is not differentiable}}
_ = differential(at: .zero, in: activeInoutArg(_:))

// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
func activeInoutArgTuple(_ x: Float) -> Float {
var tuple = (x, x)
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
tuple.0 *= x
return x * tuple.0
}

// expected-error @+1 {{function is not differentiable}}
_ = differential(at: .zero, in: activeInoutArgTuple(_:))
@differentiable
// expected-note @+2 {{when differentiating this function definition}}
// expected-note @+1 {{forward-mode differentiation does not yet support control flow}}
func activeInoutArgControlFlow(_ array: [Float]) -> Float {
var result: Float = 1
for i in withoutDerivative(at: array).indices {
result += array[i]
}
return result
}

//===----------------------------------------------------------------------===//
// Non-varied results
Expand Down