From 39211c02012e61f614b355657a7d938ec0fe852f Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Sun, 13 Oct 2019 23:55:38 -0700 Subject: [PATCH 1/4] [AutoDiff] [IRGen] Lower `@differentiable(linear)` function types. --- include/swift/AST/AutoDiff.h | 14 +- include/swift/AST/Types.h | 9 +- lib/IRGen/GenDiffFunc.cpp | 244 +++++++++++++++++---- lib/IRGen/GenFunc.cpp | 10 +- lib/IRGen/GenType.h | 3 +- lib/SIL/SILFunctionType.cpp | 75 +++++++ lib/SIL/TypeLowering.cpp | 146 +++++++++--- test/AutoDiff/differentiable_func_type.sil | 60 ++++- 8 files changed, 469 insertions(+), 92 deletions(-) diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index f9092eff0a8dc..c0b1fa49d9342 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -39,12 +39,24 @@ class SILFunctionType; typedef CanTypeWrapper CanSILFunctionType; enum class SILLinkage : uint8_t; -enum class DifferentiabilityKind: uint8_t { +enum class DifferentiabilityKind : uint8_t { NonDifferentiable = 0b00, Normal = 0b01, Linear = 0b11 }; +// TODO(TF-904): Replace `DifferentiableFunctionExtractInst::Extractee`. +enum class NormalDifferentiableFunctionTypeComponent : uint8_t { + Original = 0, + JVP = 1, + VJP = 2 +}; + +enum class LinearDifferentiableFunctionTypeComponent : uint8_t { + Original = 0, + Transpose = 1 +}; + class ParsedAutoDiffParameter { public: enum class Kind { Named, Ordered, Self }; diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index c317ddcbfdbf3..0a9df0964498a 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -4220,14 +4220,19 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode, CanSILFunctionType getWithoutDifferentiability(); - /// Returns the type of a differentiation function that is associated with - /// a function of this type. + /// Returns the type of the derivative function. CanSILFunctionType getAutoDiffDerivativeFunctionType( IndexSubset *parameterIndices, unsigned resultIndex, AutoDiffDerivativeFunctionKind kind, Lowering::TypeConverter &TC, LookupConformanceFn lookupConformance, CanGenericSignature derivativeFunctionGenericSignature = nullptr); + /// Returns the type of the transpose function. + CanSILFunctionType getAutoDiffTransposeFunctionType( + IndexSubset *parameterIndices, Lowering::TypeConverter &TC, + LookupConformanceFn lookupConformance, + CanGenericSignature derivativeFunctionGenericSignature = nullptr); + /// Returns a bit vector that specifices which parameters you can /// differentiate with respect to for this differentiable function type. (e.g. /// which parameters are not `@nondiff`). The function type must be diff --git a/lib/IRGen/GenDiffFunc.cpp b/lib/IRGen/GenDiffFunc.cpp index 0ff51c2101c7e..3d942a85c6284 100644 --- a/lib/IRGen/GenDiffFunc.cpp +++ b/lib/IRGen/GenDiffFunc.cpp @@ -32,23 +32,27 @@ using namespace swift; using namespace irgen; -using DiffFuncIndex = DifferentiableFunctionExtractee; +//----------------------------------------------------------------------------// +// `@differentiable` (non-linear) function type info +//----------------------------------------------------------------------------// namespace { -class DiffFuncFieldInfo final : public RecordField { +class DifferentiableFuncFieldInfo final : public RecordField { public: - DiffFuncFieldInfo(DiffFuncIndex index, const TypeInfo &type, - IndexSubset *parameterIndices) - : RecordField(type), Index(index), ParameterIndices(parameterIndices) {} + DifferentiableFuncFieldInfo( + DifferentiableFunctionExtractee component, const TypeInfo &type, + IndexSubset *parameterIndices) + : RecordField(type), component(component), + parameterIndices(parameterIndices) {} /// The field index. - const DiffFuncIndex Index; + const DifferentiableFunctionExtractee component; /// The parameter indices. - IndexSubset *ParameterIndices; + IndexSubset *parameterIndices; std::string getFieldName() const { - switch (Index) { + switch (component) { case DifferentiableFunctionExtractee::Original: return "original"; case DifferentiableFunctionExtractee::JVP: @@ -61,32 +65,32 @@ class DiffFuncFieldInfo final : public RecordField { SILType getType(IRGenModule &IGM, SILType t) const { auto fnTy = t.castTo(); auto origFnTy = fnTy->getWithoutDifferentiability(); - if (Index == DifferentiableFunctionExtractee::Original) + if (component == DifferentiableFunctionExtractee::Original) return SILType::getPrimitiveObjectType(origFnTy); - auto kind = *Index.getExtracteeAsDerivativeFunction(); + auto kind = *component.getExtracteeAsDerivativeFunction(); auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType( - ParameterIndices, /*resultIndex*/ 0, kind, + parameterIndices, /*resultIndex*/ 0, kind, IGM.getSILTypes(), LookUpConformanceInModule(IGM.getSwiftModule())); return SILType::getPrimitiveObjectType(assocTy); } }; -class DiffFuncTypeInfo final - : public RecordTypeInfo { +class DifferentiableFuncTypeInfo final + : public RecordTypeInfo { using super = - RecordTypeInfo; + RecordTypeInfo; public: - DiffFuncTypeInfo(ArrayRef fields, unsigned explosionSize, - llvm::Type *ty, Size size, SpareBitVector &&spareBits, - Alignment align, IsPOD_t isPOD, - IsFixedSize_t alwaysFixedSize) + DifferentiableFuncTypeInfo( + ArrayRef fields, unsigned explosionSize, + llvm::Type *ty, Size size, SpareBitVector &&spareBits, Alignment align, + IsPOD_t isPOD, IsFixedSize_t alwaysFixedSize) : super(fields, explosionSize, ty, size, std::move(spareBits), align, isPOD, alwaysFixedSize) {} Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T, - const DiffFuncFieldInfo &field) const { + const DifferentiableFuncFieldInfo &field) const { return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T)); } @@ -110,50 +114,52 @@ class DiffFuncTypeInfo final } }; -class DiffFuncTypeBuilder - : public RecordTypeBuilder { +class DifferentiableFuncTypeBuilder + : public RecordTypeBuilder { - SILFunctionType *origFnTy; + SILFunctionType *originalType; IndexSubset *parameterIndices; public: - DiffFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy) - : RecordTypeBuilder(IGM), origFnTy(fnTy->getWithoutDifferentiability()), + DifferentiableFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy) + : RecordTypeBuilder(IGM), + originalType(fnTy->getWithoutDifferentiability()), parameterIndices(fnTy->getDifferentiationParameterIndices()) { - assert(fnTy->isDifferentiable()); + assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Normal); } - TypeInfo *createFixed(ArrayRef fields, + TypeInfo *createFixed(ArrayRef fields, StructLayout &&layout) { llvm_unreachable("@differentiable functions are always loadable"); } - DiffFuncTypeInfo *createLoadable(ArrayRef fields, - StructLayout &&layout, - unsigned explosionSize) { - return DiffFuncTypeInfo::create( + DifferentiableFuncTypeInfo *createLoadable( + ArrayRef fields, StructLayout &&layout, + unsigned explosionSize) { + return DifferentiableFuncTypeInfo::create( fields, explosionSize, layout.getType(), layout.getSize(), std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(), layout.isAlwaysFixedSize()); } - TypeInfo *createNonFixed(ArrayRef fields, + TypeInfo *createNonFixed(ArrayRef fields, FieldsAreABIAccessible_t fieldsAccessible, StructLayout &&layout) { llvm_unreachable("@differentiable functions are always loadable"); } - DiffFuncFieldInfo getFieldInfo(unsigned index, DiffFuncIndex field, - const TypeInfo &fieldTI) { - return DiffFuncFieldInfo(field, fieldTI, parameterIndices); + DifferentiableFuncFieldInfo getFieldInfo( + unsigned index, DifferentiableFunctionExtractee component, + const TypeInfo &fieldTI) { + return DifferentiableFuncFieldInfo(component, fieldTI, parameterIndices); } - SILType getType(DiffFuncIndex field) { - if (field == DifferentiableFunctionExtractee::Original) - return SILType::getPrimitiveObjectType(origFnTy->getCanonicalType()); - auto kind = *field.getExtracteeAsDerivativeFunction(); - auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType( + SILType getType(DifferentiableFunctionExtractee component) { + if (component == DifferentiableFunctionExtractee::Original) + return SILType::getPrimitiveObjectType(originalType->getCanonicalType()); + auto kind = *component.getExtracteeAsDerivativeFunction(); + auto assocTy = originalType->getAutoDiffDerivativeFunctionType( parameterIndices, /*resultIndex*/ 0, kind, IGM.getSILTypes(), LookUpConformanceInModule(IGM.getSwiftModule())); return SILType::getPrimitiveObjectType(assocTy); @@ -166,11 +172,161 @@ class DiffFuncTypeBuilder }; } // end anonymous namespace +//----------------------------------------------------------------------------// +// `@differentiable(linear)` function type info +//----------------------------------------------------------------------------// +namespace { +class LinearFuncFieldInfo final : public RecordField { +public: + LinearFuncFieldInfo(LinearDifferentiableFunctionTypeComponent component, + const TypeInfo &type, IndexSubset *parameterIndices) + : RecordField(type), component(component), + parameterIndices(parameterIndices) {} + + /// The field index. + const LinearDifferentiableFunctionTypeComponent component; + + /// The parameter indices. + IndexSubset *parameterIndices; + + std::string getFieldName() const { + switch (component) { + case LinearDifferentiableFunctionTypeComponent::Original: + return "original"; + case LinearDifferentiableFunctionTypeComponent::Transpose: + return "transpose"; + } + } + + SILType getType(IRGenModule &IGM, SILType t) const { + auto fnTy = t.castTo(); + auto origFnTy = fnTy->getWithoutDifferentiability(); + switch (component) { + case LinearDifferentiableFunctionTypeComponent::Original: + return SILType::getPrimitiveObjectType(origFnTy); + case LinearDifferentiableFunctionTypeComponent::Transpose: + auto transposeTy = origFnTy->getAutoDiffTransposeFunctionType( + parameterIndices, IGM.getSILTypes(), + LookUpConformanceInModule(IGM.getSwiftModule())); + return SILType::getPrimitiveObjectType(transposeTy); + } + } +}; + +class LinearFuncTypeInfo final + : public RecordTypeInfo { + using super = + RecordTypeInfo; + +public: + LinearFuncTypeInfo( + ArrayRef fields, unsigned explosionSize, + llvm::Type *ty, Size size, SpareBitVector &&spareBits, Alignment align, + IsPOD_t isPOD, IsFixedSize_t alwaysFixedSize) + : super(fields, explosionSize, ty, size, std::move(spareBits), align, + isPOD, alwaysFixedSize) {} + + Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T, + const LinearFuncFieldInfo &field) const { + return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T)); + } + + void initializeFromParams(IRGenFunction &IGF, Explosion ¶ms, Address src, + SILType T, bool isOutlined) const override { + llvm_unreachable("unexploded @differentiable function as argument?"); + } + + void addToAggLowering(IRGenModule &IGM, SwiftAggLowering &lowering, + Size offset) const override { + for (auto &field : getFields()) { + auto fieldOffset = offset + field.getFixedByteOffset(); + cast(field.getTypeInfo()) + .addToAggLowering(IGM, lowering, fieldOffset); + } + } + + llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF) const { return None; } + llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF, SILType T) const { + return None; + } +}; + +class LinearFuncTypeBuilder + : public RecordTypeBuilder { + + SILFunctionType *originalType; + IndexSubset *parameterIndices; + +public: + LinearFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy) + : RecordTypeBuilder(IGM), + originalType(fnTy->getWithoutDifferentiability()), + parameterIndices(fnTy->getDifferentiationParameterIndices()) { + assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Linear); + } + + TypeInfo *createFixed(ArrayRef fields, + StructLayout &&layout) { + llvm_unreachable("@differentiable functions are always loadable"); + } + + LinearFuncTypeInfo *createLoadable(ArrayRef fields, + StructLayout &&layout, + unsigned explosionSize) { + return LinearFuncTypeInfo::create( + fields, explosionSize, layout.getType(), layout.getSize(), + std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(), + layout.isAlwaysFixedSize()); + } + + TypeInfo *createNonFixed(ArrayRef fields, + FieldsAreABIAccessible_t fieldsAccessible, + StructLayout &&layout) { + llvm_unreachable("@differentiable functions are always loadable"); + } + + LinearFuncFieldInfo getFieldInfo( + unsigned index, LinearDifferentiableFunctionTypeComponent field, + const TypeInfo &fieldTI) { + return LinearFuncFieldInfo(field, fieldTI, parameterIndices); + } + + SILType getType(LinearDifferentiableFunctionTypeComponent component) { + switch (component) { + case LinearDifferentiableFunctionTypeComponent::Original: + return SILType::getPrimitiveObjectType(originalType->getCanonicalType()); + case LinearDifferentiableFunctionTypeComponent::Transpose: + auto transposeTy = originalType->getAutoDiffTransposeFunctionType( + parameterIndices, IGM.getSILTypes(), + LookUpConformanceInModule(IGM.getSwiftModule())); + return SILType::getPrimitiveObjectType(transposeTy); + } + } + + StructLayout performLayout(ArrayRef fieldTypes) { + return StructLayout(IGM, /*decl=*/nullptr, LayoutKind::NonHeapObject, + LayoutStrategy::Universal, fieldTypes); + } +}; +} // end anonymous namespace + +//----------------------------------------------------------------------------// +// Type converter entry points +//----------------------------------------------------------------------------// + const TypeInfo * -TypeConverter::convertDifferentiableFunctionType(SILFunctionType *type) { - assert(type->isDifferentiable()); - DiffFuncTypeBuilder builder(IGM, type); +TypeConverter::convertNormalDifferentiableFunctionType(SILFunctionType *type) { + DifferentiableFuncTypeBuilder builder(IGM, type); return builder.layout({DifferentiableFunctionExtractee::Original, DifferentiableFunctionExtractee::JVP, DifferentiableFunctionExtractee::VJP}); } + +const TypeInfo * +TypeConverter::convertLinearDifferentiableFunctionType(SILFunctionType *type) { + LinearFuncTypeBuilder builder(IGM, type); + return builder.layout({LinearDifferentiableFunctionTypeComponent::Original, + LinearDifferentiableFunctionTypeComponent::Transpose}); +} diff --git a/lib/IRGen/GenFunc.cpp b/lib/IRGen/GenFunc.cpp index 1ce7ff5270dd5..bdd375f9dea44 100644 --- a/lib/IRGen/GenFunc.cpp +++ b/lib/IRGen/GenFunc.cpp @@ -480,8 +480,14 @@ Address irgen::projectBlockStorageCapture(IRGenFunction &IGF, const TypeInfo *TypeConverter::convertFunctionType(SILFunctionType *T) { // SWIFT_ENABLE_TENSORFLOW - if (T->isDifferentiable()) - return convertDifferentiableFunctionType(T); + switch (T->getDifferentiabilityKind()) { + case DifferentiabilityKind::Normal: + return convertNormalDifferentiableFunctionType(T); + case DifferentiabilityKind::Linear: + return convertLinearDifferentiableFunctionType(T); + case DifferentiabilityKind::NonDifferentiable: + break; + } switch (T->getRepresentation()) { case SILFunctionType::Representation::Block: diff --git a/lib/IRGen/GenType.h b/lib/IRGen/GenType.h index 611678594b5e4..64565b8774544 100644 --- a/lib/IRGen/GenType.h +++ b/lib/IRGen/GenType.h @@ -138,7 +138,8 @@ class TypeConverter { const TypeInfo *convertStructType(TypeBase *key, CanType type, StructDecl *D); const TypeInfo *convertFunctionType(SILFunctionType *T); // SWIFT_ENABLE_TENSORFLOW - const TypeInfo *convertDifferentiableFunctionType(SILFunctionType *T); + const TypeInfo *convertNormalDifferentiableFunctionType(SILFunctionType *T); + const TypeInfo *convertLinearDifferentiableFunctionType(SILFunctionType *T); const TypeInfo *convertBlockStorageType(SILBlockStorageType *T); const TypeInfo *convertBoxType(SILBoxType *T); const TypeInfo *convertArchetypeType(ArchetypeType *T); diff --git a/lib/SIL/SILFunctionType.cpp b/lib/SIL/SILFunctionType.cpp index 4f0dba9f97bfe..6d9057e379a61 100644 --- a/lib/SIL/SILFunctionType.cpp +++ b/lib/SIL/SILFunctionType.cpp @@ -323,6 +323,81 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType( getWitnessMethodConformanceOrNone()); } +CanSILFunctionType SILFunctionType::getAutoDiffTransposeFunctionType( + IndexSubset *parameterIndices, Lowering::TypeConverter &TC, + LookupConformanceFn lookupConformance, CanGenericSignature genSig) { + // Get the canonical derivative function generic signature. + if (!genSig) + genSig = getGenericSignature(); + genSig = getAutoDiffDerivativeFunctionGenericSignature( + genSig, getParameters(), parameterIndices, &TC.M); + Lowering::GenericContextScope genericContextScope(TC, genSig); + + // Given a type, returns its formal SIL parameter info. + auto getParameterInfoForOriginalResult = [&]( + const SILResultInfo &result) -> SILParameterInfo { + auto &tl = TC.getTypeLowering( + result.getType(), ResilienceExpansion::Minimal); + ParameterConvention newConv; + switch (result.getConvention()) { + case ResultConvention::Owned: + case ResultConvention::Autoreleased: + newConv = tl.isTrivial() + ? ParameterConvention::Direct_Unowned + : ParameterConvention::Direct_Guaranteed; + break; + case ResultConvention::Unowned: + case ResultConvention::UnownedInnerPointer: + newConv = ParameterConvention::Direct_Unowned; + break; + case ResultConvention::Indirect: + newConv = ParameterConvention::Indirect_In_Guaranteed; + break; + } + return {result.getType()->getCanonicalType(genSig), newConv}; + }; + + // Given a type, returns its formal SIL result info. + auto getResultInfoForOriginalParameter = [&]( + const SILParameterInfo ¶m) -> SILResultInfo { + auto &tl = TC.getTypeLowering( + param.getType(), ResilienceExpansion::Minimal); + ResultConvention newConv; + switch (param.getConvention()) { + case ParameterConvention::Direct_Owned: + case ParameterConvention::Direct_Guaranteed: + case ParameterConvention::Direct_Unowned: + newConv = tl.isTrivial() + ? ResultConvention::Unowned + : ResultConvention::Owned; + break; + case ParameterConvention::Indirect_In: + case ParameterConvention::Indirect_Inout: + case ParameterConvention::Indirect_In_Constant: + case ParameterConvention::Indirect_In_Guaranteed: + case ParameterConvention::Indirect_InoutAliasable: + newConv = ResultConvention::Indirect; + break; + } + return {param.getType()->getCanonicalType(genSig), newConv}; + }; + + SmallVector newParameters; + SmallVector newResults; + for (auto param : llvm::enumerate(getParameters())) { + if (parameterIndices->contains(param.index())) + newResults.push_back(getResultInfoForOriginalParameter(param.value())); + else + newParameters.push_back(param.value()); + } + for (auto &res : getResults()) + newParameters.push_back(getParameterInfoForOriginalResult(res)); + return SILFunctionType::get( + genSig, getExtInfo(), getCoroutineKind(), + getCalleeConvention(), newParameters, getYields(), newResults, + getOptionalErrorResult(), getASTContext()); +} + ClassDecl * SILFunctionType::getWitnessMethodClass() const { auto selfTy = getSelfInstanceType(); diff --git a/lib/SIL/TypeLowering.cpp b/lib/SIL/TypeLowering.cpp index 118a157b72bd8..aa8dc443d43d2 100644 --- a/lib/SIL/TypeLowering.cpp +++ b/lib/SIL/TypeLowering.cpp @@ -146,27 +146,6 @@ namespace { ResilienceExpansion Expansion) : TC(TC), Sig(Sig), Expansion(Expansion) {} - // SWIFT_ENABLE_TENSORFLOW - RecursiveProperties getDifferentiableSILFunctionTypeRecursiveProperties( - CanSILFunctionType type) { - assert(type->isDifferentiable()); - auto &M = TC.M; - auto origTy = type->getWithoutDifferentiability(); - auto jvpTy = origTy->getAutoDiffDerivativeFunctionType( - type->getDifferentiationParameterIndices(), /*resultIndex*/ 0, - AutoDiffDerivativeFunctionKind::JVP, TC, - LookUpConformanceInModule(&M)); - auto vjpTy = origTy->getAutoDiffDerivativeFunctionType( - type->getDifferentiationParameterIndices(), /*resultIndex*/ 0, - AutoDiffDerivativeFunctionKind::VJP, TC, - LookUpConformanceInModule(&M)); - RecursiveProperties props; - props.addSubobject(classifyType(origTy, TC, Sig, Expansion)); - props.addSubobject(classifyType(jvpTy, TC, Sig, Expansion)); - props.addSubobject(classifyType(vjpTy, TC, Sig, Expansion)); - return props; - } - public: // The subclass should implement: // // Trivial, fixed-layout, and non-address-only. @@ -248,8 +227,18 @@ namespace { RetTy visitSILFunctionType(CanSILFunctionType type) { // SWIFT_ENABLE_TENSORFLOW - if (type->isDifferentiable()) - return asImpl().visitDifferentiableSILFunctionType(type); + switch (type->getDifferentiabilityKind()) { + case DifferentiabilityKind::Normal: + return asImpl().visitNormalDifferentiableSILFunctionType( + type, + getNormalDifferentiableSILFunctionTypeRecursiveProperties(type)); + case DifferentiabilityKind::Linear: + return asImpl().visitLinearDifferentiableSILFunctionType( + type, + getLinearDifferentiableSILFunctionTypeRecursiveProperties(type)); + case DifferentiabilityKind::NonDifferentiable: + break; + } // Only escaping closures are references. bool isSwiftEscaping = type->getExtInfo().isNoEscape() && @@ -262,10 +251,48 @@ namespace { } // SWIFT_ENABLE_TENSORFLOW - RetTy visitDifferentiableSILFunctionType(CanSILFunctionType type) { - assert(type->isDifferentiable()); - auto props = getDifferentiableSILFunctionTypeRecursiveProperties(type); - return asImpl().handleAggregateByProperties(type, props); + RecursiveProperties + getNormalDifferentiableSILFunctionTypeRecursiveProperties( + CanSILFunctionType type) { + auto &M = TC.M; + auto origTy = type->getWithoutDifferentiability(); + auto jvpTy = origTy->getAutoDiffDerivativeFunctionType( + type->getDifferentiationParameterIndices(), /*resultIndex*/ 0, + AutoDiffDerivativeFunctionKind::JVP, TC, + LookUpConformanceInModule(&M)); + auto vjpTy = origTy->getAutoDiffDerivativeFunctionType( + type->getDifferentiationParameterIndices(), /*resultIndex*/ 0, + AutoDiffDerivativeFunctionKind::VJP, TC, + LookUpConformanceInModule(&M)); + RecursiveProperties props; + props.addSubobject(classifyType(origTy, TC, Sig, Expansion)); + props.addSubobject(classifyType(jvpTy, TC, Sig, Expansion)); + props.addSubobject(classifyType(vjpTy, TC, Sig, Expansion)); + return props; + } + + RecursiveProperties + getLinearDifferentiableSILFunctionTypeRecursiveProperties( + CanSILFunctionType type) { + auto &M = TC.M; + auto origTy = type->getWithoutDifferentiability(); + auto transTy = origTy->getAutoDiffTransposeFunctionType( + type->getDifferentiationParameterIndices(), TC, + LookUpConformanceInModule(&M)); + RecursiveProperties props; + props.addSubobject(classifyType(origTy, TC, Sig, Expansion)); + props.addSubobject(classifyType(transTy, TC, Sig, Expansion)); + return props; + } + + RetTy visitNormalDifferentiableSILFunctionType( + CanSILFunctionType type, RecursiveProperties props) { + return handleAggregateByProperties(type, props); + } + + RetTy visitLinearDifferentiableSILFunctionType( + CanSILFunctionType type, RecursiveProperties props) { + return handleAggregateByProperties(type, props); } RetTy visitLValueType(CanLValueType type) { @@ -862,9 +889,10 @@ namespace { }; // SWIFT_ENABLE_TENSORFLOW - class DifferentiableSILFunctionTypeLowering final - : public LoadableAggTypeLowering { + class NormalDifferentiableSILFunctionTypeLowering final + : public LoadableAggTypeLowering< + NormalDifferentiableSILFunctionTypeLowering, + DifferentiableFunctionExtractee> { public: using LoadableAggTypeLowering::LoadableAggTypeLowering; @@ -916,6 +944,48 @@ namespace { } }; + class LinearDifferentiableSILFunctionTypeLowering final + : public LoadableAggTypeLowering< + LinearDifferentiableSILFunctionTypeLowering, + LinearDifferentiableFunctionTypeComponent> { + public: + using LoadableAggTypeLowering::LoadableAggTypeLowering; + + SILValue emitRValueProject( + SILBuilder &B, SILLocation loc, SILValue tupleValue, + LinearDifferentiableFunctionTypeComponent component, + const TypeLowering &eltLowering) const { + // TODO: Handle this once `linear_function_extract` instruction exists. + llvm_unreachable("Unhandled"); + } + + SILValue rebuildAggregate(SILBuilder &B, SILLocation loc, + ArrayRef values) const override { + // TODO: Handle this once `linear_function` instruction exists. + llvm_unreachable("Unhandled"); + } + + void lowerChildren(TypeConverter &TC, + SmallVectorImpl &children) const override { + auto fnTy = getLoweredType().castTo(); + children.reserve(2); + auto origFnTy = fnTy->getWithoutDifferentiability(); + auto paramIndices = fnTy->getDifferentiationParameterIndices(); + children.push_back(Child{ + LinearDifferentiableFunctionTypeComponent::Original, + TC.getTypeLowering(origFnTy, getResilienceExpansion()) + }); + auto transposeFnTy = origFnTy->getAutoDiffTransposeFunctionType( + paramIndices, TC, LookUpConformanceInModule(&TC.M)); + auto transposeSILFnTy = SILType::getPrimitiveObjectType(transposeFnTy); + children.push_back(Child{ + LinearDifferentiableFunctionTypeComponent::Transpose, + TC.getTypeLowering(transposeSILFnTy, getResilienceExpansion()) + }); + assert(children.size() == 2); + } + }; + /// A lowering for loadable but non-trivial tuple types. class LoadableTupleTypeLowering final : public LoadableAggTypeLowering { @@ -1412,11 +1482,17 @@ namespace { // SWIFT_ENABLE_TENSORFLOW TypeLowering * - visitDifferentiableSILFunctionType(CanSILFunctionType type) { - assert(type->isDifferentiable()); - auto props = getDifferentiableSILFunctionTypeRecursiveProperties(type); - return handleAggregateByProperties( - type, props); + visitNormalDifferentiableSILFunctionType(CanSILFunctionType type, + RecursiveProperties props) { + return handleAggregateByProperties + (type, props); + } + + TypeLowering * + visitLinearDifferentiableSILFunctionType(CanSILFunctionType type, + RecursiveProperties props) { + return handleAggregateByProperties + (type, props); } template diff --git a/test/AutoDiff/differentiable_func_type.sil b/test/AutoDiff/differentiable_func_type.sil index 2defc5fe001fa..ddab12c33c859 100644 --- a/test/AutoDiff/differentiable_func_type.sil +++ b/test/AutoDiff/differentiable_func_type.sil @@ -1,7 +1,9 @@ // RUN: %empty-directory(%t) // RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name differentiable_func_type // RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.2.sib -module-name differentiable_func_type -// RUN: %target-sil-opt %t/tmp.2.sib -module-name differentiable_func_type | %FileCheck %s +// RUN: %target-sil-opt %t/tmp.2.sib -module-name differentiable_func_type | %FileCheck %s -check-prefix=CHECK-SIL + +// RUN: %target-swift-frontend %s -emit-ir -module-name differentiable_func_type | %FileCheck %s -check-prefix=CHECK-LLVM sil_stage raw @@ -13,9 +15,18 @@ bb0(%0 : $@differentiable(linear) (Float) -> Float): } // CHECK-LABEL: sil @takeAndReturnLinear : $@convention(thin) (@differentiable(linear) (Float) -> Float) -> @differentiable(linear) (Float) -> Float { -// CHECK: bb0([[ARG:%.*]] : $@differentiable(linear) (Float) -> Float): -// CHECK: return [[ARG]] : $@differentiable(linear) (Float) -> Float -// CHECK: } +// CHECK-SIL: bb0([[ARG:%.*]] : $@differentiable(linear) (Float) -> Float): +// CHECK-SIL: return [[ARG]] : $@differentiable(linear) (Float) -> Float +// CHECK-SIL: } + +// CHECK-LLVM-LABEL: define swiftcc { i8*, %swift.refcounted*, i8*, %swift.refcounted* } @takeAndReturnLinear(i8*, %swift.refcounted*, i8*, %swift.refcounted*) #0 { +// CHECK-LLVM: entry: +// CHECK-LLVM: %4 = insertvalue { i8*, %swift.refcounted*, i8*, %swift.refcounted* } undef, i8* %0, 0 +// CHECK-LLVM: %5 = insertvalue { i8*, %swift.refcounted*, i8*, %swift.refcounted* } %4, %swift.refcounted* %1, 1 +// CHECK-LLVM: %6 = insertvalue { i8*, %swift.refcounted*, i8*, %swift.refcounted* } %5, i8* %2, 2 +// CHECK-LLVM: %7 = insertvalue { i8*, %swift.refcounted*, i8*, %swift.refcounted* } %6, %swift.refcounted* %3, 3 +// CHECK-LLVM: ret { i8*, %swift.refcounted*, i8*, %swift.refcounted* } %7 +// CHECK-LLVM: } sil @takeAndReturnDifferentiable : $@convention(thin) (@differentiable (Float) -> Float) -> @differentiable (Float) -> Float { @@ -24,6 +35,41 @@ bb0(%0 : $@differentiable (Float) -> Float): } // CHECK-LABEL: sil @takeAndReturnDifferentiable : $@convention(thin) (@differentiable (Float) -> Float) -> @differentiable (Float) -> Float { -// CHECK: bb0([[ARG:%.*]] : $@differentiable (Float) -> Float): -// CHECK: return [[ARG]] : $@differentiable (Float) -> Float -// CHECK: } \ No newline at end of file +// CHECK-SIL: bb0([[ARG:%.*]] : $@differentiable (Float) -> Float): +// CHECK-SIL: return [[ARG]] : $@differentiable (Float) -> Float +// CHECK-SIL: } + +// CHECK-LLVM-LABEL: define swiftcc void @takeAndReturnDifferentiable(<{ %swift.function, %swift.function, %swift.function }>* noalias nocapture sret, <{ %swift.function, %swift.function, %swift.function }>* noalias nocapture dereferenceable(48)) #0 { +// CHECK-LLVM: entry: +// CHECK-LLVM: %.original = getelementptr inbounds <{ %swift.function, %swift.function, %swift.function }>, <{ %swift.function, %swift.function, %swift.function }>* %1, i32 0, i32 0 +// CHECK-LLVM: %.original.fn = getelementptr inbounds %swift.function, %swift.function* %.original, i32 0, i32 0 +// CHECK-LLVM: %2 = load i8*, i8** %.original.fn, align 8 +// CHECK-LLVM: %.original.data = getelementptr inbounds %swift.function, %swift.function* %.original, i32 0, i32 1 +// CHECK-LLVM: %3 = load %swift.refcounted*, %swift.refcounted** %.original.data, align 8 +// CHECK-LLVM: %.jvp = getelementptr inbounds <{ %swift.function, %swift.function, %swift.function }>, <{ %swift.function, %swift.function, %swift.function }>* %1, i32 0, i32 1 +// CHECK-LLVM: %.jvp.fn = getelementptr inbounds %swift.function, %swift.function* %.jvp, i32 0, i32 0 +// CHECK-LLVM: %4 = load i8*, i8** %.jvp.fn, align 8 +// CHECK-LLVM: %.jvp.data = getelementptr inbounds %swift.function, %swift.function* %.jvp, i32 0, i32 1 +// CHECK-LLVM: %5 = load %swift.refcounted*, %swift.refcounted** %.jvp.data, align 8 +// CHECK-LLVM: %.vjp = getelementptr inbounds <{ %swift.function, %swift.function, %swift.function }>, <{ %swift.function, %swift.function, %swift.function }>* %1, i32 0, i32 2 +// CHECK-LLVM: %.vjp.fn = getelementptr inbounds %swift.function, %swift.function* %.vjp, i32 0, i32 0 +// CHECK-LLVM: %6 = load i8*, i8** %.vjp.fn, align 8 +// CHECK-LLVM: %.vjp.data = getelementptr inbounds %swift.function, %swift.function* %.vjp, i32 0, i32 1 +// CHECK-LLVM: %7 = load %swift.refcounted*, %swift.refcounted** %.vjp.data, align 8 +// CHECK-LLVM: %.original1 = getelementptr inbounds <{ %swift.function, %swift.function, %swift.function }>, <{ %swift.function, %swift.function, %swift.function }>* %0, i32 0, i32 0 +// CHECK-LLVM: %.original1.fn = getelementptr inbounds %swift.function, %swift.function* %.original1, i32 0, i32 0 +// CHECK-LLVM: store i8* %2, i8** %.original1.fn, align 8 +// CHECK-LLVM: %.original1.data = getelementptr inbounds %swift.function, %swift.function* %.original1, i32 0, i32 1 +// CHECK-LLVM: store %swift.refcounted* %3, %swift.refcounted** %.original1.data, align 8 +// CHECK-LLVM: %.jvp2 = getelementptr inbounds <{ %swift.function, %swift.function, %swift.function }>, <{ %swift.function, %swift.function, %swift.function }>* %0, i32 0, i32 1 +// CHECK-LLVM: %.jvp2.fn = getelementptr inbounds %swift.function, %swift.function* %.jvp2, i32 0, i32 0 +// CHECK-LLVM: store i8* %4, i8** %.jvp2.fn, align 8 +// CHECK-LLVM: %.jvp2.data = getelementptr inbounds %swift.function, %swift.function* %.jvp2, i32 0, i32 1 +// CHECK-LLVM: store %swift.refcounted* %5, %swift.refcounted** %.jvp2.data, align 8 +// CHECK-LLVM: %.vjp3 = getelementptr inbounds <{ %swift.function, %swift.function, %swift.function }>, <{ %swift.function, %swift.function, %swift.function }>* %0, i32 0, i32 2 +// CHECK-LLVM: %.vjp3.fn = getelementptr inbounds %swift.function, %swift.function* %.vjp3, i32 0, i32 0 +// CHECK-LLVM: store i8* %6, i8** %.vjp3.fn, align 8 +// CHECK-LLVM: %.vjp3.data = getelementptr inbounds %swift.function, %swift.function* %.vjp3, i32 0, i32 1 +// CHECK-LLVM: store %swift.refcounted* %7, %swift.refcounted** %.vjp3.data, align 8 +// CHECK-LLVM: ret void +// CHECK-LLVM: } From 575c18609f210e1e4718110f57a7f5bf14090a1e Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Mon, 14 Oct 2019 10:23:13 -0700 Subject: [PATCH 2/4] Add line break. Co-Authored-By: Dan Zheng --- lib/IRGen/GenDiffFunc.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/IRGen/GenDiffFunc.cpp b/lib/IRGen/GenDiffFunc.cpp index 3d942a85c6284..98e1b572114ac 100644 --- a/lib/IRGen/GenDiffFunc.cpp +++ b/lib/IRGen/GenDiffFunc.cpp @@ -37,7 +37,8 @@ using namespace irgen; // `@differentiable` (non-linear) function type info //----------------------------------------------------------------------------// namespace { -class DifferentiableFuncFieldInfo final : public RecordField { +class DifferentiableFuncFieldInfo final + : public RecordField { public: DifferentiableFuncFieldInfo( DifferentiableFunctionExtractee component, const TypeInfo &type, From f7a042b4eec4a232e26694eac7fd52f0010d0295 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Mon, 14 Oct 2019 10:50:01 -0700 Subject: [PATCH 3/4] Update test --- test/AutoDiff/differentiable_func_type.sil | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/AutoDiff/differentiable_func_type.sil b/test/AutoDiff/differentiable_func_type.sil index ddab12c33c859..9ca751d07eff3 100644 --- a/test/AutoDiff/differentiable_func_type.sil +++ b/test/AutoDiff/differentiable_func_type.sil @@ -19,7 +19,7 @@ bb0(%0 : $@differentiable(linear) (Float) -> Float): // CHECK-SIL: return [[ARG]] : $@differentiable(linear) (Float) -> Float // CHECK-SIL: } -// CHECK-LLVM-LABEL: define swiftcc { i8*, %swift.refcounted*, i8*, %swift.refcounted* } @takeAndReturnLinear(i8*, %swift.refcounted*, i8*, %swift.refcounted*) #0 { +// CHECK-LLVM-LABEL: define {{.*}} swiftcc { i8*, %swift.refcounted*, i8*, %swift.refcounted* } @takeAndReturnLinear(i8*, %swift.refcounted*, i8*, %swift.refcounted*) #0 { // CHECK-LLVM: entry: // CHECK-LLVM: %4 = insertvalue { i8*, %swift.refcounted*, i8*, %swift.refcounted* } undef, i8* %0, 0 // CHECK-LLVM: %5 = insertvalue { i8*, %swift.refcounted*, i8*, %swift.refcounted* } %4, %swift.refcounted* %1, 1 @@ -39,7 +39,7 @@ bb0(%0 : $@differentiable (Float) -> Float): // CHECK-SIL: return [[ARG]] : $@differentiable (Float) -> Float // CHECK-SIL: } -// CHECK-LLVM-LABEL: define swiftcc void @takeAndReturnDifferentiable(<{ %swift.function, %swift.function, %swift.function }>* noalias nocapture sret, <{ %swift.function, %swift.function, %swift.function }>* noalias nocapture dereferenceable(48)) #0 { +// CHECK-LLVM-LABEL: define {{.*}} swiftcc void @takeAndReturnDifferentiable(<{ %swift.function, %swift.function, %swift.function }>* noalias nocapture sret, <{ %swift.function, %swift.function, %swift.function }>* noalias nocapture dereferenceable(48)) #0 { // CHECK-LLVM: entry: // CHECK-LLVM: %.original = getelementptr inbounds <{ %swift.function, %swift.function, %swift.function }>, <{ %swift.function, %swift.function, %swift.function }>* %1, i32 0, i32 0 // CHECK-LLVM: %.original.fn = getelementptr inbounds %swift.function, %swift.function* %.original, i32 0, i32 0 From 733c262fc09323e77a0f498072c68b0891499567 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Mon, 14 Oct 2019 12:43:36 -0700 Subject: [PATCH 4/4] Fix test. --- test/AutoDiff/differentiable_func_type.sil | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/AutoDiff/differentiable_func_type.sil b/test/AutoDiff/differentiable_func_type.sil index 9ca751d07eff3..0a57e14d25fb1 100644 --- a/test/AutoDiff/differentiable_func_type.sil +++ b/test/AutoDiff/differentiable_func_type.sil @@ -19,7 +19,7 @@ bb0(%0 : $@differentiable(linear) (Float) -> Float): // CHECK-SIL: return [[ARG]] : $@differentiable(linear) (Float) -> Float // CHECK-SIL: } -// CHECK-LLVM-LABEL: define {{.*}} swiftcc { i8*, %swift.refcounted*, i8*, %swift.refcounted* } @takeAndReturnLinear(i8*, %swift.refcounted*, i8*, %swift.refcounted*) #0 { +// CHECK-LLVM-LABEL: define{{.*}} swiftcc { i8*, %swift.refcounted*, i8*, %swift.refcounted* } @takeAndReturnLinear(i8*, %swift.refcounted*, i8*, %swift.refcounted*) #0 { // CHECK-LLVM: entry: // CHECK-LLVM: %4 = insertvalue { i8*, %swift.refcounted*, i8*, %swift.refcounted* } undef, i8* %0, 0 // CHECK-LLVM: %5 = insertvalue { i8*, %swift.refcounted*, i8*, %swift.refcounted* } %4, %swift.refcounted* %1, 1 @@ -39,7 +39,7 @@ bb0(%0 : $@differentiable (Float) -> Float): // CHECK-SIL: return [[ARG]] : $@differentiable (Float) -> Float // CHECK-SIL: } -// CHECK-LLVM-LABEL: define {{.*}} swiftcc void @takeAndReturnDifferentiable(<{ %swift.function, %swift.function, %swift.function }>* noalias nocapture sret, <{ %swift.function, %swift.function, %swift.function }>* noalias nocapture dereferenceable(48)) #0 { +// CHECK-LLVM-LABEL: define{{.*}} swiftcc void @takeAndReturnDifferentiable(<{ %swift.function, %swift.function, %swift.function }>* noalias nocapture sret, <{ %swift.function, %swift.function, %swift.function }>* noalias nocapture dereferenceable(48)) #0 { // CHECK-LLVM: entry: // CHECK-LLVM: %.original = getelementptr inbounds <{ %swift.function, %swift.function, %swift.function }>, <{ %swift.function, %swift.function, %swift.function }>* %1, i32 0, i32 0 // CHECK-LLVM: %.original.fn = getelementptr inbounds %swift.function, %swift.function* %.original, i32 0, i32 0