Skip to content

Commit 90765f7

Browse files
authored
Merge branch 'tensorflow' into marcrasi-use-diff-witness
2 parents 2d984fb + 21a4bc5 commit 90765f7

File tree

11 files changed

+355
-116
lines changed

11 files changed

+355
-116
lines changed

include/swift/AST/Attr.def

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -514,20 +514,21 @@ DECL_ATTR(differentiable, Differentiable,
514514
91)
515515
DECL_ATTR(differentiating, Differentiating,
516516
OnFunc | LongAttribute | AllowMultipleAttributes |
517-
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove |
517+
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
518518
NotSerialized, 92)
519519
SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
520520
OnAccessor | OnFunc | OnConstructor | OnSubscript |
521521
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove |
522522
NotSerialized, 93)
523523
SIMPLE_DECL_ATTR(noDerivative, NoDerivative,
524524
OnVar |
525-
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove,
525+
ABIBreakingToAdd | ABIBreakingToRemove | APIBreakingToAdd | APIBreakingToRemove,
526526
94)
527527
DECL_ATTR(transposing, Transposing,
528528
OnFunc | LongAttribute | AllowMultipleAttributes |
529-
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove |
529+
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
530530
NotSerialized, 96)
531+
// SWIFT_ENABLE_TENSORFLOW END
531532

532533
SIMPLE_DECL_ATTR(IBSegueAction, IBSegueAction,
533534
OnFunc |

include/swift/AST/DiagnosticsSIL.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,8 @@ NOTE(autodiff_when_differentiating_function_definition,none,
524524
"when differentiating this function definition", ())
525525
NOTE(autodiff_cannot_differentiate_through_inout_arguments,none,
526526
"cannot differentiate through 'inout' arguments", ())
527+
NOTE(autodiff_cannot_differentiate_through_multiple_results,none,
528+
"cannot differentiate through multiple results", ())
527529
NOTE(autodiff_class_member_not_supported,none,
528530
"differentiating class members is not yet supported", ())
529531
// TODO(TF-642): Remove when `partial_apply` works with `@differentiable`

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,20 +1513,6 @@ class DifferentiableActivityInfo {
15131513
const SILAutoDiffIndices &indices) const;
15141514
};
15151515

