From b4ac692a8c7c84566332bbccc96c485bb7ab76de Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Tue, 15 Oct 2019 00:46:51 -0700 Subject: [PATCH] [AutoDiff] NFC: Change `DifferentiableFunctionExtractee` to a top-level type. Many places in the compiler that are completely unrelated to {{DifferentiableFunctionExtractInst}} are using {{DifferentiableFunctionExtractInst::Extractee}}, including `@differentiable` type lowering (IRGen/GenDiffFunc.cpp). This patch refactors it and renames it to `NormalDifferentiableFunctionTypeComponent` so that it is no longer part of `DifferentiableFunctionInst`. Resolves [TF-904](https://bugs.swift.org/browse/TF-904). --- include/swift/AST/AutoDiff.h | 91 +++++++++++-------- include/swift/SIL/SILBuilder.h | 4 +- include/swift/SIL/SILInstruction.h | 38 ++------ lib/AST/AutoDiff.cpp | 27 ++++++ lib/IRGen/GenDiffFunc.cpp | 34 +++---- lib/ParseSIL/ParseSIL.cpp | 2 +- lib/SIL/SILInstructions.cpp | 43 ++------- lib/SIL/SILPrinter.cpp | 10 +- lib/SIL/TypeLowering.cpp | 20 ++-- lib/SILGen/SILGenPoly.cpp | 2 +- lib/SILGen/SILGenThunk.cpp | 3 +- .../Mandatory/Differentiation.cpp | 14 +-- lib/Serialization/DeserializeSIL.cpp | 4 +- 13 files changed, 144 insertions(+), 148 deletions(-) diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index af88e338d76a6..4fe355509ac1f 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -45,11 +45,58 @@ enum class DifferentiabilityKind: uint8_t { Linear = 2 }; -// TODO(TF-904): Replace `DifferentiableFunctionExtractInst::Extractee`. -enum class NormalDifferentiableFunctionTypeComponent : uint8_t { - Original = 0, - JVP = 1, - VJP = 2 +/// The kind of an linear map. +struct AutoDiffLinearMapKind { + enum innerty : uint8_t { + // The differential function. + Differential = 0, + // The pullback function. + Pullback = 1 + } rawValue; + + AutoDiffLinearMapKind() = default; + AutoDiffLinearMapKind(innerty rawValue) : rawValue(rawValue) {} + operator innerty() const { return rawValue; } +}; + +/// The kind of a derivative function. +struct AutoDiffDerivativeFunctionKind { + enum innerty : uint8_t { + // The Jacobian-vector products function. + JVP = 0, + // The vector-Jacobian products function. + VJP = 1 + } rawValue; + + AutoDiffDerivativeFunctionKind() = default; + AutoDiffDerivativeFunctionKind(innerty rawValue) : rawValue(rawValue) {} + AutoDiffDerivativeFunctionKind(AutoDiffLinearMapKind linMapKind) + : rawValue(static_cast(linMapKind.rawValue)) {} + explicit AutoDiffDerivativeFunctionKind(StringRef string); + operator innerty() const { return rawValue; } + AutoDiffLinearMapKind getLinearMapKind() { + return (AutoDiffLinearMapKind::innerty)rawValue; + } +}; + +struct NormalDifferentiableFunctionTypeComponent { + enum innerty : unsigned { + Original = 0, + JVP = 1, + VJP = 2 + } rawValue; + + NormalDifferentiableFunctionTypeComponent() = default; + NormalDifferentiableFunctionTypeComponent(innerty rawValue) + : rawValue(rawValue) {} + NormalDifferentiableFunctionTypeComponent( + AutoDiffDerivativeFunctionKind kind); + explicit NormalDifferentiableFunctionTypeComponent(unsigned rawValue) : + NormalDifferentiableFunctionTypeComponent((innerty)rawValue) {} + explicit NormalDifferentiableFunctionTypeComponent(StringRef name); + operator innerty() const { return rawValue; } + + Optional getAsDerivativeFunctionKind() const; }; struct LinearDifferentiableFunctionTypeComponent { @@ -196,40 +243,6 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s, return s; } -/// The kind of an linear map. -struct AutoDiffLinearMapKind { - enum innerty : uint8_t { - // The differential function. - Differential = 0, - // The pullback function. - Pullback = 1 - } rawValue; - - AutoDiffLinearMapKind() = default; - AutoDiffLinearMapKind(innerty rawValue) : rawValue(rawValue) {} - operator innerty() const { return rawValue; } -}; - -/// The kind of a derivative function. -struct AutoDiffDerivativeFunctionKind { - enum innerty : uint8_t { - // The Jacobian-vector products function. - JVP = 0, - // The vector-Jacobian products function. - VJP = 1 - } rawValue; - - AutoDiffDerivativeFunctionKind() = default; - AutoDiffDerivativeFunctionKind(innerty rawValue) : rawValue(rawValue) {} - AutoDiffDerivativeFunctionKind(AutoDiffLinearMapKind linMapKind) - : rawValue(static_cast(linMapKind.rawValue)) {} - explicit AutoDiffDerivativeFunctionKind(StringRef string); - operator innerty() const { return rawValue; } - AutoDiffLinearMapKind getLinearMapKind() { - return (AutoDiffLinearMapKind::innerty)rawValue; - } -}; - /// Identifies an autodiff derivative function configuration: /// - Parameter indices. /// - Result indices. diff --git a/include/swift/SIL/SILBuilder.h b/include/swift/SIL/SILBuilder.h index 225f0803d5fc8..73cd0bcacbf40 100644 --- a/include/swift/SIL/SILBuilder.h +++ b/include/swift/SIL/SILBuilder.h @@ -528,7 +528,7 @@ class SILBuilder { } DifferentiableFunctionExtractInst *createDifferentiableFunctionExtract( - SILLocation Loc, DifferentiableFunctionExtractee Extractee, + SILLocation Loc, NormalDifferentiableFunctionTypeComponent Extractee, SILValue TheFunction) { return insert(new (getModule()) DifferentiableFunctionExtractInst( getModule(), getSILDebugLocation(Loc), Extractee, TheFunction)); @@ -546,7 +546,7 @@ class SILBuilder { SILValue TheFunction) { return insert(new (getModule()) DifferentiableFunctionExtractInst( getModule(), getSILDebugLocation(Loc), - DifferentiableFunctionExtractee::Original, TheFunction)); + NormalDifferentiableFunctionTypeComponent::Original, TheFunction)); } BuiltinInst *createBuiltin(SILLocation Loc, Identifier Name, SILType ResultTy, diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index 5ba8c7ae70050..6a66cd8dd98b9 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -7967,42 +7967,29 @@ class DifferentiableFunctionExtractInst : public InstructionBase< SILInstructionKind::DifferentiableFunctionExtractInst, SingleValueInstruction> { -public: - struct Extractee { - enum innerty : unsigned { - Original = 0, - JVP = 1, - VJP = 2 - } rawValue; - Extractee() = default; - Extractee(innerty rawValue) : rawValue(rawValue) {} - explicit Extractee(unsigned rawValue) : Extractee((innerty)rawValue) {} - Extractee(AutoDiffDerivativeFunctionKind kind); - explicit Extractee(StringRef name); - operator innerty() const { return rawValue; } - - Optional - getExtracteeAsDerivativeFunction() const; - }; - private: /// The extractee. - Extractee extractee; + NormalDifferentiableFunctionTypeComponent extractee; /// The list containing the `@differentiable` function operand. FixedOperandList<1> operands; static SILType - getExtracteeType(SILValue function, Extractee extractee, SILModule &module); + getExtracteeType( + SILValue function, NormalDifferentiableFunctionTypeComponent extractee, + SILModule &module); public: explicit DifferentiableFunctionExtractInst( - SILModule &module, SILDebugLocation debugLoc, Extractee extractee, + SILModule &module, SILDebugLocation debugLoc, + NormalDifferentiableFunctionTypeComponent extractee, SILValue theFunction); - Extractee getExtractee() const { return extractee; } + NormalDifferentiableFunctionTypeComponent getExtractee() const { + return extractee; + } AutoDiffDerivativeFunctionKind getDerivativeFunctionKind() const { - auto kind = extractee.getExtracteeAsDerivativeFunction(); + auto kind = extractee.getAsDerivativeFunctionKind(); assert(kind); return *kind; } @@ -8012,9 +7999,6 @@ class DifferentiableFunctionExtractInst MutableArrayRef getAllOperands() { return operands.asArray(); } }; -typedef DifferentiableFunctionExtractInst::Extractee - DifferentiableFunctionExtractee; - /// `linear_function_extract` - given an `@differentiable(linear)` function /// representing a bundle of the original function and the transpose function, /// extract the specified function. @@ -8047,8 +8031,6 @@ class LinearFunctionExtractInst ArrayRef getAllOperands() const { return operands.asArray(); } MutableArrayRef getAllOperands() { return operands.asArray(); } }; - -typedef LinearDifferentiableFunctionTypeComponent LinearFunctionExtractee; // SWIFT_ENABLE_TENSORFLOW END // This is defined out of line to work around the fact that this depends on diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 8a1799f830e2f..2c01dd59d3791 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -32,6 +32,33 @@ AutoDiffDerivativeFunctionKind(StringRef string) { rawValue = *result; } +NormalDifferentiableFunctionTypeComponent:: +NormalDifferentiableFunctionTypeComponent(AutoDiffDerivativeFunctionKind kind) { + switch (kind) { + case AutoDiffDerivativeFunctionKind::JVP: rawValue = JVP; return; + case AutoDiffDerivativeFunctionKind::VJP: rawValue = VJP; return; + } +} + +NormalDifferentiableFunctionTypeComponent:: +NormalDifferentiableFunctionTypeComponent(StringRef string) { + Optional result = llvm::StringSwitch>(string) + .Case("original", Original) + .Case("jvp", JVP) + .Case("vjp", VJP); + assert(result && "Invalid string"); + rawValue = *result; +} + +Optional +NormalDifferentiableFunctionTypeComponent::getAsDerivativeFunctionKind() const { + switch (rawValue) { + case Original: return None; + case JVP: return {AutoDiffDerivativeFunctionKind::JVP}; + case VJP: return {AutoDiffDerivativeFunctionKind::VJP}; + } +} + LinearDifferentiableFunctionTypeComponent:: LinearDifferentiableFunctionTypeComponent(StringRef string) { Optional result = diff --git a/lib/IRGen/GenDiffFunc.cpp b/lib/IRGen/GenDiffFunc.cpp index 98e1b572114ac..3f54e8aac25f2 100644 --- a/lib/IRGen/GenDiffFunc.cpp +++ b/lib/IRGen/GenDiffFunc.cpp @@ -41,24 +41,24 @@ class DifferentiableFuncFieldInfo final : public RecordField { public: DifferentiableFuncFieldInfo( - DifferentiableFunctionExtractee component, const TypeInfo &type, + NormalDifferentiableFunctionTypeComponent component, const TypeInfo &type, IndexSubset *parameterIndices) : RecordField(type), component(component), parameterIndices(parameterIndices) {} /// The field index. - const DifferentiableFunctionExtractee component; + const NormalDifferentiableFunctionTypeComponent component; /// The parameter indices. IndexSubset *parameterIndices; std::string getFieldName() const { switch (component) { - case DifferentiableFunctionExtractee::Original: + case NormalDifferentiableFunctionTypeComponent::Original: return "original"; - case DifferentiableFunctionExtractee::JVP: + case NormalDifferentiableFunctionTypeComponent::JVP: return "jvp"; - case DifferentiableFunctionExtractee::VJP: + case NormalDifferentiableFunctionTypeComponent::VJP: return "vjp"; } } @@ -66,9 +66,9 @@ class DifferentiableFuncFieldInfo final SILType getType(IRGenModule &IGM, SILType t) const { auto fnTy = t.castTo(); auto origFnTy = fnTy->getWithoutDifferentiability(); - if (component == DifferentiableFunctionExtractee::Original) + if (component == NormalDifferentiableFunctionTypeComponent::Original) return SILType::getPrimitiveObjectType(origFnTy); - auto kind = *component.getExtracteeAsDerivativeFunction(); + auto kind = *component.getAsDerivativeFunctionKind(); auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType( parameterIndices, /*resultIndex*/ 0, kind, IGM.getSILTypes(), LookUpConformanceInModule(IGM.getSwiftModule())); @@ -79,8 +79,8 @@ class DifferentiableFuncFieldInfo final class DifferentiableFuncTypeInfo final : public RecordTypeInfo { - using super = - RecordTypeInfo; + using super = RecordTypeInfo; public: DifferentiableFuncTypeInfo( @@ -117,7 +117,7 @@ class DifferentiableFuncTypeInfo final class DifferentiableFuncTypeBuilder : public RecordTypeBuilder { + NormalDifferentiableFunctionTypeComponent> { SILFunctionType *originalType; IndexSubset *parameterIndices; @@ -151,15 +151,15 @@ class DifferentiableFuncTypeBuilder } DifferentiableFuncFieldInfo getFieldInfo( - unsigned index, DifferentiableFunctionExtractee component, + unsigned index, NormalDifferentiableFunctionTypeComponent component, const TypeInfo &fieldTI) { return DifferentiableFuncFieldInfo(component, fieldTI, parameterIndices); } - SILType getType(DifferentiableFunctionExtractee component) { - if (component == DifferentiableFunctionExtractee::Original) + SILType getType(NormalDifferentiableFunctionTypeComponent component) { + if (component == NormalDifferentiableFunctionTypeComponent::Original) return SILType::getPrimitiveObjectType(originalType->getCanonicalType()); - auto kind = *component.getExtracteeAsDerivativeFunction(); + auto kind = *component.getAsDerivativeFunctionKind(); auto assocTy = originalType->getAutoDiffDerivativeFunctionType( parameterIndices, /*resultIndex*/ 0, kind, IGM.getSILTypes(), LookUpConformanceInModule(IGM.getSwiftModule())); @@ -320,9 +320,9 @@ class LinearFuncTypeBuilder const TypeInfo * TypeConverter::convertNormalDifferentiableFunctionType(SILFunctionType *type) { DifferentiableFuncTypeBuilder builder(IGM, type); - return builder.layout({DifferentiableFunctionExtractee::Original, - DifferentiableFunctionExtractee::JVP, - DifferentiableFunctionExtractee::VJP}); + return builder.layout({NormalDifferentiableFunctionTypeComponent::Original, + NormalDifferentiableFunctionTypeComponent::JVP, + NormalDifferentiableFunctionTypeComponent::VJP}); } const TypeInfo * diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index 1caa6ed4eb9ff..a5f0d9b79a639 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -3041,7 +3041,7 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { case SILInstructionKind::DifferentiableFunctionExtractInst: { // Parse the rest of the instruction: an extractee, a differentiable // function operand, and a debug location. - DifferentiableFunctionExtractee extractee; + NormalDifferentiableFunctionTypeComponent extractee; StringRef extracteeNames[3] = {"original", "jvp", "vjp"}; SILValue functionOperand; SourceLoc lastLoc; diff --git a/lib/SIL/SILInstructions.cpp b/lib/SIL/SILInstructions.cpp index 426c0d369bc57..06a8b5ed48b51 100644 --- a/lib/SIL/SILInstructions.cpp +++ b/lib/SIL/SILInstructions.cpp @@ -664,45 +664,16 @@ LinearFunctionInst *LinearFunctionInst::create( HasOwnership); } -DifferentiableFunctionExtractInst::Extractee::Extractee( - AutoDiffDerivativeFunctionKind kind) { - switch (kind) { - case AutoDiffDerivativeFunctionKind::JVP: - rawValue = JVP; - return; - case AutoDiffDerivativeFunctionKind::VJP: - rawValue = VJP; - return; - } -} - -DifferentiableFunctionExtractInst::Extractee::Extractee(StringRef string) { - Optional result = llvm::StringSwitch>(string) - .Case("original", Original) - .Case("jvp", JVP) - .Case("vjp", VJP); - assert(result && "Invalid string"); - rawValue = *result; -} - -Optional -DifferentiableFunctionExtractInst::Extractee:: -getExtracteeAsDerivativeFunction() const { - switch (rawValue) { - case Original: return None; - case JVP: return {AutoDiffDerivativeFunctionKind::JVP}; - case VJP: return {AutoDiffDerivativeFunctionKind::VJP}; - } -} - SILType DifferentiableFunctionExtractInst:: -getExtracteeType(SILValue function, Extractee extractee, SILModule &module) { +getExtracteeType( + SILValue function, NormalDifferentiableFunctionTypeComponent extractee, + SILModule &module) { auto fnTy = function->getType().castTo(); assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Normal); auto originalFnTy = fnTy->getWithoutDifferentiability(); - auto kindOpt = extractee.getExtracteeAsDerivativeFunction(); + auto kindOpt = extractee.getAsDerivativeFunctionKind(); if (!kindOpt) { - assert(extractee == Extractee::Original); + assert(extractee == NormalDifferentiableFunctionTypeComponent::Original); return SILType::getPrimitiveObjectType(originalFnTy); } auto resultFnTy = originalFnTy->getAutoDiffDerivativeFunctionType( @@ -713,8 +684,8 @@ getExtracteeType(SILValue function, Extractee extractee, SILModule &module) { } DifferentiableFunctionExtractInst::DifferentiableFunctionExtractInst( - SILModule &module, SILDebugLocation debugLoc, Extractee extractee, - SILValue theFunction) + SILModule &module, SILDebugLocation debugLoc, + NormalDifferentiableFunctionTypeComponent extractee, SILValue theFunction) : InstructionBase(debugLoc, getExtracteeType(theFunction, extractee, module)), extractee(extractee), operands(this, theFunction) {} diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index 4452543e3dc07..b31badb34f9a4 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -1194,13 +1194,13 @@ class SILPrinter : public SILInstructionVisitor { DifferentiableFunctionExtractInst *dfei) { *this << '['; switch (dfei->getExtractee()) { - case DifferentiableFunctionExtractee::Original: + case NormalDifferentiableFunctionTypeComponent::Original: *this << "original"; break; - case DifferentiableFunctionExtractee::JVP: + case NormalDifferentiableFunctionTypeComponent::JVP: *this << "jvp"; break; - case DifferentiableFunctionExtractee::VJP: + case NormalDifferentiableFunctionTypeComponent::VJP: *this << "vjp"; break; } @@ -1211,10 +1211,10 @@ class SILPrinter : public SILInstructionVisitor { void visitLinearFunctionExtractInst(LinearFunctionExtractInst *lfei) { *this << '['; switch (lfei->getExtractee()) { - case LinearFunctionExtractee::Original: + case LinearDifferentiableFunctionTypeComponent::Original: *this << "original"; break; - case LinearFunctionExtractee::Transpose: + case LinearDifferentiableFunctionTypeComponent::Transpose: *this << "transpose"; break; } diff --git a/lib/SIL/TypeLowering.cpp b/lib/SIL/TypeLowering.cpp index 10b5b1ed009be..fe727888ec486 100644 --- a/lib/SIL/TypeLowering.cpp +++ b/lib/SIL/TypeLowering.cpp @@ -892,14 +892,14 @@ namespace { class NormalDifferentiableSILFunctionTypeLowering final : public LoadableAggTypeLowering< NormalDifferentiableSILFunctionTypeLowering, - DifferentiableFunctionExtractee> { + NormalDifferentiableFunctionTypeComponent> { public: using LoadableAggTypeLowering::LoadableAggTypeLowering; - SILValue emitRValueProject(SILBuilder &B, SILLocation loc, - SILValue tupleValue, - DifferentiableFunctionExtractee extractee, - const TypeLowering &eltLowering) const { + SILValue emitRValueProject( + SILBuilder &B, SILLocation loc, SILValue tupleValue, + NormalDifferentiableFunctionTypeComponent extractee, + const TypeLowering &eltLowering) const { return B.createDifferentiableFunctionExtract( loc, extractee, tupleValue); } @@ -921,7 +921,7 @@ namespace { auto origFnTy = fnTy->getWithoutDifferentiability(); auto paramIndices = fnTy->getDifferentiationParameterIndices(); children.push_back(Child{ - DifferentiableFunctionExtractee::Original, + NormalDifferentiableFunctionTypeComponent::Original, TC.getTypeLowering(origFnTy, getResilienceExpansion()) }); for (AutoDiffDerivativeFunctionKind kind : @@ -931,12 +931,12 @@ namespace { paramIndices, 0, kind, TC, LookUpConformanceInModule(&TC.M)); auto silTy = SILType::getPrimitiveObjectType(derivativeFnTy); - DifferentiableFunctionExtractee extractee(kind); + NormalDifferentiableFunctionTypeComponent extractee(kind); // Assert that we have the right extractee. A terrible bug in the past // was caused by implicit conversions from `unsigned` to - // `DifferentiableFunctionExtractee` which resulted into a wrong - // extractee. - assert(extractee.getExtracteeAsDerivativeFunction() == kind); + // `NormalDifferentiableFunctionTypeComponent` which resulted into a + // wrong extractee. + assert(extractee.getAsDerivativeFunctionKind() == kind); children.push_back(Child{ extractee, TC.getTypeLowering(silTy, getResilienceExpansion())}); } diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp index e30f50f038c33..018a9fe2fdfc0 100644 --- a/lib/SILGen/SILGenPoly.cpp +++ b/lib/SILGen/SILGenPoly.cpp @@ -4312,7 +4312,7 @@ getWitnessFunctionRef(SILGenFunction &SGF, auto autoDiffFn = SGF.B.createDifferentiableFunction( loc, loweredIndices, originalFn); return SGF.B.createDifferentiableFunctionExtract( - loc, DifferentiableFunctionExtractee(autoDiffFuncId->getKind()), + loc, NormalDifferentiableFunctionTypeComponent(autoDiffFuncId->getKind()), autoDiffFn); } diff --git a/lib/SILGen/SILGenThunk.cpp b/lib/SILGen/SILGenThunk.cpp index acddba5fb228a..204231e211062 100644 --- a/lib/SILGen/SILGenThunk.cpp +++ b/lib/SILGen/SILGenThunk.cpp @@ -146,7 +146,8 @@ SILFunction *SILGenModule::getOrCreateAutoDiffClassMethodThunk( auto diffFn = SGF.B.createDifferentiableFunction( loc, loweredIndices, originalFnRef); auto diffDerivativeFn = SGF.B.createDifferentiableFunctionExtract( - loc, DifferentiableFunctionExtractee(autoDiffFuncId->getKind()), diffFn); + loc, NormalDifferentiableFunctionTypeComponent(autoDiffFuncId->getKind()), + diffFn); auto autoDiffDerivativeFnSILTy = SILType::getPrimitiveObjectType(constantTy); SmallVector args(thunk->getArguments().begin(), thunk->getArguments().end()); diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 306249f8e7887..ae9527c3caa27 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -2563,7 +2563,8 @@ emitDerivativeFunctionReference( // derivative function, and return it. if (auto *inst = original->getDefiningInstruction()) if (auto *dfei = dyn_cast(inst)) - if (dfei->getExtractee() == DifferentiableFunctionExtractee::Original) + if (dfei->getExtractee() == + NormalDifferentiableFunctionTypeComponent::Original) functionSource = dfei->getFunctionOperand(); // If `functionSource` is a `@differentiable` function, just extract the @@ -3787,7 +3788,7 @@ class VJPEmitter final } auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, original); vjpValue = builder.createDifferentiableFunctionExtract( - loc, DifferentiableFunctionExtractInst::Extractee::VJP, + loc, NormalDifferentiableFunctionTypeComponent::VJP, borrowedDiffFunc); vjpValue = builder.emitCopyValueOperation(loc, vjpValue); } @@ -3870,7 +3871,7 @@ class VJPEmitter final auto borrowedADFunc = builder.emitBeginBorrowOperation(loc, diffFuncInst); auto extractedVJP = getBuilder().createDifferentiableFunctionExtract( - loc, DifferentiableFunctionExtractInst::Extractee::VJP, + loc, NormalDifferentiableFunctionTypeComponent::VJP, borrowedADFunc); vjpValue = builder.emitCopyValueOperation(loc, extractedVJP); builder.emitEndBorrowOperation(loc, borrowedADFunc); @@ -5465,7 +5466,7 @@ class JVPEmitter final } auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, original); jvpValue = builder.createDifferentiableFunctionExtract( - loc, DifferentiableFunctionExtractInst::Extractee::JVP, + loc, NormalDifferentiableFunctionTypeComponent::JVP, borrowedDiffFunc); jvpValue = builder.emitCopyValueOperation(loc, jvpValue); } @@ -5544,7 +5545,7 @@ class JVPEmitter final auto borrowedADFunc = builder.emitBeginBorrowOperation(loc, diffFuncInst); auto extractedJVP = builder.createDifferentiableFunctionExtract( - loc, DifferentiableFunctionExtractInst::Extractee::JVP, + loc, NormalDifferentiableFunctionTypeComponent::JVP, borrowedADFunc); jvpValue = builder.emitCopyValueOperation(loc, extractedJVP); builder.emitEndBorrowOperation(loc, borrowedADFunc); @@ -8712,7 +8713,8 @@ void ADContext::foldDifferentiableFunctionExtraction( if (!dfei) continue; // Fold original function extractors. - if (dfei->getExtractee() == DifferentiableFunctionExtractee::Original) { + if (dfei->getExtractee() == + NormalDifferentiableFunctionTypeComponent::Original) { auto originalFnValue = source->getOriginalFunction(); dfei->replaceAllUsesWith(originalFnValue); dfei->eraseFromParent(); diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index 5601e6252edaa..1228cb02f7916 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -1599,7 +1599,7 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB, auto astTy = MF->getType(TyID); auto silTy = getSILType(astTy, SILValueCategory::Object); auto val = getLocalValue(ValID, silTy); - DifferentiableFunctionExtractee extractee(Attr); + NormalDifferentiableFunctionTypeComponent extractee(Attr); ResultVal = Builder.createDifferentiableFunctionExtract(Loc, extractee, val); break; @@ -1608,7 +1608,7 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB, auto astTy = MF->getType(TyID); auto silTy = getSILType(astTy, SILValueCategory::Object); auto val = getLocalValue(ValID, silTy); - LinearFunctionExtractee extractee(Attr); + LinearDifferentiableFunctionTypeComponent extractee(Attr); ResultVal = Builder.createLinearFunctionExtract(Loc, extractee, val); break; }