diff --git a/include/swift/AST/DiagnosticsSIL.def b/include/swift/AST/DiagnosticsSIL.def index acea284180310..d24955141f5b1 100644 --- a/include/swift/AST/DiagnosticsSIL.def +++ b/include/swift/AST/DiagnosticsSIL.def @@ -396,9 +396,6 @@ NOTE(autodiff_protocol_member_not_differentiable,none, NOTE(autodiff_protocol_member_subset_indices_not_differentiable,none, "member is differentiable only with respect to a smaller subset of " "arguments", ()) -NOTE(autodiff_function_subset_indices_not_differentiable,none, - "function is differentiable only with respect to a smaller subset of " - "arguments", ()) NOTE(autodiff_function_assoc_func_requirements_unmet,none, "function call is not differentiable because generic requirements are not " "met", ()) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index fedac4c75cba9..c7e17179579ad 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -1793,13 +1793,6 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder, if (autodiffFnType->isDifferentiable()) { SILValue assocFn = builder.createAutoDiffFunctionExtract( original.getLoc(), kind, /*differentiationOrder*/ 1, functionSource); - if (autodiffFnType->getDifferentiationParameterIndices().test( - desiredIndices.parameters)) { - context.emitNondifferentiabilityError( - original, parentTask, - diag::autodiff_function_subset_indices_not_differentiable); - return None; - } SILAutoDiffIndices indices(0, desiredIndices.parameters); return std::make_pair(assocFn, indices); } diff --git a/test/AutoDiff/simple_math.swift b/test/AutoDiff/simple_math.swift index 63e6a751e1da7..a202b9fb5699a 100644 --- a/test/AutoDiff/simple_math.swift +++ b/test/AutoDiff/simple_math.swift @@ -308,4 +308,12 @@ SimpleMathTests.test("StructGeneric") { expectEqual(405, gradient(at: 3, in: fifthPower)) } +SimpleMathTests.test("SubsetIndices") { + func train(_ lossFunction: @differentiable (Float, Float) -> Float) { + let y = Float(0) + _ = gradient(at: 0) { x in lossFunction(x, y) } + } + train { x, y in x + y } +} + runAllTests()