diff --git a/docs/SIL.rst b/docs/SIL.rst index f2d4058ffdc77..412d556aa4fd9 100644 --- a/docs/SIL.rst +++ b/docs/SIL.rst @@ -5671,12 +5671,10 @@ differentiable_function_extract :: sil-instruction ::= 'differentiable_function_extract' - sil-differentiable-function-extractee + '[' sil-differentiable-function-extractee ']' sil-value ':' sil-type - sil-differentiable-function-extractee ::= - '[' sil-differentiable-function-extractee ']' - sil-differentiable-function-extractee-name ::= 'original' | 'jvp' | 'vjp' + sil-differentiable-function-extractee ::= 'original' | 'jvp' | 'vjp' differentiable_function_extract [original] %0 : $@differentiable (T) -> T differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T @@ -5692,12 +5690,10 @@ linear_function_extract :: sil-instruction ::= 'linear_function_extract' - sil-linear-function-extractee + '[' sil-linear-function-extractee ']' sil-value ':' sil-type - sil-linear-function-extractee ::= - '[' sil-linear-function-extractee ']' - sil-linear-function-extractee-name ::= 'original' | 'transpose' + sil-linear-function-extractee ::= 'original' | 'transpose' linear_function_extract [original] %0 : $@differentiable(linear) (T) -> T linear_function_extract [transpose] %0 : $@differentiable(linear) (T) -> T @@ -5707,6 +5703,40 @@ Extracts the original function or a transpose function from the given ``[original]`` or ``[transpose]``. +differentiability_witness_function +`````````````````````````````````` +:: + + sil-instruction ::= + 'differentiability_witness_function' + '[' sil-differentiability-witness-function-kind ']' + '[' 'parameters' sil-differentiability-witness-function-index-list ']' + '[' 'results' sil-differentiability-witness-function-index-list ']' + generic-parameter-clause? + sil-function-name ':' sil-type + + sil-differentiability-witness-function-kind ::= 'jvp' | 'vjp' | 'transpose' + sil-differentiability-witness-function-index-list ::= [0-9]+ (' ' [0-9]+)* + + differentiability_witness_function [jvp] [parameters 0] [results 0] \ + @foo : $(T) -> T + +Looks up the differentiability witness function for the referenced function +using SIL differentiability witnesses. + +The differentiability witness function kind identifies the witness function to +look up: ``[jvp]``, ``[vjp]``, or ``[transpose]``. + +The remaining components identify the SIL differentiability witness: + +- Original function name. +- Parameter indices. +- Result indices. +- Witness generic parameter clause (optional). When parsing SIL, the parsed + witness generic parameter clause is combined with the original function's + generic signature to form the full witness generic signature. + + Assertion configuration ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index a4fdca1e4f74a..8d86ed03b8db3 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -79,6 +79,27 @@ struct AutoDiffDerivativeFunctionKind { } }; +/// The kind of a differentiability witness function. +struct DifferentiabilityWitnessFunctionKind { + enum innerty : uint8_t { + // The Jacobian-vector products function. + JVP = 0, + // The vector-Jacobian products function. + VJP = 1, + // The transpose function. + Transpose = 2 + } rawValue; + + DifferentiabilityWitnessFunctionKind() = default; + DifferentiabilityWitnessFunctionKind(innerty rawValue) : rawValue(rawValue) {} + explicit DifferentiabilityWitnessFunctionKind(unsigned rawValue) + : rawValue(static_cast(rawValue)) {} + explicit DifferentiabilityWitnessFunctionKind(StringRef name); + operator innerty() const { return rawValue; } + + Optional getAsDerivativeFunctionKind() const; +}; + struct NormalDifferentiableFunctionTypeComponent { enum innerty : unsigned { Original = 0, diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index b9fa8a3e44952..e95f93cbd1ee1 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -1620,6 +1620,13 @@ ERROR(sil_inst_autodiff_expected_linear_extractee_kind,PointsToFirstBadToken, "and '[transpose]'", ()) ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken, "expected an operand of a function type", ()) +ERROR(sil_inst_autodiff_expected_differentiability_witness_kind,PointsToFirstBadToken, + "expected a differentiability witness kind, which can be one of '[jvp]', " + "'[vjp]', or '[transpose]'", ()) +ERROR(sil_inst_autodiff_invalid_witness_generic_signature,PointsToFirstBadToken, + "expected witness_generic signature '%0' does not have same generic " + "parameters as original function generic signature '%1'", + (StringRef, StringRef)) // Quoted attribute. ERROR(attr_quoted_enable_experimental_quasiquotes,PointsToFirstBadToken, diff --git a/include/swift/SIL/SILBuilder.h b/include/swift/SIL/SILBuilder.h index acb6010ab27b4..6fc7ee83bfd14 100644 --- a/include/swift/SIL/SILBuilder.h +++ b/include/swift/SIL/SILBuilder.h @@ -549,6 +549,18 @@ class SILBuilder { NormalDifferentiableFunctionTypeComponent::Original, TheFunction)); } + DifferentiabilityWitnessFunctionInst * + createDifferentiabilityWitnessFunction( + SILLocation Loc, SILFunction *OriginalFunction, + DifferentiabilityWitnessFunctionKind WitnessKind, + IndexSubset *ParameterIndices, IndexSubset *ResultIndices, + GenericSignature *WitnessGenericSignature) { + return insert(new (getModule()) DifferentiabilityWitnessFunctionInst( + getModule(), getSILDebugLocation(Loc), OriginalFunction, WitnessKind, + ParameterIndices, ResultIndices, WitnessGenericSignature)); + } + // SWIFT_ENABLE_TENSORFLOW END + BuiltinInst *createBuiltin(SILLocation Loc, Identifier Name, SILType ResultTy, SubstitutionMap Subs, ArrayRef Args) { diff --git a/include/swift/SIL/SILCloner.h b/include/swift/SIL/SILCloner.h index 5fd3160d14078..b4f40c3bcf5b2 100644 --- a/include/swift/SIL/SILCloner.h +++ b/include/swift/SIL/SILCloner.h @@ -1011,6 +1011,17 @@ visitLinearFunctionExtractInst(LinearFunctionExtractInst *Inst) { getOpLocation(Inst->getLoc()), Inst->getExtractee(), getOpValue(Inst->getFunctionOperand()))); } + +template +void SILCloner::visitDifferentiabilityWitnessFunctionInst( + DifferentiabilityWitnessFunctionInst *Inst) { + getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope())); + recordClonedInstruction( + Inst, getBuilder().createDifferentiabilityWitnessFunction( + getOpLocation(Inst->getLoc()), Inst->getOriginalFunction(), + Inst->getWitnessKind(), Inst->getParameterIndices(), + Inst->getResultIndices(), Inst->getWitnessGenericSignature())); +} // SWIFT_ENABLE_TENSORFLOW END template diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index 6a66cd8dd98b9..7abfcdf65457b 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -8031,6 +8031,54 @@ class LinearFunctionExtractInst ArrayRef getAllOperands() const { return operands.asArray(); } MutableArrayRef getAllOperands() { return operands.asArray(); } }; + +class DifferentiabilityWitnessFunctionInst + : public InstructionBase< + SILInstructionKind::DifferentiabilityWitnessFunctionInst, + SingleValueInstruction> { +private: + friend SILBuilder; + /// The original function. + SILFunction *originalFunction; + /// The differentiability witness function kind. + DifferentiabilityWitnessFunctionKind witnessKind; + /// The autodiff config: parameter indices, result indices, and witness + /// derivative signature. + AutoDiffConfig config; + + static SILType getDifferentiabilityWitnessType( + SILModule &module, SILFunction *originalFunction, + DifferentiabilityWitnessFunctionKind witnessKind, + IndexSubset *parameterIndices, IndexSubset *resultIndices, + GenericSignature *witnessGenericSignature); + +public: + DifferentiabilityWitnessFunctionInst( + SILModule &module, SILDebugLocation loc, SILFunction *originalFunction, + DifferentiabilityWitnessFunctionKind witnessKind, + IndexSubset *parameterIndices, IndexSubset *resultIndices, + GenericSignature *witnessGenericSignature); + + static DifferentiabilityWitnessFunctionInst *create( + SILModule &module, SILDebugLocation loc, SILFunction *originalFunction, + DifferentiabilityWitnessFunctionKind witnessKind, + IndexSubset *parameterIndices, IndexSubset *resultIndices, + GenericSignature *witnessGenericSignature); + + DifferentiabilityWitnessFunctionKind getWitnessKind() const { + return witnessKind; + } + SILFunction *getOriginalFunction() const { return originalFunction; } + AutoDiffConfig const &getConfig() const { return config; } + IndexSubset *getParameterIndices() const { return config.parameterIndices; } + IndexSubset *getResultIndices() const { return config.resultIndices; } + GenericSignature *getWitnessGenericSignature() const { + return config.derivativeGenericSignature; + } + + ArrayRef getAllOperands() const { return {}; } + MutableArrayRef getAllOperands() { return {}; } +}; // SWIFT_ENABLE_TENSORFLOW END // This is defined out of line to work around the fact that this depends on diff --git a/include/swift/SIL/SILNodes.def b/include/swift/SIL/SILNodes.def index 2caf4b5fe6e86..b8b28a90618be 100644 --- a/include/swift/SIL/SILNodes.def +++ b/include/swift/SIL/SILNodes.def @@ -700,6 +700,10 @@ ABSTRACT_VALUE_AND_INST(SingleValueInstruction, ValueBase, SILInstruction) SINGLE_VALUE_INST(LinearFunctionExtractInst, linear_function_extract, SingleValueInstruction, None, DoesNotRelease) + SINGLE_VALUE_INST(DifferentiabilityWitnessFunctionInst, + differentiability_witness_function, + SingleValueInstruction, None, DoesNotRelease) + // SWIFT_ENABLE_TENSORFLOW END // Key paths // TODO: The only "side effect" is potentially retaining the returned key path diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 2c01dd59d3791..ba22d9ce427c8 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -32,6 +32,25 @@ AutoDiffDerivativeFunctionKind(StringRef string) { rawValue = *result; } +DifferentiabilityWitnessFunctionKind:: +DifferentiabilityWitnessFunctionKind(StringRef string) { + Optional result = llvm::StringSwitch>(string) + .Case("jvp", JVP) + .Case("vjp", VJP) + .Case("transpose", Transpose); + assert(result && "Invalid string"); + rawValue = *result; +} + +Optional +DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const { + switch (rawValue) { + case JVP: return {AutoDiffDerivativeFunctionKind::JVP}; + case VJP: return {AutoDiffDerivativeFunctionKind::VJP}; + case Transpose: return None; + } +} + NormalDifferentiableFunctionTypeComponent:: NormalDifferentiableFunctionTypeComponent(AutoDiffDerivativeFunctionKind kind) { switch (kind) { diff --git a/lib/IRGen/IRGenSIL.cpp b/lib/IRGen/IRGenSIL.cpp index fa92032fedf38..3db2756cbec4d 100644 --- a/lib/IRGen/IRGenSIL.cpp +++ b/lib/IRGen/IRGenSIL.cpp @@ -925,6 +925,8 @@ class IRGenSILFunction : void visitDifferentiableFunctionExtractInst( DifferentiableFunctionExtractInst *i); void visitLinearFunctionExtractInst(LinearFunctionExtractInst *i); + void visitDifferentiabilityWitnessFunctionInst( + DifferentiabilityWitnessFunctionInst *i); // SWIFT_ENABLE_TENSORFLOW END void visitFunctionRefBaseInst(FunctionRefBaseInst *i); @@ -1927,6 +1929,13 @@ visitLinearFunctionExtractInst(LinearFunctionExtractInst *i) { setLoweredExplosion(i, e); } +void IRGenSILFunction::visitDifferentiabilityWitnessFunctionInst( +DifferentiabilityWitnessFunctionInst *i) { + // TODO(TF-916): Implement IRGen for `differentiability_witness_function`. + llvm_unreachable("unimplemented"); +} +// SWIFT_ENABLE_TENSORFLOW END + void IRGenSILFunction::visitFunctionRefBaseInst(FunctionRefBaseInst *i) { auto fn = i->getInitiallyReferencedFunction(); diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index 6785ec708c95f..79bdac8cd920f 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -3091,6 +3091,110 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { InstLoc, extractee, functionOperand); break; } + case SILInstructionKind::DifferentiabilityWitnessFunctionInst: { + // e.g. differentiability_witness_function + // [jvp] [parameters 0 1] [results 0] + // @foo : $(T) -> T + DifferentiabilityWitnessFunctionKind witnessKind; + StringRef witnessKindNames[3] = {"jvp", "vjp", "transpose"}; + SourceLoc lastLoc; + if (P.parseToken(tok::l_square, + diag::sil_inst_autodiff_expected_differentiability_witness_kind) || + parseSILIdentifierSwitch(witnessKind, witnessKindNames, + diag::sil_inst_autodiff_expected_differentiability_witness_kind) || + P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare, + "differentiability witness function kind")) + return true; + // Parse an index set, prefaced with the given label. + auto parseIndexSet = [&](StringRef label, SmallVectorImpl &indices, + const Diagnostic &parseIndexDiag) -> bool { + // Parse `[