1516-
/// Given a parameter argument (not indirect result) and some differentiation
1517-
/// indices, figure out whether the parent function is being differentiated with
1518-
/// respect to this parameter, according to the indices.
1519-
static bool isDifferentiationParameter(SILArgument *argument,
1520-
IndexSubset *indices) {
1521-
if (!argument) return false;
1522-
auto *function = argument->getFunction();
1523-
auto paramArgs = function->getArgumentsWithoutIndirectResults();
1524-
for (unsigned i : indices->getIndices())
1525-
if (paramArgs[i] == argument)
1526-
return true;
1527-
return false;
1528-
}
1529-
15301516
/// For an `apply` instruction with active results, compute:
15311517
/// - The results of the `apply` instruction, in type order.
15321518
/// - The set of minimal parameter and result indices for differentiating the
@@ -1543,9 +1529,7 @@ static void collectMinimalIndicesForFunctionCall(
15431529
// Record all parameter indices in type order.
15441530
unsigned currentParamIdx = 0;
15451531
for (auto applyArg : ai->getArgumentsWithoutIndirectResults()) {
1546-
if (activityInfo.isVaried(applyArg, parentIndices.parameters) ||
1547-
isDifferentiationParameter(dyn_cast<SILArgument>(applyArg),
1548-
parentIndices.parameters))
1532+
if (activityInfo.isActive(applyArg, parentIndices))
15491533
paramIndices.push_back(currentParamIdx);
15501534
++currentParamIdx;
15511535
}
@@ -1566,13 +1550,12 @@ static void collectMinimalIndicesForFunctionCall(
15661550
if (res.isFormalDirect()) {
15671551
results.push_back(directResults[dirResIdx]);
15681552
if (auto dirRes = directResults[dirResIdx])
1569-
if (dirRes && activityInfo.isUseful(dirRes, parentIndices.source))
1553+
if (dirRes && activityInfo.isActive(dirRes, parentIndices))
15701554
resultIndices.push_back(idx);
15711555
++dirResIdx;
15721556
} else {
15731557
results.push_back(indirectResults[indResIdx]);
1574-
if (activityInfo.isUseful(indirectResults[indResIdx],
1575-
parentIndices.source))
1558+
if (activityInfo.isActive(indirectResults[indResIdx], parentIndices))
15761559
resultIndices.push_back(idx);
15771560
++indResIdx;
15781561
}
@@ -3813,10 +3796,12 @@ class VJPEmitter final
38133796
activeResultIndices.begin(), activeResultIndices.end(),
38143797
[&s](unsigned i) { s << i; }, [&s] { s << ", "; });
38153798
s << "}\n";);
3816-
// FIXME: We don't support multiple active results yet.
3799+
// Diagnose multiple active results.
3800+
// TODO(TF-983): Support multiple active results.
38173801
if (activeResultIndices.size() > 1) {
38183802
context.emitNondifferentiabilityError(
3819-
ai, invoker, diag::autodiff_expression_not_differentiable_note);
3803+
ai, invoker,
3804+
diag::autodiff_cannot_differentiate_through_multiple_results);
38203805
errorOccurred = true;
38213806
return;
38223807
}
@@ -5492,13 +5477,16 @@ class JVPEmitter final
54925477
activeResultIndices.begin(), activeResultIndices.end(),
54935478
[&s](unsigned i) { s << i; }, [&s] { s << ", "; });
54945479
s << "}\n";);
5495-
// FIXME: We don't support multiple active results yet.
5480+
// Diagnose multiple active results.
5481+
// TODO(TF-983): Support multiple active results.
54965482
if (activeResultIndices.size() > 1) {
54975483
context.emitNondifferentiabilityError(
5498-
ai, invoker, diag::autodiff_expression_not_differentiable_note);
5484+
ai, invoker,
5485+
diag::autodiff_cannot_differentiate_through_multiple_results);
54995486
errorOccurred = true;
55005487
return;
55015488
}
5489+
55025490
// Form expected indices, assuming there's only one result.
55035491
SILAutoDiffIndices indices(
55045492
activeResultIndices.front(),
@@ -6257,6 +6245,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
62576245
// Adjoint values of dominated active values are passed as pullback block
62586246
// arguments.
62596247
DominanceOrder domOrder(original.getEntryBlock(), domInfo);
6248+
// Keep track of visited values.
6249+
SmallPtrSet<SILValue, 8> visited;
62606250
while (auto *bb = domOrder.getNext()) {
62616251
auto &bbActiveValues = activeValues[bb];
62626252
// If the current block has an immediate dominator, append the immediate
@@ -6266,13 +6256,12 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
62666256
bbActiveValues.append(domBBActiveValues.begin(),
62676257
domBBActiveValues.end());
62686258
}
6269-
SmallPtrSet<SILValue, 8> visited(bbActiveValues.begin(),
6270-
bbActiveValues.end());
6271-
// Register a value as active if it has not yet been visited.
62726259
bool diagnosedActiveEnumValue = false;
6273-
auto addActiveValue = [&](SILValue v) {
6260+
// Mark the activity of a value if it has not yet been visited.
6261+
auto markValueActivity = [&](SILValue v) {
62746262
if (visited.count(v))
62756263
return;
6264+
visited.insert(v);
62766265
// Diagnose active enum values. Differentiation of enum values requires
62776266
// special adjoint value handling and is not yet supported. Diagnose
62786267
// only the first active enum value to prevent too many diagnostics.
@@ -6288,17 +6277,19 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
62886277
// become projections into their adjoint base buffer.
62896278
if (Projection::isAddressProjection(v))
62906279
return;
6291-
visited.insert(v);
62926280
bbActiveValues.push_back(v);
62936281
};
6294-
// Register bb arguments and all instruction operands/results.
6282+
// Visit bb arguments and all instruction operands/results.
6283+
for (auto *arg : bb->getArguments())
6284+
if (getActivityInfo().isActive(arg, getIndices()))
6285+
markValueActivity(arg);
62956286
for (auto &inst : *bb) {
62966287
for (auto op : inst.getOperandValues())
62976288
if (getActivityInfo().isActive(op, getIndices()))
6298-
addActiveValue(op);
6289+
markValueActivity(op);
62996290
for (auto result : inst.getResults())
63006291
if (getActivityInfo().isActive(result, getIndices()))
6301-
addActiveValue(result);
6292+
markValueActivity(result);
63026293
}
63036294
domOrder.pushChildren(bb);
63046295
if (errorOccurred)

lib/Serialization/DeserializeSIL.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,9 @@ SILDeserializer::readSILFunctionChecked(DeclID FID, SILFunction *existingFn,
799799
// SIL_VTABLE or SIL_GLOBALVAR or SIL_WITNESS_TABLE record also means the end
800800
// of this SILFunction.
801801
while (kind != SIL_FUNCTION && kind != SIL_VTABLE && kind != SIL_GLOBALVAR &&
802-
kind != SIL_WITNESS_TABLE) {
802+
// SWIFT_ENABLE_TENSORFLOW
803+
kind != SIL_WITNESS_TABLE && kind != SIL_DIFFERENTIABILITY_WITNESS) {
804+
// SWIFT_ENABLE_TENSORFLOW END
803805
if (kind == SIL_BASIC_BLOCK)
804806
// Handle a SILBasicBlock record.
805807
CurrentBB = readSILBasicBlock(fn, CurrentBB, scratch);

lib/Serialization/SerializeSIL.cpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,12 @@ namespace {
201201
/// Global variables that we've emitted a reference to.
202202
llvm::DenseSet<const SILGlobalVariable *> GlobalsToEmit;
203203

204+
// SWIFT_ENABLE_TENSORFLOW
205+
/// Referenced differentiability witnesses that need to be emitted.
206+
llvm::DenseSet<const SILDifferentiabilityWitness *>
207+
DifferentiabilityWitnessesToEmit;
208+
// SWIFT_ENABLE_TENSORFLOW END
209+
204210
/// Additional functions we might need to serialize.
205211
llvm::SmallVector<const SILFunction *, 16> Worklist;
206212

@@ -1061,6 +1067,7 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) {
10611067
case SILInstructionKind::DifferentiabilityWitnessFunctionInst: {
10621068
auto *dwfi = cast<DifferentiabilityWitnessFunctionInst>(&SI);
10631069
auto *witness = dwfi->getWitness();
1070+
DifferentiabilityWitnessesToEmit.insert(witness);
10641071
Mangle::ASTMangler mangler;
10651072
auto mangledKey = mangler.mangleSILDifferentiabilityWitnessKey(
10661073
witness->getKey());
@@ -2718,17 +2725,6 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) {
27182725
writeSILDefaultWitnessTable(wt);
27192726
}
27202727

2721-
// SWIFT_ENABLE_TENSORFLOW
2722-
// Write out differentiability witnesses.
2723-
for (const auto &diffWitness : SILMod->getDifferentiabilityWitnessList()) {
2724-
// TODO(TF-893): Consider checking
2725-
// `SILMod->shouldSerializeEntitiesAssociatedWithDeclContext` on the JVP/VJP
2726-
// functions.
2727-
if ((ShouldSerializeAll || diffWitness.isSerialized()))
2728-
writeSILDifferentiabilityWitness(diffWitness);
2729-
}
2730-
// SWIFT_ENABLE_TENSORFLOW END
2731-
27322728
// Emit only declarations if it is a module with pre-specializations.
27332729
// And only do it in optimized builds.
27342730
bool emitDeclarationsForOnoneSupport =
@@ -2751,6 +2747,26 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) {
27512747
processSILFunctionWorklist();
27522748
}
27532749

2750+
// SWIFT_ENABLE_TENSORFLOW
2751+
// Write out differentiability witnesses.
2752+
// Note: this must be done after visiting SIL functions above so that
2753+
// differentiability witness references (`differentiability_witness_function`
2754+
// instructions) have been tracked.
2755+
for (const auto &diffWitness : SILMod->getDifferentiabilityWitnessList()) {
2756+
// TODO(TF-893): Consider checking
2757+
// `SILMod->shouldSerializeEntitiesAssociatedWithDeclContext` on the JVP/VJP
2758+
// functions.
2759+
if ((ShouldSerializeAll || diffWitness.isSerialized()))
2760+
DifferentiabilityWitnessesToEmit.insert(&diffWitness);
2761+
}
2762+
for (auto *diffWitness : DifferentiabilityWitnessesToEmit)
2763+
writeSILDifferentiabilityWitness(*diffWitness);
2764+
// Process SIL functions referenced by differentiability witnesses.
2765+
// Note: this is necessary despite processing `FuncsToEmit` below because
2766+
// `Worklist` is processed separately.
2767+
processSILFunctionWorklist();
2768+
// SWIFT_ENABLE_TENSORFLOW END
2769+
27542770
// Now write function declarations for every function we've
27552771
// emitted a reference to without emitting a function body for.
27562772
for (const SILFunction &F : *SILMod) {

test/AutoDiff/activity_analysis.swift

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: %target-swift-emit-sil -verify -Xllvm -debug-only=differentiation 2>&1 %s | %FileCheck %s
2+
// REQUIRES: asserts
23

34
// Check that `@noDerivative` struct projections have "NONE" activity.
45

@@ -204,6 +205,106 @@ func TF_954(_ x: Float) -> Float {
204205
// CHECK: [ACTIVE] %40 = begin_access [read] [static] %2 : $*Float
205206
// CHECK: [ACTIVE] %41 = load [trivial] %40 : $*Float
206207

208+
// Check `inout` argument activity.
209+
210+
struct Mut: Differentiable {}
211+
extension Mut {
212+
@differentiable(wrt: x)
213+
mutating func mutatingMethod(_ x: Mut) -> Mut {
214+
return x
215+
}
216+
}
217+
218+
// CHECK-LABEL: [AD] Activity info for $s17activity_analysis3MutV14mutatingMethodyA2CF at (source=0 parameters=(0))
219+
// CHECK: [ACTIVE] %0 = argument of bb0 : $Mut
220+
// CHECK: [NONE] %1 = argument of bb0 : $*Mut
221+
222+
// TODO(TF-985): Find workaround to avoid marking non-wrt `inout` argument as
223+
// active.
224+
// expected-error @+1 {{function is not differentiable}}
225+
@differentiable(wrt: x)
226+
// expected-note @+1 {{when differentiating this function definition}}
227+
func nonActiveInoutArg(_ nonactive: inout Mut, _ x: Mut) -> Mut {
228+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
229+
return nonactive.mutatingMethod(x)
230+
}
231+
232+
// CHECK-LABEL: [AD] Activity info for $s17activity_analysis17nonActiveInoutArgyAA3MutVADz_ADtF at (source=0 parameters=(1))
233+
// CHECK: [ACTIVE] %0 = argument of bb0 : $*Mut
234+
// CHECK: [ACTIVE] %1 = argument of bb0 : $Mut
235+
// CHECK: [ACTIVE] %4 = begin_access [modify] [static] %0 : $*Mut
236+
// CHECK: [NONE] // function_ref Mut.mutatingMethod(_:)
237+
// CHECK: [ACTIVE] %6 = apply %5(%1, %4) : $@convention(method) (Mut, @inout Mut) -> Mut
238+
239+
// expected-error @+1 {{function is not differentiable}}
240+
@differentiable(wrt: x)
241+
// expected-note @+1 {{when differentiating this function definition}}
242+
func activeInoutArgMutatingMethod(_ x: Mut) -> Mut {
243+
var result = x
244+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
245+
result = result.mutatingMethod(result)
246+
return result
247+
}
248+
249+
// CHECK-LABEL: [AD] Activity info for $s17activity_analysis28activeInoutArgMutatingMethodyAA3MutVADF at (source=0 parameters=(0))
250+
// CHECK: [ACTIVE] %0 = argument of bb0 : $Mut
251+
// CHECK: [ACTIVE] %2 = alloc_stack $Mut, var, name "result"
252+
// CHECK: [ACTIVE] %4 = begin_access [read] [static] %2 : $*Mut
253+
// CHECK: [ACTIVE] %5 = load [trivial] %4 : $*Mut
254+
// CHECK: [ACTIVE] %7 = begin_access [modify] [static] %2 : $*Mut
255+
// CHECK: [NONE] // function_ref Mut.mutatingMethod(_:)
256+
// CHECK: [ACTIVE] %9 = apply %8(%5, %7) : $@convention(method) (Mut, @inout Mut) -> Mut
257+
// CHECK: [ACTIVE] %11 = begin_access [modify] [static] %2 : $*Mut
258+
// CHECK: [ACTIVE] %14 = begin_access [read] [static] %2 : $*Mut
259+
// CHECK: [ACTIVE] %15 = load [trivial] %14 : $*Mut
260+
261+
// expected-error @+1 {{function is not differentiable}}
262+
@differentiable(wrt: x)
263+
// expected-note @+1 {{when differentiating this function definition}}
264+
func activeInoutArgMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) -> Mut {
265+
var result = nonactive
266+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
267+
result = result.mutatingMethod(x)
268+
return result
269+
}
270+
271+
// CHECK_LABEL: [AD] Activity info for $s17activity_analysis31activeInoutArgMutatingMethodVaryAA3MutVADz_ADtF at (source=0 parameters=(1))
272+
// CHECK: [USEFUL] %0 = argument of bb0 : $*Mut
273+
// CHECK: [ACTIVE] %1 = argument of bb0 : $Mut
274+
// CHECK: [ACTIVE] %4 = alloc_stack $Mut, var, name "result"
275+
// CHECK: [USEFUL] %5 = begin_access [read] [static] %0 : $*Mut
276+
// CHECK: [ACTIVE] %8 = begin_access [modify] [static] %4 : $*Mut
277+
// CHECK: [NONE] // function_ref Mut.mutatingMethod(_:)
278+
// CHECK: [ACTIVE] %10 = apply %9(%1, %8) : $@convention(method) (Mut, @inout Mut) -> Mut
279+
// CHECK: [ACTIVE] %12 = begin_access [modify] [static] %4 : $*Mut
280+
// CHECK: [ACTIVE] %15 = begin_access [read] [static] %4 : $*Mut
281+
// CHECK: [ACTIVE] %16 = load [trivial] %15 : $*Mut
282+
283+
// expected-error @+1 {{function is not differentiable}}
284+
@differentiable(wrt: x)
285+
// expected-note @+1 {{when differentiating this function definition}}
286+
func activeInoutArgMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) -> Mut {
287+
var result = (nonactive, x)
288+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
289+
let result2 = result.0.mutatingMethod(result.0)
290+
return result2
291+
}
292+
293+
// CHECK-LABEL: [AD] Activity info for $s17activity_analysis33activeInoutArgMutatingMethodTupleyAA3MutVADz_ADtF at (source=0 parameters=(1))
294+
// CHECK: [USEFUL] %0 = argument of bb0 : $*Mut
295+
// CHECK: [ACTIVE] %1 = argument of bb0 : $Mut
296+
// CHECK: [ACTIVE] %4 = alloc_stack $(Mut, Mut), var, name "result"
297+
// CHECK: [ACTIVE] %5 = tuple_element_addr %4 : $*(Mut, Mut), 0
298+
// CHECK: [ACTIVE] %6 = tuple_element_addr %4 : $*(Mut, Mut), 1
299+
// CHECK: [USEFUL] %7 = begin_access [read] [static] %0 : $*Mut
300+
// CHECK: [ACTIVE] %11 = begin_access [read] [static] %4 : $*(Mut, Mut)
301+
// CHECK: [ACTIVE] %12 = tuple_element_addr %11 : $*(Mut, Mut), 0
302+
// CHECK: [ACTIVE] %13 = load [trivial] %12 : $*Mut
303+
// CHECK: [ACTIVE] %15 = begin_access [modify] [static] %4 : $*(Mut, Mut)
304+
// CHECK: [ACTIVE] %16 = tuple_element_addr %15 : $*(Mut, Mut), 0
305+
// CHECK: [NONE] // function_ref Mut.mutatingMethod(_:)
306+
// CHECK: [ACTIVE] %18 = apply %17(%13, %16) : $@convention(method) (Mut, @inout Mut) -> Mut
307+
207308
//===----------------------------------------------------------------------===//
208309
// Non-differentiable functions
209310
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)