diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index e4cf35742e836..c317ddcbfdbf3 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -4284,6 +4284,10 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode, } // SWIFT_ENABLE_TENSORFLOW + DifferentiabilityKind getDifferentiabilityKind() const { + return getExtInfo().getDifferentiabilityKind(); + } + bool isDifferentiable() const { return getExtInfo().isDifferentiable(); } diff --git a/lib/Serialization/Deserialization.cpp b/lib/Serialization/Deserialization.cpp index 5dba0b025fa91..e603093dbd1b6 100644 --- a/lib/Serialization/Deserialization.cpp +++ b/lib/Serialization/Deserialization.cpp @@ -4401,6 +4401,23 @@ Optional getActualParameterConvention(uint8_t raw) { return None; } +// SWIFT_ENABLE_TENSORFLOW +/// Translate from the serialization DifferentiabilityKind enumerators, +/// which are guaranteed to be stable, to the AST ones. +static Optional +getActualDifferentiabilityKind(uint8_t raw) { + switch (serialization::DifferentiabilityKind(raw)) { +#define CASE(ID) \ + case serialization::DifferentiabilityKind::ID: \ + return swift::DifferentiabilityKind::ID; + CASE(NonDifferentiable) + CASE(Normal) + CASE(Linear) +#undef CASE + } + return None; +} + /// Translate from the serialization SILParameterDifferentiability enumerators, /// which are guaranteed to be stable, to the AST ones. static Optional @@ -4412,8 +4429,8 @@ getActualSILParameterDifferentiability(uint8_t raw) { CASE(DifferentiableOrNotApplicable) CASE(NotDifferentiable) } - return None; #undef CASE + return None; } /// Translate from the serialization ResultConvention enumerators, @@ -5010,7 +5027,7 @@ class swift::TypeDeserializer { bool pseudogeneric = false; bool noescape; // SWIFT_ENABLE_TENSORFLOW - bool differentiable; + uint8_t rawDifferentiabilityKind; bool hasErrorResult; unsigned numParams; unsigned numYields; @@ -5025,7 +5042,7 @@ class swift::TypeDeserializer { pseudogeneric, noescape, // SWIFT_ENABLE_TENSORFLOW - differentiable, + rawDifferentiabilityKind, hasErrorResult, numParams, numYields, @@ -5038,9 +5055,12 @@ class swift::TypeDeserializer { = getActualSILFunctionTypeRepresentation(rawRepresentation); if (!representation.hasValue()) MF.fatal(); - auto kind = DifferentiabilityKind((unsigned)differentiable); + auto differentiabilityKind = + getActualDifferentiabilityKind(rawDifferentiabilityKind); + if (!differentiabilityKind.hasValue()) + MF.fatal(); SILFunctionType::ExtInfo extInfo(*representation, pseudogeneric, - noescape, kind); + noescape, *differentiabilityKind); // Process the coroutine kind. auto coroutineKind = getActualSILCoroutineKind(rawCoroutineKind); @@ -5065,7 +5085,7 @@ class swift::TypeDeserializer { // SWIFT_ENABLE_TENSORFLOW auto paramDiff = swift::SILParameterDifferentiability::DifferentiableOrNotApplicable; - if (differentiable) { + if (differentiabilityKind != DifferentiabilityKind::NonDifferentiable) { auto paramDiffOpt = getActualSILParameterDifferentiability(rawParamDiff); if (!paramDiffOpt) { @@ -5102,7 +5122,9 @@ class swift::TypeDeserializer { // Bounds check. FIXME: overflow // SWIFT_ENABLE_TENSORFLOW - unsigned entriesPerParam = differentiable ? 3 : 2; + unsigned entriesPerParam = + differentiabilityKind != DifferentiabilityKind::NonDifferentiable + ? 3 : 2; if (entriesPerParam * numParams + 2 * numResults + 2 * unsigned(hasErrorResult) > variableData.size()) { @@ -5119,7 +5141,7 @@ class swift::TypeDeserializer { auto rawConvention = variableData[nextVariableDataIndex++]; // SWIFT_ENABLE_TENSORFLOW uint64_t paramDiff = 0; - if (differentiable) + if (differentiabilityKind != DifferentiabilityKind::NonDifferentiable) paramDiff = variableData[nextVariableDataIndex++]; auto param = processParameter(typeID, rawConvention, paramDiff); if (!param) diff --git a/lib/Serialization/ModuleFormat.h b/lib/Serialization/ModuleFormat.h index a3af7336da7c0..3d696ce54d753 100644 --- a/lib/Serialization/ModuleFormat.h +++ b/lib/Serialization/ModuleFormat.h @@ -52,7 +52,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0; /// describe what change you made. The content of this comment isn't important; /// it just ensures a conflict if two people change the module format. /// Don't worry about adhering to the 80-column limit for this line. -const uint16_t SWIFTMODULE_VERSION_MINOR = 521; // remove order from 'differentiation_function' layout +const uint16_t SWIFTMODULE_VERSION_MINOR = 522; // Add SIL function type DifferentiabilityKind field /// A standard hash seed used for all string hashes in a serialized module. /// @@ -327,6 +327,15 @@ enum class ParameterConvention : uint8_t { using ParameterConventionField = BCFixed<4>; // SWIFT_ENABLE_TENSORFLOW +// These IDs must \em not be renumbered or reordered without incrementing the +// module version. +enum class DifferentiabilityKind : uint8_t { + NonDifferentiable = 0, + Normal = 1, + Linear = 2 +}; +using DifferentiabilityKindField = BCFixed<2>; + // These IDs must \em not be renumbered or reordered without incrementing // the module version. enum class SILParameterDifferentiability : uint8_t { @@ -951,7 +960,7 @@ namespace decls_block { BCFixed<1>, // pseudogeneric? BCFixed<1>, // noescape? // SWIFT_ENABLE_TENSORFLOW - BCFixed<1>, // differentiable? + DifferentiabilityKindField, // differentiability kind BCFixed<1>, // error result? BCVBR<6>, // number of parameters BCVBR<5>, // number of yields diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index dcffb9d7848d6..0aee19402dda2 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -3661,6 +3661,19 @@ static uint8_t getRawStableSILCoroutineKind( llvm_unreachable("bad kind"); } +// SWIFT_ENABLE_TENSORFLOW +/// Translate from the AST differentiability kind enum to the Serialization enum +/// values, which are guaranteed to be stable. +static uint8_t getRawStableDifferentiabilityKind( + swift::DifferentiabilityKind kind) { + switch (kind) { + SIMPLE_CASE(DifferentiabilityKind, NonDifferentiable) + SIMPLE_CASE(DifferentiabilityKind, Normal) + SIMPLE_CASE(DifferentiabilityKind, Linear) + } + llvm_unreachable("bad differentiability kind"); +} + /// Translate from the AST ownership enum to the Serialization enum /// values, which are guaranteed to be stable. static uint8_t @@ -4015,8 +4028,11 @@ class Serializer::TypeSerializer : public TypeVisitor { using namespace decls_block; auto representation = fnTy->getRepresentation(); + // SWIFT_ENABLE_TENSORFLOW auto stableRepresentation = - getRawStableSILFunctionTypeRepresentation(representation); + getRawStableSILFunctionTypeRepresentation(representation); + auto stableDifferentiabilityKind = + getRawStableDifferentiabilityKind(fnTy->getDifferentiabilityKind()); SmallVector variableData; for (auto param : fnTy->getParameters()) { @@ -4059,7 +4075,7 @@ class Serializer::TypeSerializer : public TypeVisitor { stableCoroutineKind, stableCalleeConvention, stableRepresentation, fnTy->isPseudogeneric(), fnTy->isNoEscape(), // SWIFT_ENABLE_TENSORFLOW - fnTy->isDifferentiable(), fnTy->hasErrorResult(), + stableDifferentiabilityKind, fnTy->hasErrorResult(), fnTy->getParameters().size(), fnTy->getNumYields(), fnTy->getNumResults(), S.addGenericSignatureRef(sig), variableData); diff --git a/test/AutoDiff/differentiable_func_type.sil b/test/AutoDiff/differentiable_func_type.sil new file mode 100644 index 0000000000000..2defc5fe001fa --- /dev/null +++ b/test/AutoDiff/differentiable_func_type.sil @@ -0,0 +1,29 @@ +// RUN: %empty-directory(%t) +// RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name differentiable_func_type +// RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.2.sib -module-name differentiable_func_type +// RUN: %target-sil-opt %t/tmp.2.sib -module-name differentiable_func_type | %FileCheck %s + +sil_stage raw + +import Swift + +sil @takeAndReturnLinear : $@convention(thin) (@differentiable(linear) (Float) -> Float) -> @differentiable(linear) (Float) -> Float { +bb0(%0 : $@differentiable(linear) (Float) -> Float): + return %0 : $@differentiable(linear) (Float) -> Float +} + +// CHECK-LABEL: sil @takeAndReturnLinear : $@convention(thin) (@differentiable(linear) (Float) -> Float) -> @differentiable(linear) (Float) -> Float { +// CHECK: bb0([[ARG:%.*]] : $@differentiable(linear) (Float) -> Float): +// CHECK: return [[ARG]] : $@differentiable(linear) (Float) -> Float +// CHECK: } + + +sil @takeAndReturnDifferentiable : $@convention(thin) (@differentiable (Float) -> Float) -> @differentiable (Float) -> Float { +bb0(%0 : $@differentiable (Float) -> Float): + return %0 : $@differentiable (Float) -> Float +} + +// CHECK-LABEL: sil @takeAndReturnDifferentiable : $@convention(thin) (@differentiable (Float) -> Float) -> @differentiable (Float) -> Float { +// CHECK: bb0([[ARG:%.*]] : $@differentiable (Float) -> Float): +// CHECK: return [[ARG]] : $@differentiable (Float) -> Float +// CHECK: } \ No newline at end of file diff --git a/test/AutoDiff/differentiable_function_inst.sil b/test/AutoDiff/differentiable_function_inst.sil index f6b3a4488cad4..8e9df1c6adc74 100644 --- a/test/AutoDiff/differentiable_function_inst.sil +++ b/test/AutoDiff/differentiable_function_inst.sil @@ -10,6 +10,35 @@ sil_stage raw import Swift import Builtin +sil @examplefunc : $@convention(thin) (Float, Float, Float) -> Float +sil @examplemethod : $@convention(method) (Float, Float, Float) -> Float + +// CHECK-LABEL: sil @test +sil @test : $@convention(thin) () -> () { +bb0: + %0 = function_ref @examplefunc : $@convention(thin) (Float, Float, Float) -> Float + %1 = differentiable_function [wrt 0 1 2] %0 : $@convention(thin) (Float, Float, Float) -> Float + + // CHECK: %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float + %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float + %3 = differentiable_function [wrt 0] %0 : $@convention(thin) (Float, Float, Float) -> Float + + // CHECK: %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float + %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float + %5 = function_ref @examplemethod : $@convention(method) (Float, Float, Float) -> Float + %6 = differentiable_function [wrt 0 1 2] %5 : $@convention(method) (Float, Float, Float) -> Float + + // CHECK: %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float + %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float + %8 = differentiable_function [wrt 0] %5 : $@convention(method) (Float, Float, Float) -> Float + + // CHECK: %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float + %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float + + %ret = tuple () + return %ret : $() +} + // The adjoint function emitted by the compiler. Parameter are a vector, as in // vector-Jacobian products, and pullback values. The function is partially // applied to a pullback struct to form a pullback, which takes a vector and diff --git a/test/AutoDiff/differentiable_sil_function_type_parse.sil b/test/AutoDiff/differentiable_sil_function_type_parse.sil deleted file mode 100644 index 05d28f36a8201..0000000000000 --- a/test/AutoDiff/differentiable_sil_function_type_parse.sil +++ /dev/null @@ -1,46 +0,0 @@ -// RUN: %target-sil-opt %s -module-name=autodiff_sil_function_type_parse | %target-sil-opt -module-name=autodiff_sil_function_type_parse | %FileCheck %s - -sil_stage raw - -import Swift - -sil @examplefunc : $@convention(thin) (Float, Float, Float) -> Float - -sil @examplemethod : $@convention(method) (Float, Float, Float) -> Float - -// CHECK-LABEL: sil @test -sil @test : $@convention(thin) () -> () { -bb0: - %0 = function_ref @examplefunc : $@convention(thin) (Float, Float, Float) -> Float - - %1 = differentiable_function [wrt 0 1 2] %0 : $@convention(thin) (Float, Float, Float) -> Float - // CHECK: %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float - %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float - - %3 = differentiable_function [wrt 0] %0 : $@convention(thin) (Float, Float, Float) -> Float - // CHECK: %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float - %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float - - %5 = function_ref @examplemethod : $@convention(method) (Float, Float, Float) -> Float - - %6 = differentiable_function [wrt 0 1 2] %5 : $@convention(method) (Float, Float, Float) -> Float - // CHECK: %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float - %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float - - %8 = differentiable_function [wrt 0] %5 : $@convention(method) (Float, Float, Float) -> Float - // CHECK: %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float - %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float - - %ret = tuple () - return %ret : $() -} - -sil @takeAndReturnLinear : $@convention(thin) (@differentiable(linear) (Float) -> Float) -> @differentiable(linear) (Float) -> Float { -bb0(%0 : $@differentiable(linear) (Float) -> Float): - return %0 : $@differentiable(linear) (Float) -> Float -} - -// CHECK-LABEL: sil @takeAndReturnLinear : $@convention(thin) (@differentiable(linear) (Float) -> Float) -> @differentiable(linear) (Float) -> Float { -// CHECK: bb0([[ARG:%.*]] : $@differentiable(linear) (Float) -> Float): -// CHECK: return [[ARG]] : $@differentiable(linear) (Float) -> Float -// CHECK: }