From 0d7b6ec83dd1795b2640796ab212fd82a701ce10 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Tue, 9 Apr 2019 16:28:26 +0200 Subject: [PATCH] [AutoDiff] Enable differentiation wrt "subset indices". Remove this error: `function is differentiable only with respect to a smaller subset of arguments`. The differentiation pass (and "minimal indices" logic) actually supports such differentiation wrt "subset indices". All that is needed is to remove the deprecated error. --- include/swift/AST/DiagnosticsSIL.def | 3 --- lib/SILOptimizer/Mandatory/Differentiation.cpp | 7 ------- test/AutoDiff/simple_math.swift | 8 ++++++++ 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/include/swift/AST/DiagnosticsSIL.def b/include/swift/AST/DiagnosticsSIL.def index acea284180310..d24955141f5b1 100644 --- a/include/swift/AST/DiagnosticsSIL.def +++ b/include/swift/AST/DiagnosticsSIL.def @@ -396,9 +396,6 @@ NOTE(autodiff_protocol_member_not_differentiable,none, NOTE(autodiff_protocol_member_subset_indices_not_differentiable,none, "member is differentiable only with respect to a smaller subset of " "arguments", ()) -NOTE(autodiff_function_subset_indices_not_differentiable,none, - "function is differentiable only with respect to a smaller subset of " - "arguments", ()) NOTE(autodiff_function_assoc_func_requirements_unmet,none, "function call is not differentiable because generic requirements are not " "met", ()) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index fedac4c75cba9..c7e17179579ad 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -1793,13 +1793,6 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder, if (autodiffFnType->isDifferentiable()) { SILValue assocFn = builder.createAutoDiffFunctionExtract( original.getLoc(), kind, /*differentiationOrder*/ 1, functionSource); - if (autodiffFnType->getDifferentiationParameterIndices().test( - desiredIndices.parameters)) { - context.emitNondifferentiabilityError( - original, parentTask, - diag::autodiff_function_subset_indices_not_differentiable); - return None; - } SILAutoDiffIndices indices(0, desiredIndices.parameters); return std::make_pair(assocFn, indices); } diff --git a/test/AutoDiff/simple_math.swift b/test/AutoDiff/simple_math.swift index 63e6a751e1da7..a202b9fb5699a 100644 --- a/test/AutoDiff/simple_math.swift +++ b/test/AutoDiff/simple_math.swift @@ -308,4 +308,12 @@ SimpleMathTests.test("StructGeneric") { expectEqual(405, gradient(at: 3, in: fifthPower)) } +SimpleMathTests.test("SubsetIndices") { + func train(_ lossFunction: @differentiable (Float, Float) -> Float) { + let y = Float(0) + _ = gradient(at: 0) { x in lossFunction(x, y) } + } + train { x, y in x + y } +} + runAllTests()