@@ -252,8 +252,6 @@ class AttributeChecker : public AttributeVisitor<AttributeChecker> {
252252  void  visitTransposeAttr (TransposeAttr *attr);
253253  //  TODO(TF-999): Remove deprecated `@differentiating` attribute.
254254  void  visitDifferentiatingAttr (DerivativeAttr *attr);
255-   //  TODO(TF-999): Remove deprecated `@transposing` attribute.
256-   void  visitTransposingAttr (TransposeAttr *attr);
257255  void  visitCompilerEvaluableAttr (CompilerEvaluableAttr *attr);
258256  void  visitNoDerivativeAttr (NoDerivativeAttr *attr);
259257  //  SWIFT_ENABLE_TENSORFLOW END
@@ -3094,15 +3092,13 @@ static IndexSubset *computeDifferentiationParameters(
30943092}
30953093
30963094//  SWIFT_ENABLE_TENSORFLOW
3097- //  Computes `IndexSubset` from the given parsed transposing parameters
3098- //  (possibly empty) for the given function, then verifies that the parameter
3099- //  indices are valid.
3100- //  The attribute name/location are used in diagnostics.
3101- static  IndexSubset *computeTransposingParameters (
3095+ //  Computes `IndexSubset` from the given parsed transposed parameters (possibly
3096+ //  empty) for the given function, then verifies that the parameter indices are
3097+ //  valid. The attribute name/location are used in diagnostics.
3098+ static  IndexSubset *computeTransposedParameters (
31023099    ArrayRef<ParsedAutoDiffParameter> parsedWrtParams,
31033100    AbstractFunctionDecl *transposeFunction, bool  isCurried,
3104-     GenericEnvironment *derivativeGenEnv, SourceLoc attrLoc
3105- ) {
3101+     GenericEnvironment *derivativeGenEnv, SourceLoc attrLoc) {
31063102  auto  &ctx = transposeFunction->getASTContext ();
31073103  auto  &diags = ctx.Diags ;
31083104
@@ -3248,11 +3244,10 @@ static bool checkDifferentiationParameters(
32483244//  context. Returns true on error.
32493245//  The parsed differentiation parameters and attribute location are used in
32503246//  diagnostics.
3251- static  bool  checkTransposingParameters (
3252-     AbstractFunctionDecl *AFD,
3253-     SmallVector<Type, 4 > wrtParamTypes, GenericEnvironment *derivativeGenEnv,
3254-     ModuleDecl *module , ArrayRef<ParsedAutoDiffParameter> parsedWrtParams,
3255-     SourceLoc attrLoc) {
3247+ static  bool  checkTransposedParameters (
3248+     AbstractFunctionDecl *AFD, SmallVector<Type, 4 > wrtParamTypes,
3249+     GenericEnvironment *derivativeGenEnv, ModuleDecl *module ,
3250+     ArrayRef<ParsedAutoDiffParameter> parsedWrtParams, SourceLoc attrLoc) {
32563251  auto  &ctx = AFD->getASTContext ();
32573252  auto  &diags = ctx.Diags ;
32583253
@@ -4001,17 +3996,17 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
40013996  //  If checked wrt param indices are not specified, compute them.
40023997  bool  isCurried = transposeInterfaceType->getResult ()->is <AnyFunctionType>();
40033998  if  (!wrtParamIndices)
4004-     wrtParamIndices = computeTransposingParameters (
3999+     wrtParamIndices = computeTransposedParameters (
40054000        parsedWrtParams, transpose, isCurried,
40064001        transpose->getGenericEnvironment (), attr->getLocation ());
40074002  if  (!wrtParamIndices) {
40084003    D->getAttrs ().removeAttribute (attr);
40094004    attr->setInvalid ();
40104005    return ;
40114006  }
4012-    
4007+ 
40134008  //  Diagnose empty parameter indices. This occurs when no `wrt` clause is
4014-   //  declared and no differentiation  parameters can be inferred.
4009+   //  declared and no transposed  parameters can be inferred.
40154010  if  (wrtParamIndices->isEmpty ()) {
40164011    diagnose (attr->getLocation (),
40174012             diag::diff_params_clause_no_inferred_parameters);
@@ -4104,8 +4099,8 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
41044099             originalName.Name );
41054100  };
41064101
4107-   //  Returns true if the derivative  function and original function candidate are
4108-   //  defined in compatible type contexts. If the derivative  function and the
4102+   //  Returns true if the transpose  function and original function candidate are
4103+   //  defined in compatible type contexts. If the transpose  function and the
41094104  //  original function candidate have different parents, return false.
41104105  std::function<bool (AbstractFunctionDecl *)> hasValidTypeContext =
41114106      [&](AbstractFunctionDecl *decl) { return  true ; };
@@ -4138,34 +4133,29 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
41384133
41394134  attr->setOriginalFunction (originalAFD);
41404135
4141-   //  Gather differentiation parameters.
4142-   //  Differentiation parameters are with respect to the original function.
4136+   //  Get the transposed parameter types.
41434137  SmallVector<Type, 4 > wrtParamTypes;
41444138  autodiff::getSubsetParameterTypes (wrtParamIndices, expectedOriginalFnType,
41454139                                    wrtParamTypes);
41464140
4147-   //  Check if differentiation  parameter indices are valid.
4148-   if  (checkTransposingParameters (originalAFD, wrtParamTypes,
4149-                                   transpose->getGenericEnvironment (),
4150-                                   transpose->getModuleContext (), parsedWrtParams,
4151-                                   attr->getLocation ())) {
4141+   //  Check if transposed  parameter indices are valid.
4142+   if  (checkTransposedParameters (originalAFD, wrtParamTypes,
4143+                                 transpose->getGenericEnvironment (),
4144+                                 transpose->getModuleContext (), parsedWrtParams,
4145+                                 attr->getLocation ())) {
41524146    D->getAttrs ().removeAttribute (attr);
41534147    attr->setInvalid ();
41544148    return ;
41554149  }
41564150
4157-   //  Set the checked differentiation  parameter indices in the attribute.
4151+   //  Set the checked transposed  parameter indices in the attribute.
41584152  attr->setParameterIndices (wrtParamIndices);
41594153}
41604154
41614155void  AttributeChecker::visitDifferentiatingAttr (DerivativeAttr *attr) {
41624156  visitDerivativeAttr (attr);
41634157}
41644158
4165- void  AttributeChecker::visitTransposingAttr (TransposeAttr *attr) {
4166-   visitTransposeAttr (attr);
4167- }
4168- 
41694159static  bool 
41704160compilerEvaluableAllowedInExtensionDecl (ExtensionDecl *extensionDecl) {
41714161  auto  extendedTypeKind = extensionDecl->getExtendedType ()->getKind ();
0 commit comments