diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index 608905298baff..04c2ac6981575 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -786,8 +786,9 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived, } // Add a typealias declaration with the given name and underlying target - // struct type to the source struct. - auto addAssociatedTypeAliasDecl = [&](Identifier name, StructDecl *source, + // struct type to the source nominal type. + auto addAssociatedTypeAliasDecl = [&](Identifier name, + NominalTypeDecl *source, StructDecl *target) { auto lookup = source->lookupDirect(name); assert(lookup.size() < 2 && @@ -845,6 +846,10 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived, allDiffableVarsStruct, allDiffableVarsStruct); addAssociatedTypeAliasDecl(C.Id_CotangentVector, allDiffableVarsStruct, allDiffableVarsStruct); + addAssociatedTypeAliasDecl(C.Id_TangentVector, + nominal, allDiffableVarsStruct); + addAssociatedTypeAliasDecl(C.Id_CotangentVector, + nominal, allDiffableVarsStruct); TC.validateDecl(allDiffableVarsStruct); return parentDC->mapTypeIntoContext( allDiffableVarsStruct->getDeclaredInterfaceType()); diff --git a/test/AutoDiff/derived_differentiable_properties.swift b/test/AutoDiff/derived_differentiable_properties.swift index d68b2d013463f..6e8c219112b78 100644 --- a/test/AutoDiff/derived_differentiable_properties.swift +++ b/test/AutoDiff/derived_differentiable_properties.swift @@ -13,8 +13,8 @@ public struct Foo : Differentiable { // CHECK-AST: @_fieldwiseProductSpace typealias AllDifferentiableVariables = Foo.AllDifferentiableVariables // CHECK-AST: @_fieldwiseProductSpace typealias TangentVector = Foo.AllDifferentiableVariables // CHECK-AST: @_fieldwiseProductSpace typealias CotangentVector = Foo.AllDifferentiableVariables -// CHECK-AST: typealias TangentVector = Foo.AllDifferentiableVariables -// CHECK-AST: typealias CotangentVector = Foo.AllDifferentiableVariables +// CHECK-AST: @_fieldwiseProductSpace typealias TangentVector = Foo.AllDifferentiableVariables +// CHECK-AST: @_fieldwiseProductSpace typealias CotangentVector = Foo.AllDifferentiableVariables // CHECK-SILGEN-LABEL: // Foo.a.getter // CHECK-SILGEN: sil [transparent] [serialized] [differentiable source 0 wrt 0] @$s33derived_differentiable_properties3FooV1aSfvg : $@convention(method) (Foo) -> Float