Skip to content

[OpenMP][clang] 6.0: parsing/sema for num_threads 'strict' modifier #145490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 24, 2025

Conversation

ro-i
Copy link
Contributor

@ro-i ro-i commented Jun 24, 2025

Implement parsing and semantic analysis support for the optional 'strict' modifier of the num_threads clause. This modifier has been introduced in OpenMP 6.0, section 12.1.2.

Note: this is basically 1:1 https://reviews.llvm.org/D138328.

Implement parsing and semantic analysis support for the optional
'strict' modifier of the num_threads clause. This modifier has been
introduced in OpenMP 6.0, section 12.1.2.
@ro-i ro-i requested review from alexey-bataev and kparzysz June 24, 2025 10:11
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules flang:openmp clang:openmp OpenMP related changes to Clang labels Jun 24, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 24, 2025

@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-clang-modules

Author: Robert Imschweiler (ro-i)

Changes

Implement parsing and semantic analysis support for the optional 'strict' modifier of the num_threads clause. This modifier has been introduced in OpenMP 6.0, section 12.1.2.

Note: this is basically 1:1 https://reviews.llvm.org/D138328.


Patch is 33.07 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145490.diff

14 Files Affected:

  • (modified) clang/include/clang/AST/OpenMPClause.h (+25-4)
  • (modified) clang/include/clang/Basic/OpenMPKinds.def (+7)
  • (modified) clang/include/clang/Basic/OpenMPKinds.h (+6)
  • (modified) clang/include/clang/Sema/SemaOpenMP.h (+4-4)
  • (modified) clang/lib/AST/OpenMPClause.cpp (+5)
  • (modified) clang/lib/Basic/OpenMPKinds.cpp (+19-2)
  • (modified) clang/lib/Parse/ParseOpenMP.cpp (+31-2)
  • (modified) clang/lib/Sema/SemaOpenMP.cpp (+46-30)
  • (modified) clang/lib/Sema/TreeTransform.h (+7-4)
  • (modified) clang/lib/Serialization/ASTReader.cpp (+2)
  • (modified) clang/lib/Serialization/ASTWriter.cpp (+2)
  • (modified) clang/test/OpenMP/parallel_ast_print.cpp (+41-1)
  • (modified) clang/test/OpenMP/parallel_num_threads_messages.cpp (+70-3)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMP.td (+7)
diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h
index 05239668b34b1..c6f99fb21a0f0 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -824,31 +824,52 @@ class OMPNumThreadsClause final
       public OMPClauseWithPreInit {
   friend class OMPClauseReader;
 
+  /// Modifiers for 'num_threads' clause.
+  OpenMPNumThreadsClauseModifier Modifier = OMPC_NUMTHREADS_unknown;
+
+  /// Location of the modifier.
+  SourceLocation ModifierLoc;
+
+  /// Sets modifier.
+  void setModifier(OpenMPNumThreadsClauseModifier M) { Modifier = M; }
+
+  /// Sets modifier location.
+  void setModifierLoc(SourceLocation Loc) { ModifierLoc = Loc; }
+
   /// Set condition.
   void setNumThreads(Expr *NThreads) { setStmt(NThreads); }
 
 public:
   /// Build 'num_threads' clause with condition \a NumThreads.
   ///
+  /// \param Modifier Clause modifier.
   /// \param NumThreads Number of threads for the construct.
   /// \param HelperNumThreads Helper Number of threads for the construct.
   /// \param CaptureRegion Innermost OpenMP region where expressions in this
   /// clause must be captured.
   /// \param StartLoc Starting location of the clause.
   /// \param LParenLoc Location of '('.
+  /// \param ModifierLoc Modifier location.
   /// \param EndLoc Ending location of the clause.
-  OMPNumThreadsClause(Expr *NumThreads, Stmt *HelperNumThreads,
-                      OpenMPDirectiveKind CaptureRegion,
+  OMPNumThreadsClause(OpenMPNumThreadsClauseModifier Modifier, Expr *NumThreads,
+                      Stmt *HelperNumThreads, OpenMPDirectiveKind CaptureRegion,
                       SourceLocation StartLoc, SourceLocation LParenLoc,
-                      SourceLocation EndLoc)
+                      SourceLocation ModifierLoc, SourceLocation EndLoc)
       : OMPOneStmtClause(NumThreads, StartLoc, LParenLoc, EndLoc),
-        OMPClauseWithPreInit(this) {
+        OMPClauseWithPreInit(this), Modifier(Modifier),
+        ModifierLoc(ModifierLoc) {
     setPreInitStmt(HelperNumThreads, CaptureRegion);
   }
 
   /// Build an empty clause.
   OMPNumThreadsClause() : OMPOneStmtClause(), OMPClauseWithPreInit(this) {}
 
+  /// Gets modifier.
+  OpenMPNumThreadsClauseModifier getModifier() const { return Modifier; }
+
+  /// Gets modifier location.
+  SourceLocation getModifierLoc() const { return ModifierLoc; }
+
   /// Returns number of threads.
   Expr *getNumThreads() const { return getStmtAs<Expr>(); }
 };
diff --git a/clang/include/clang/Basic/OpenMPKinds.def b/clang/include/clang/Basic/OpenMPKinds.def
index 2b1dc1e0121b2..9d6f816eea91f 100644
--- a/clang/include/clang/Basic/OpenMPKinds.def
+++ b/clang/include/clang/Basic/OpenMPKinds.def
@@ -86,6 +86,9 @@
 #ifndef OPENMP_NUMTASKS_MODIFIER
 #define OPENMP_NUMTASKS_MODIFIER(Name)
 #endif
+#ifndef OPENMP_NUMTHREADS_MODIFIER
+#define OPENMP_NUMTHREADS_MODIFIER(Name)
+#endif
 #ifndef OPENMP_DOACROSS_MODIFIER
 #define OPENMP_DOACROSS_MODIFIER(Name)
 #endif
@@ -227,6 +230,9 @@ OPENMP_GRAINSIZE_MODIFIER(strict)
 // Modifiers for the 'num_tasks' clause.
 OPENMP_NUMTASKS_MODIFIER(strict)
 
+// Modifiers for the 'num_tasks' clause.
+OPENMP_NUMTHREADS_MODIFIER(strict)
+
 // Modifiers for 'allocate' clause.
 OPENMP_ALLOCATE_MODIFIER(allocator)
 OPENMP_ALLOCATE_MODIFIER(align)
@@ -238,6 +244,7 @@ OPENMP_DOACROSS_MODIFIER(sink_omp_cur_iteration)
 OPENMP_DOACROSS_MODIFIER(source_omp_cur_iteration)
 
 #undef OPENMP_NUMTASKS_MODIFIER
+#undef OPENMP_NUMTHREADS_MODIFIER
 #undef OPENMP_GRAINSIZE_MODIFIER
 #undef OPENMP_BIND_KIND
 #undef OPENMP_ADJUST_ARGS_KIND
diff --git a/clang/include/clang/Basic/OpenMPKinds.h b/clang/include/clang/Basic/OpenMPKinds.h
index 9afcce21a499d..f40db4c13c55a 100644
--- a/clang/include/clang/Basic/OpenMPKinds.h
+++ b/clang/include/clang/Basic/OpenMPKinds.h
@@ -223,6 +223,12 @@ enum OpenMPNumTasksClauseModifier {
   OMPC_NUMTASKS_unknown
 };
 
+enum OpenMPNumThreadsClauseModifier {
+#define OPENMP_NUMTHREADS_MODIFIER(Name) OMPC_NUMTHREADS_##Name,
+#include "clang/Basic/OpenMPKinds.def"
+  OMPC_NUMTHREADS_unknown
+};
+
 /// OpenMP dependence types for 'doacross' clause.
 enum OpenMPDoacrossClauseModifier {
 #define OPENMP_DOACROSS_MODIFIER(Name) OMPC_DOACROSS_##Name,
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index 7b169f56b6807..91c3d4bd5210e 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -877,10 +877,10 @@ class SemaOpenMP : public SemaBase {
                                     SourceLocation LParenLoc,
                                     SourceLocation EndLoc);
   /// Called on well-formed 'num_threads' clause.
-  OMPClause *ActOnOpenMPNumThreadsClause(Expr *NumThreads,
-                                         SourceLocation StartLoc,
-                                         SourceLocation LParenLoc,
-                                         SourceLocation EndLoc);
+  OMPClause *ActOnOpenMPNumThreadsClause(
+      OpenMPNumThreadsClauseModifier Modifier, Expr *NumThreads,
+      SourceLocation StartLoc, SourceLocation LParenLoc,
+      SourceLocation ModifierLoc, SourceLocation EndLoc);
   /// Called on well-formed 'align' clause.
   OMPClause *ActOnOpenMPAlignClause(Expr *Alignment, SourceLocation StartLoc,
                                     SourceLocation LParenLoc,
diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index 8b1caa05eec32..de8b5996818de 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -1830,6 +1830,11 @@ void OMPClausePrinter::VisitOMPFinalClause(OMPFinalClause *Node) {
 
 void OMPClausePrinter::VisitOMPNumThreadsClause(OMPNumThreadsClause *Node) {
   OS << "num_threads(";
+  OpenMPNumThreadsClauseModifier Modifier = Node->getModifier();
+  if (Modifier != OMPC_NUMTHREADS_unknown) {
+    OS << getOpenMPSimpleClauseTypeName(Node->getClauseKind(), Modifier)
+       << ": ";
+  }
   Node->getNumThreads()->printPretty(OS, nullptr, Policy, 0);
   OS << ")";
 }
diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index a451fc7c01841..d3d393bd09396 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -185,11 +185,19 @@ unsigned clang::getOpenMPSimpleClauseType(OpenMPClauseKind Kind, StringRef Str,
 #define OPENMP_ALLOCATE_MODIFIER(Name) .Case(#Name, OMPC_ALLOCATE_##Name)
 #include "clang/Basic/OpenMPKinds.def"
         .Default(OMPC_ALLOCATE_unknown);
+  case OMPC_num_threads: {
+    unsigned Type = llvm::StringSwitch<unsigned>(Str)
+#define OPENMP_NUMTHREADS_MODIFIER(Name) .Case(#Name, OMPC_NUMTHREADS_##Name)
+#include "clang/Basic/OpenMPKinds.def"
+                        .Default(OMPC_NUMTHREADS_unknown);
+    if (LangOpts.OpenMP < 60)
+      return OMPC_NUMTHREADS_unknown;
+    return Type;
+  }
   case OMPC_unknown:
   case OMPC_threadprivate:
   case OMPC_if:
   case OMPC_final:
-  case OMPC_num_threads:
   case OMPC_safelen:
   case OMPC_simdlen:
   case OMPC_sizes:
@@ -520,11 +528,20 @@ const char *clang::getOpenMPSimpleClauseTypeName(OpenMPClauseKind Kind,
 #include "clang/Basic/OpenMPKinds.def"
     }
     llvm_unreachable("Invalid OpenMP 'allocate' clause modifier");
+  case OMPC_num_threads:
+    switch (Type) {
+    case OMPC_NUMTHREADS_unknown:
+      return "unknown";
+#define OPENMP_NUMTHREADS_MODIFIER(Name)                                       \
+  case OMPC_NUMTHREADS_##Name:                                                 \
+    return #Name;
+#include "clang/Basic/OpenMPKinds.def"
+    }
+    llvm_unreachable("Invalid OpenMP 'num_threads' clause modifier");
   case OMPC_unknown:
   case OMPC_threadprivate:
   case OMPC_if:
   case OMPC_final:
-  case OMPC_num_threads:
   case OMPC_safelen:
   case OMPC_simdlen:
   case OMPC_sizes:
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index 78d3503d8eb68..f694ae1d0d112 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -3196,7 +3196,8 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
     if ((CKind == OMPC_ordered || CKind == OMPC_partial) &&
         PP.LookAhead(/*N=*/0).isNot(tok::l_paren))
       Clause = ParseOpenMPClause(CKind, WrongDirective);
-    else if (CKind == OMPC_grainsize || CKind == OMPC_num_tasks)
+    else if (CKind == OMPC_grainsize || CKind == OMPC_num_tasks ||
+             CKind == OMPC_num_threads)
       Clause = ParseOpenMPSingleExprWithArgClause(DKind, CKind, WrongDirective);
     else
       Clause = ParseOpenMPSingleExprClause(CKind, WrongDirective);
@@ -3981,6 +3982,33 @@ OMPClause *Parser::ParseOpenMPSingleExprWithArgClause(OpenMPDirectiveKind DKind,
       Arg.push_back(OMPC_NUMTASKS_unknown);
       KLoc.emplace_back();
     }
+  } else if (Kind == OMPC_num_threads) {
+    // Parse optional <num_threads modifier> ':'
+    OpenMPNumThreadsClauseModifier Modifier =
+        static_cast<OpenMPNumThreadsClauseModifier>(getOpenMPSimpleClauseType(
+            Kind, Tok.isAnnotation() ? "" : PP.getSpelling(Tok),
+            getLangOpts()));
+    if (getLangOpts().OpenMP >= 60) {
+      if (NextToken().is(tok::colon)) {
+        Arg.push_back(Modifier);
+        KLoc.push_back(Tok.getLocation());
+        // Parse modifier
+        ConsumeAnyToken();
+        // Parse ':'
+        ConsumeAnyToken();
+      } else {
+        if (Modifier == OMPC_NUMTHREADS_strict) {
+          Diag(Tok, diag::err_modifier_expected_colon) << "strict";
+          // Parse modifier
+          ConsumeAnyToken();
+        }
+        Arg.push_back(OMPC_NUMTHREADS_unknown);
+        KLoc.emplace_back();
+      }
+    } else {
+      Arg.push_back(OMPC_NUMTHREADS_unknown);
+      KLoc.emplace_back();
+    }
   } else {
     assert(Kind == OMPC_if);
     KLoc.push_back(Tok.getLocation());
@@ -4004,7 +4032,8 @@ OMPClause *Parser::ParseOpenMPSingleExprWithArgClause(OpenMPDirectiveKind DKind,
   bool NeedAnExpression = (Kind == OMPC_schedule && DelimLoc.isValid()) ||
                           (Kind == OMPC_dist_schedule && DelimLoc.isValid()) ||
                           Kind == OMPC_if || Kind == OMPC_device ||
-                          Kind == OMPC_grainsize || Kind == OMPC_num_tasks;
+                          Kind == OMPC_grainsize || Kind == OMPC_num_tasks ||
+                          Kind == OMPC_num_threads;
   if (NeedAnExpression) {
     SourceLocation ELoc = Tok.getLocation();
     ExprResult LHS(ParseCastExpression(CastParseKind::AnyCastExpr, false,
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 00f4658180807..a30acbe9a4bca 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -15509,9 +15509,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
   case OMPC_final:
     Res = ActOnOpenMPFinalClause(Expr, StartLoc, LParenLoc, EndLoc);
     break;
-  case OMPC_num_threads:
-    Res = ActOnOpenMPNumThreadsClause(Expr, StartLoc, LParenLoc, EndLoc);
-    break;
   case OMPC_safelen:
     Res = ActOnOpenMPSafelenClause(Expr, StartLoc, LParenLoc, EndLoc);
     break;
@@ -15565,6 +15562,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
     break;
   case OMPC_grainsize:
   case OMPC_num_tasks:
+  case OMPC_num_threads:
   case OMPC_device:
   case OMPC_if:
   case OMPC_default:
@@ -15911,10 +15909,41 @@ isNonNegativeIntegerValue(Expr *&ValExpr, Sema &SemaRef, OpenMPClauseKind CKind,
   return true;
 }
 
-OMPClause *SemaOpenMP::ActOnOpenMPNumThreadsClause(Expr *NumThreads,
-                                                   SourceLocation StartLoc,
-                                                   SourceLocation LParenLoc,
-                                                   SourceLocation EndLoc) {
+static std::string getListOfPossibleValues(OpenMPClauseKind K, unsigned First,
+                                           unsigned Last,
+                                           ArrayRef<unsigned> Exclude = {}) {
+  SmallString<256> Buffer;
+  llvm::raw_svector_ostream Out(Buffer);
+  unsigned Skipped = Exclude.size();
+  for (unsigned I = First; I < Last; ++I) {
+    if (llvm::is_contained(Exclude, I)) {
+      --Skipped;
+      continue;
+    }
+    Out << "'" << getOpenMPSimpleClauseTypeName(K, I) << "'";
+    if (I + Skipped + 2 == Last)
+      Out << " or ";
+    else if (I + Skipped + 1 != Last)
+      Out << ", ";
+  }
+  return std::string(Out.str());
+}
+
+OMPClause *SemaOpenMP::ActOnOpenMPNumThreadsClause(
+    OpenMPNumThreadsClauseModifier Modifier, Expr *NumThreads,
+    SourceLocation StartLoc, SourceLocation LParenLoc,
+    SourceLocation ModifierLoc, SourceLocation EndLoc) {
+  assert((ModifierLoc.isInvalid() || getLangOpts().OpenMP >= 60) &&
+         "Unexpected num_threads modifier in OpenMP < 60.");
+
+  if (ModifierLoc.isValid() && Modifier == OMPC_NUMTHREADS_unknown) {
+    std::string Values = getListOfPossibleValues(OMPC_num_threads, /*First=*/0,
+                                                 OMPC_NUMTHREADS_unknown);
+    Diag(ModifierLoc, diag::err_omp_unexpected_clause_value)
+        << Values << getOpenMPClauseNameForDiag(OMPC_num_threads);
+    return nullptr;
+  }
+
   Expr *ValExpr = NumThreads;
   Stmt *HelperValStmt = nullptr;
 
@@ -15935,8 +15964,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPNumThreadsClause(Expr *NumThreads,
     HelperValStmt = buildPreInits(getASTContext(), Captures);
   }
 
-  return new (getASTContext()) OMPNumThreadsClause(
-      ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc);
+  return new (getASTContext())
+      OMPNumThreadsClause(Modifier, ValExpr, HelperValStmt, CaptureRegion,
+                          StartLoc, LParenLoc, ModifierLoc, EndLoc);
 }
 
 ExprResult SemaOpenMP::VerifyPositiveIntegerConstantInClause(
@@ -16301,26 +16331,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSimpleClause(
   return Res;
 }
 
-static std::string getListOfPossibleValues(OpenMPClauseKind K, unsigned First,
-                                           unsigned Last,
-                                           ArrayRef<unsigned> Exclude = {}) {
-  SmallString<256> Buffer;
-  llvm::raw_svector_ostream Out(Buffer);
-  unsigned Skipped = Exclude.size();
-  for (unsigned I = First; I < Last; ++I) {
-    if (llvm::is_contained(Exclude, I)) {
-      --Skipped;
-      continue;
-    }
-    Out << "'" << getOpenMPSimpleClauseTypeName(K, I) << "'";
-    if (I + Skipped + 2 == Last)
-      Out << " or ";
-    else if (I + Skipped + 1 != Last)
-      Out << ", ";
-  }
-  return std::string(Out.str());
-}
-
 OMPClause *SemaOpenMP::ActOnOpenMPDefaultClause(DefaultKind Kind,
                                                 SourceLocation KindKwLoc,
                                                 SourceLocation StartLoc,
@@ -16693,8 +16703,14 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprWithArgClause(
         static_cast<OpenMPNumTasksClauseModifier>(Argument.back()), Expr,
         StartLoc, LParenLoc, ArgumentLoc.back(), EndLoc);
     break;
-  case OMPC_final:
   case OMPC_num_threads:
+    assert(Argument.size() == 1 && ArgumentLoc.size() == 1 &&
+           "Modifier for num_threads clause and its location are expected.");
+    Res = ActOnOpenMPNumThreadsClause(
+        static_cast<OpenMPNumThreadsClauseModifier>(Argument.back()), Expr,
+        StartLoc, LParenLoc, ArgumentLoc.back(), EndLoc);
+    break;
+  case OMPC_final:
   case OMPC_safelen:
   case OMPC_simdlen:
   case OMPC_sizes:
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 26bee7a96de22..0d58587cb8a99 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -1714,12 +1714,14 @@ class TreeTransform {
   ///
   /// By default, performs semantic analysis to build the new OpenMP clause.
   /// Subclasses may override this routine to provide different behavior.
-  OMPClause *RebuildOMPNumThreadsClause(Expr *NumThreads,
+  OMPClause *RebuildOMPNumThreadsClause(OpenMPNumThreadsClauseModifier Modifier,
+                                        Expr *NumThreads,
                                         SourceLocation StartLoc,
                                         SourceLocation LParenLoc,
+                                        SourceLocation ModifierLoc,
                                         SourceLocation EndLoc) {
-    return getSema().OpenMP().ActOnOpenMPNumThreadsClause(NumThreads, StartLoc,
-                                                          LParenLoc, EndLoc);
+    return getSema().OpenMP().ActOnOpenMPNumThreadsClause(
+        Modifier, NumThreads, StartLoc, LParenLoc, ModifierLoc, EndLoc);
   }
 
   /// Build a new OpenMP 'safelen' clause.
@@ -10461,7 +10463,8 @@ TreeTransform<Derived>::TransformOMPNumThreadsClause(OMPNumThreadsClause *C) {
   if (NumThreads.isInvalid())
     return nullptr;
   return getDerived().RebuildOMPNumThreadsClause(
-      NumThreads.get(), C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc());
+      C->getModifier(), NumThreads.get(), C->getBeginLoc(), C->getLParenLoc(),
+      C->getModifierLoc(), C->getEndLoc());
 }
 
 template <typename Derived>
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 6f082fe840b4c..b696cb2efee3d 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -11461,7 +11461,9 @@ void OMPClauseReader::VisitOMPFinalClause(OMPFinalClause *C) {
 
 void OMPClauseReader::VisitOMPNumThreadsClause(OMPNumThreadsClause *C) {
   VisitOMPClauseWithPreInit(C);
+  C->setModifier(Record.readEnum<OpenMPNumThreadsClauseModifier>());
   C->setNumThreads(Record.readSubExpr());
+  C->setModifierLoc(Record.readSourceLocation());
   C->setLParenLoc(Record.readSourceLocation());
 }
 
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index c6487c5366a29..4cca214f8e308 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -7802,7 +7802,9 @@ void OMPClauseWriter::VisitOMPFinalClause(OMPFinalClause *C) {
 
 void OMPClauseWriter::VisitOMPNumThreadsClause(OMPNumThreadsClause *C) {
   VisitOMPClauseWithPreInit(C);
+  Record.writeEnum(C->getModifier());
   Record.AddStmt(C->getNumThreads());
+  Record.AddSourceLocation(C->getModifierLoc());
   Record.AddSourceLocation(C->getLParenLoc());
 }
 
diff --git a/clang/test/OpenMP/parallel_ast_print.cpp b/clang/test/OpenMP/parallel_ast_print.cpp
index 83afedcb740da..948baaff30d89 100644
--- a/clang/test/OpenMP/parallel_ast_print.cpp
+++ b/clang/test/OpenMP/parallel_ast_print.cpp
@@ -13,6 +13,14 @@
 // RUN: %clang_cc1 -DOMP51 -verify -Wno-vla -fopenmp-simd -ast-print %s | FileCheck -check-prefixes=CHECK,OMP51 %s
 // RUN: %clang_cc1 -DOMP51 -fopenmp-simd -x c++ -std=c++11 -emit-pch -o %t %s
 // RUN: %clang_cc1 -DOMP51 -fopenmp-simd -std=c++11 -include-pch %t -verify -Wno-vla %s -ast-print | FileCheck -check-prefixes=CHECK,OMP51 %s
+
+// RUN: %clang_cc1 -DOMP60 -verify -Wno-vla -fopenmp -fopenmp-version=60 -ast-print %s | FileCheck -check-prefixes=CHECK,OMP60 %s
+// RUN: %clang_cc1 -DOMP60 -fopenmp -fopenmp-version=60 -x c++ -std=c++11 -emit-pch -o %t %s
+// RUN: %clang_cc1 -DOMP60 -fopenmp -fopenmp-version=60 -std=c++11 -include-pch %t -verify -Wno-vla %s -ast-print | FileCheck -check-prefixes=CHECK,OMP60 %s
+
+// RUN: %clang_cc1 -DOMP60 -verify -Wno-vla -fopenmp-simd -fopenmp-version=60 -ast-print %s | FileCheck -check-prefixes=CHECK,OMP60 %s
+// RUN: %clang_cc1 -DOMP60 -fopenmp-simd -fopenmp-version=60 -x c++ -std=c++11 -emit-pch -o %t %s
+// RUN: %clang_cc1 -DOMP60 -fopenmp-simd -fopenmp-version=60 -std=c++11 -include-pch %t -verify -Wno-vla %s -ast-print | FileCheck -check-prefixes=CHECK,OMP60 %s
 // expected-no-diagnostics
 
 #ifndef HEADER
@@ -164,8 +172,16 @@ T tmain(T argc, T *argv) {
 #pragma omp parallel default(none), private(argc,b) firstprivate(argv) shared (d) if (parallel:argc > 0) num_threads(C) copyin(S<T>::TS, thrp) proc_bind(primary) reduction(+:c, arr1[argc]) reduction(max:e, arr[:C][0:10])
   foo();
 #endif
+#ifdef OMP60
+#pragma omp parallel default(none), private(argc,b) firstprivate(argv) shared (d) if (parallel:argc > 0) num_threads(strict: C) copyin(S<T>::TS, thrp) proc_bind(primary) reduction(+:c, arr1[argc]) reduction(max:e, arr[:C][0:10])
+  foo();
+#endif
 #pragma omp parallel if (C) num_threads(s) proc_bind(close) re...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 24, 2025

@llvm/pr-subscribers-clang

Author: Robert Imschweiler (ro-i)

Changes

Implement parsing and semantic analysis support for the optional 'strict' modifier of the num_threads clause. This modifier has been introduced in OpenMP 6.0, section 12.1.2.

Note: this is basically 1:1 https://reviews.llvm.org/D138328.


Patch is 33.07 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145490.diff

14 Files Affected:

  • (modified) clang/include/clang/AST/OpenMPClause.h (+25-4)
  • (modified) clang/include/clang/Basic/OpenMPKinds.def (+7)
  • (modified) clang/include/clang/Basic/OpenMPKinds.h (+6)
  • (modified) clang/include/clang/Sema/SemaOpenMP.h (+4-4)
  • (modified) clang/lib/AST/OpenMPClause.cpp (+5)
  • (modified) clang/lib/Basic/OpenMPKinds.cpp (+19-2)
  • (modified) clang/lib/Parse/ParseOpenMP.cpp (+31-2)
  • (modified) clang/lib/Sema/SemaOpenMP.cpp (+46-30)
  • (modified) clang/lib/Sema/TreeTransform.h (+7-4)
  • (modified) clang/lib/Serialization/ASTReader.cpp (+2)
  • (modified) clang/lib/Serialization/ASTWriter.cpp (+2)
  • (modified) clang/test/OpenMP/parallel_ast_print.cpp (+41-1)
  • (modified) clang/test/OpenMP/parallel_num_threads_messages.cpp (+70-3)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMP.td (+7)
diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h
index 05239668b34b1..c6f99fb21a0f0 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -824,31 +824,52 @@ class OMPNumThreadsClause final
       public OMPClauseWithPreInit {
   friend class OMPClauseReader;
 
+  /// Modifiers for 'num_threads' clause.
+  OpenMPNumThreadsClauseModifier Modifier = OMPC_NUMTHREADS_unknown;
+
+  /// Location of the modifier.
+  SourceLocation ModifierLoc;
+
+  /// Sets modifier.
+  void setModifier(OpenMPNumThreadsClauseModifier M) { Modifier = M; }
+
+  /// Sets modifier location.
+  void setModifierLoc(SourceLocation Loc) { ModifierLoc = Loc; }
+
   /// Set condition.
   void setNumThreads(Expr *NThreads) { setStmt(NThreads); }
 
 public:
   /// Build 'num_threads' clause with condition \a NumThreads.
   ///
+  /// \param Modifier Clause modifier.
   /// \param NumThreads Number of threads for the construct.
   /// \param HelperNumThreads Helper Number of threads for the construct.
   /// \param CaptureRegion Innermost OpenMP region where expressions in this
   /// clause must be captured.
   /// \param StartLoc Starting location of the clause.
   /// \param LParenLoc Location of '('.
+  /// \param ModifierLoc Modifier location.
   /// \param EndLoc Ending location of the clause.
-  OMPNumThreadsClause(Expr *NumThreads, Stmt *HelperNumThreads,
-                      OpenMPDirectiveKind CaptureRegion,
+  OMPNumThreadsClause(OpenMPNumThreadsClauseModifier Modifier, Expr *NumThreads,
+                      Stmt *HelperNumThreads, OpenMPDirectiveKind CaptureRegion,
                       SourceLocation StartLoc, SourceLocation LParenLoc,
-                      SourceLocation EndLoc)
+                      SourceLocation ModifierLoc, SourceLocation EndLoc)
       : OMPOneStmtClause(NumThreads, StartLoc, LParenLoc, EndLoc),
-        OMPClauseWithPreInit(this) {
+        OMPClauseWithPreInit(this), Modifier(Modifier),
+        ModifierLoc(ModifierLoc) {
     setPreInitStmt(HelperNumThreads, CaptureRegion);
   }
 
   /// Build an empty clause.
   OMPNumThreadsClause() : OMPOneStmtClause(), OMPClauseWithPreInit(this) {}
 
+  /// Gets modifier.
+  OpenMPNumThreadsClauseModifier getModifier() const { return Modifier; }
+
+  /// Gets modifier location.
+  SourceLocation getModifierLoc() const { return ModifierLoc; }
+
   /// Returns number of threads.
   Expr *getNumThreads() const { return getStmtAs<Expr>(); }
 };
diff --git a/clang/include/clang/Basic/OpenMPKinds.def b/clang/include/clang/Basic/OpenMPKinds.def
index 2b1dc1e0121b2..9d6f816eea91f 100644
--- a/clang/include/clang/Basic/OpenMPKinds.def
+++ b/clang/include/clang/Basic/OpenMPKinds.def
@@ -86,6 +86,9 @@
 #ifndef OPENMP_NUMTASKS_MODIFIER
 #define OPENMP_NUMTASKS_MODIFIER(Name)
 #endif
+#ifndef OPENMP_NUMTHREADS_MODIFIER
+#define OPENMP_NUMTHREADS_MODIFIER(Name)
+#endif
 #ifndef OPENMP_DOACROSS_MODIFIER
 #define OPENMP_DOACROSS_MODIFIER(Name)
 #endif
@@ -227,6 +230,9 @@ OPENMP_GRAINSIZE_MODIFIER(strict)
 // Modifiers for the 'num_tasks' clause.
 OPENMP_NUMTASKS_MODIFIER(strict)
 
+// Modifiers for the 'num_tasks' clause.
+OPENMP_NUMTHREADS_MODIFIER(strict)
+
 // Modifiers for 'allocate' clause.
 OPENMP_ALLOCATE_MODIFIER(allocator)
 OPENMP_ALLOCATE_MODIFIER(align)
@@ -238,6 +244,7 @@ OPENMP_DOACROSS_MODIFIER(sink_omp_cur_iteration)
 OPENMP_DOACROSS_MODIFIER(source_omp_cur_iteration)
 
 #undef OPENMP_NUMTASKS_MODIFIER
+#undef OPENMP_NUMTHREADS_MODIFIER
 #undef OPENMP_GRAINSIZE_MODIFIER
 #undef OPENMP_BIND_KIND
 #undef OPENMP_ADJUST_ARGS_KIND
diff --git a/clang/include/clang/Basic/OpenMPKinds.h b/clang/include/clang/Basic/OpenMPKinds.h
index 9afcce21a499d..f40db4c13c55a 100644
--- a/clang/include/clang/Basic/OpenMPKinds.h
+++ b/clang/include/clang/Basic/OpenMPKinds.h
@@ -223,6 +223,12 @@ enum OpenMPNumTasksClauseModifier {
   OMPC_NUMTASKS_unknown
 };
 
+enum OpenMPNumThreadsClauseModifier {
+#define OPENMP_NUMTHREADS_MODIFIER(Name) OMPC_NUMTHREADS_##Name,
+#include "clang/Basic/OpenMPKinds.def"
+  OMPC_NUMTHREADS_unknown
+};
+
 /// OpenMP dependence types for 'doacross' clause.
 enum OpenMPDoacrossClauseModifier {
 #define OPENMP_DOACROSS_MODIFIER(Name) OMPC_DOACROSS_##Name,
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index 7b169f56b6807..91c3d4bd5210e 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -877,10 +877,10 @@ class SemaOpenMP : public SemaBase {
                                     SourceLocation LParenLoc,
                                     SourceLocation EndLoc);
   /// Called on well-formed 'num_threads' clause.
-  OMPClause *ActOnOpenMPNumThreadsClause(Expr *NumThreads,
-                                         SourceLocation StartLoc,
-                                         SourceLocation LParenLoc,
-                                         SourceLocation EndLoc);
+  OMPClause *ActOnOpenMPNumThreadsClause(
+      OpenMPNumThreadsClauseModifier Modifier, Expr *NumThreads,
+      SourceLocation StartLoc, SourceLocation LParenLoc,
+      SourceLocation ModifierLoc, SourceLocation EndLoc);
   /// Called on well-formed 'align' clause.
   OMPClause *ActOnOpenMPAlignClause(Expr *Alignment, SourceLocation StartLoc,
                                     SourceLocation LParenLoc,
diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index 8b1caa05eec32..de8b5996818de 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -1830,6 +1830,11 @@ void OMPClausePrinter::VisitOMPFinalClause(OMPFinalClause *Node) {
 
 void OMPClausePrinter::VisitOMPNumThreadsClause(OMPNumThreadsClause *Node) {
   OS << "num_threads(";
+  OpenMPNumThreadsClauseModifier Modifier = Node->getModifier();
+  if (Modifier != OMPC_NUMTHREADS_unknown) {
+    OS << getOpenMPSimpleClauseTypeName(Node->getClauseKind(), Modifier)
+       << ": ";
+  }
   Node->getNumThreads()->printPretty(OS, nullptr, Policy, 0);
   OS << ")";
 }
diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index a451fc7c01841..d3d393bd09396 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -185,11 +185,19 @@ unsigned clang::getOpenMPSimpleClauseType(OpenMPClauseKind Kind, StringRef Str,
 #define OPENMP_ALLOCATE_MODIFIER(Name) .Case(#Name, OMPC_ALLOCATE_##Name)
 #include "clang/Basic/OpenMPKinds.def"
         .Default(OMPC_ALLOCATE_unknown);
+  case OMPC_num_threads: {
+    unsigned Type = llvm::StringSwitch<unsigned>(Str)
+#define OPENMP_NUMTHREADS_MODIFIER(Name) .Case(#Name, OMPC_NUMTHREADS_##Name)
+#include "clang/Basic/OpenMPKinds.def"
+                        .Default(OMPC_NUMTHREADS_unknown);
+    if (LangOpts.OpenMP < 60)
+      return OMPC_NUMTHREADS_unknown;
+    return Type;
+  }
   case OMPC_unknown:
   case OMPC_threadprivate:
   case OMPC_if:
   case OMPC_final:
-  case OMPC_num_threads:
   case OMPC_safelen:
   case OMPC_simdlen:
   case OMPC_sizes:
@@ -520,11 +528,20 @@ const char *clang::getOpenMPSimpleClauseTypeName(OpenMPClauseKind Kind,
 #include "clang/Basic/OpenMPKinds.def"
     }
     llvm_unreachable("Invalid OpenMP 'allocate' clause modifier");
+  case OMPC_num_threads:
+    switch (Type) {
+    case OMPC_NUMTHREADS_unknown:
+      return "unknown";
+#define OPENMP_NUMTHREADS_MODIFIER(Name)                                       \
+  case OMPC_NUMTHREADS_##Name:                                                 \
+    return #Name;
+#include "clang/Basic/OpenMPKinds.def"
+    }
+    llvm_unreachable("Invalid OpenMP 'num_threads' clause modifier");
   case OMPC_unknown:
   case OMPC_threadprivate:
   case OMPC_if:
   case OMPC_final:
-  case OMPC_num_threads:
   case OMPC_safelen:
   case OMPC_simdlen:
   case OMPC_sizes:
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index 78d3503d8eb68..f694ae1d0d112 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -3196,7 +3196,8 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
     if ((CKind == OMPC_ordered || CKind == OMPC_partial) &&
         PP.LookAhead(/*N=*/0).isNot(tok::l_paren))
       Clause = ParseOpenMPClause(CKind, WrongDirective);
-    else if (CKind == OMPC_grainsize || CKind == OMPC_num_tasks)
+    else if (CKind == OMPC_grainsize || CKind == OMPC_num_tasks ||
+             CKind == OMPC_num_threads)
       Clause = ParseOpenMPSingleExprWithArgClause(DKind, CKind, WrongDirective);
     else
       Clause = ParseOpenMPSingleExprClause(CKind, WrongDirective);
@@ -3981,6 +3982,33 @@ OMPClause *Parser::ParseOpenMPSingleExprWithArgClause(OpenMPDirectiveKind DKind,
       Arg.push_back(OMPC_NUMTASKS_unknown);
       KLoc.emplace_back();
     }
+  } else if (Kind == OMPC_num_threads) {
+    // Parse optional <num_threads modifier> ':'
+    OpenMPNumThreadsClauseModifier Modifier =
+        static_cast<OpenMPNumThreadsClauseModifier>(getOpenMPSimpleClauseType(
+            Kind, Tok.isAnnotation() ? "" : PP.getSpelling(Tok),
+            getLangOpts()));
+    if (getLangOpts().OpenMP >= 60) {
+      if (NextToken().is(tok::colon)) {
+        Arg.push_back(Modifier);
+        KLoc.push_back(Tok.getLocation());
+        // Parse modifier
+        ConsumeAnyToken();
+        // Parse ':'
+        ConsumeAnyToken();
+      } else {
+        if (Modifier == OMPC_NUMTHREADS_strict) {
+          Diag(Tok, diag::err_modifier_expected_colon) << "strict";
+          // Parse modifier
+          ConsumeAnyToken();
+        }
+        Arg.push_back(OMPC_NUMTHREADS_unknown);
+        KLoc.emplace_back();
+      }
+    } else {
+      Arg.push_back(OMPC_NUMTHREADS_unknown);
+      KLoc.emplace_back();
+    }
   } else {
     assert(Kind == OMPC_if);
     KLoc.push_back(Tok.getLocation());
@@ -4004,7 +4032,8 @@ OMPClause *Parser::ParseOpenMPSingleExprWithArgClause(OpenMPDirectiveKind DKind,
   bool NeedAnExpression = (Kind == OMPC_schedule && DelimLoc.isValid()) ||
                           (Kind == OMPC_dist_schedule && DelimLoc.isValid()) ||
                           Kind == OMPC_if || Kind == OMPC_device ||
-                          Kind == OMPC_grainsize || Kind == OMPC_num_tasks;
+                          Kind == OMPC_grainsize || Kind == OMPC_num_tasks ||
+                          Kind == OMPC_num_threads;
   if (NeedAnExpression) {
     SourceLocation ELoc = Tok.getLocation();
     ExprResult LHS(ParseCastExpression(CastParseKind::AnyCastExpr, false,
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 00f4658180807..a30acbe9a4bca 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -15509,9 +15509,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
   case OMPC_final:
     Res = ActOnOpenMPFinalClause(Expr, StartLoc, LParenLoc, EndLoc);
     break;
-  case OMPC_num_threads:
-    Res = ActOnOpenMPNumThreadsClause(Expr, StartLoc, LParenLoc, EndLoc);
-    break;
   case OMPC_safelen:
     Res = ActOnOpenMPSafelenClause(Expr, StartLoc, LParenLoc, EndLoc);
     break;
@@ -15565,6 +15562,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
     break;
   case OMPC_grainsize:
   case OMPC_num_tasks:
+  case OMPC_num_threads:
   case OMPC_device:
   case OMPC_if:
   case OMPC_default:
@@ -15911,10 +15909,41 @@ isNonNegativeIntegerValue(Expr *&ValExpr, Sema &SemaRef, OpenMPClauseKind CKind,
   return true;
 }
 
-OMPClause *SemaOpenMP::ActOnOpenMPNumThreadsClause(Expr *NumThreads,
-                                                   SourceLocation StartLoc,
-                                                   SourceLocation LParenLoc,
-                                                   SourceLocation EndLoc) {
+static std::string getListOfPossibleValues(OpenMPClauseKind K, unsigned First,
+                                           unsigned Last,
+                                           ArrayRef<unsigned> Exclude = {}) {
+  SmallString<256> Buffer;
+  llvm::raw_svector_ostream Out(Buffer);
+  unsigned Skipped = Exclude.size();
+  for (unsigned I = First; I < Last; ++I) {
+    if (llvm::is_contained(Exclude, I)) {
+      --Skipped;
+      continue;
+    }
+    Out << "'" << getOpenMPSimpleClauseTypeName(K, I) << "'";
+    if (I + Skipped + 2 == Last)
+      Out << " or ";
+    else if (I + Skipped + 1 != Last)
+      Out << ", ";
+  }
+  return std::string(Out.str());
+}
+
+OMPClause *SemaOpenMP::ActOnOpenMPNumThreadsClause(
+    OpenMPNumThreadsClauseModifier Modifier, Expr *NumThreads,
+    SourceLocation StartLoc, SourceLocation LParenLoc,
+    SourceLocation ModifierLoc, SourceLocation EndLoc) {
+  assert((ModifierLoc.isInvalid() || getLangOpts().OpenMP >= 60) &&
+         "Unexpected num_threads modifier in OpenMP < 60.");
+
+  if (ModifierLoc.isValid() && Modifier == OMPC_NUMTHREADS_unknown) {
+    std::string Values = getListOfPossibleValues(OMPC_num_threads, /*First=*/0,
+                                                 OMPC_NUMTHREADS_unknown);
+    Diag(ModifierLoc, diag::err_omp_unexpected_clause_value)
+        << Values << getOpenMPClauseNameForDiag(OMPC_num_threads);
+    return nullptr;
+  }
+
   Expr *ValExpr = NumThreads;
   Stmt *HelperValStmt = nullptr;
 
@@ -15935,8 +15964,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPNumThreadsClause(Expr *NumThreads,
     HelperValStmt = buildPreInits(getASTContext(), Captures);
   }
 
-  return new (getASTContext()) OMPNumThreadsClause(
-      ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc);
+  return new (getASTContext())
+      OMPNumThreadsClause(Modifier, ValExpr, HelperValStmt, CaptureRegion,
+                          StartLoc, LParenLoc, ModifierLoc, EndLoc);
 }
 
 ExprResult SemaOpenMP::VerifyPositiveIntegerConstantInClause(
@@ -16301,26 +16331,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSimpleClause(
   return Res;
 }
 
-static std::string getListOfPossibleValues(OpenMPClauseKind K, unsigned First,
-                                           unsigned Last,
-                                           ArrayRef<unsigned> Exclude = {}) {
-  SmallString<256> Buffer;
-  llvm::raw_svector_ostream Out(Buffer);
-  unsigned Skipped = Exclude.size();
-  for (unsigned I = First; I < Last; ++I) {
-    if (llvm::is_contained(Exclude, I)) {
-      --Skipped;
-      continue;
-    }
-    Out << "'" << getOpenMPSimpleClauseTypeName(K, I) << "'";
-    if (I + Skipped + 2 == Last)
-      Out << " or ";
-    else if (I + Skipped + 1 != Last)
-      Out << ", ";
-  }
-  return std::string(Out.str());
-}
-
 OMPClause *SemaOpenMP::ActOnOpenMPDefaultClause(DefaultKind Kind,
                                                 SourceLocation KindKwLoc,
                                                 SourceLocation StartLoc,
@@ -16693,8 +16703,14 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprWithArgClause(
         static_cast<OpenMPNumTasksClauseModifier>(Argument.back()), Expr,
         StartLoc, LParenLoc, ArgumentLoc.back(), EndLoc);
     break;
-  case OMPC_final:
   case OMPC_num_threads:
+    assert(Argument.size() == 1 && ArgumentLoc.size() == 1 &&
+           "Modifier for num_threads clause and its location are expected.");
+    Res = ActOnOpenMPNumThreadsClause(
+        static_cast<OpenMPNumThreadsClauseModifier>(Argument.back()), Expr,
+        StartLoc, LParenLoc, ArgumentLoc.back(), EndLoc);
+    break;
+  case OMPC_final:
   case OMPC_safelen:
   case OMPC_simdlen:
   case OMPC_sizes:
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 26bee7a96de22..0d58587cb8a99 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -1714,12 +1714,14 @@ class TreeTransform {
   ///
   /// By default, performs semantic analysis to build the new OpenMP clause.
   /// Subclasses may override this routine to provide different behavior.
-  OMPClause *RebuildOMPNumThreadsClause(Expr *NumThreads,
+  OMPClause *RebuildOMPNumThreadsClause(OpenMPNumThreadsClauseModifier Modifier,
+                                        Expr *NumThreads,
                                         SourceLocation StartLoc,
                                         SourceLocation LParenLoc,
+                                        SourceLocation ModifierLoc,
                                         SourceLocation EndLoc) {
-    return getSema().OpenMP().ActOnOpenMPNumThreadsClause(NumThreads, StartLoc,
-                                                          LParenLoc, EndLoc);
+    return getSema().OpenMP().ActOnOpenMPNumThreadsClause(
+        Modifier, NumThreads, StartLoc, LParenLoc, ModifierLoc, EndLoc);
   }
 
   /// Build a new OpenMP 'safelen' clause.
@@ -10461,7 +10463,8 @@ TreeTransform<Derived>::TransformOMPNumThreadsClause(OMPNumThreadsClause *C) {
   if (NumThreads.isInvalid())
     return nullptr;
   return getDerived().RebuildOMPNumThreadsClause(
-      NumThreads.get(), C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc());
+      C->getModifier(), NumThreads.get(), C->getBeginLoc(), C->getLParenLoc(),
+      C->getModifierLoc(), C->getEndLoc());
 }
 
 template <typename Derived>
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 6f082fe840b4c..b696cb2efee3d 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -11461,7 +11461,9 @@ void OMPClauseReader::VisitOMPFinalClause(OMPFinalClause *C) {
 
 void OMPClauseReader::VisitOMPNumThreadsClause(OMPNumThreadsClause *C) {
   VisitOMPClauseWithPreInit(C);
+  C->setModifier(Record.readEnum<OpenMPNumThreadsClauseModifier>());
   C->setNumThreads(Record.readSubExpr());
+  C->setModifierLoc(Record.readSourceLocation());
   C->setLParenLoc(Record.readSourceLocation());
 }
 
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index c6487c5366a29..4cca214f8e308 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -7802,7 +7802,9 @@ void OMPClauseWriter::VisitOMPFinalClause(OMPFinalClause *C) {
 
 void OMPClauseWriter::VisitOMPNumThreadsClause(OMPNumThreadsClause *C) {
   VisitOMPClauseWithPreInit(C);
+  Record.writeEnum(C->getModifier());
   Record.AddStmt(C->getNumThreads());
+  Record.AddSourceLocation(C->getModifierLoc());
   Record.AddSourceLocation(C->getLParenLoc());
 }
 
diff --git a/clang/test/OpenMP/parallel_ast_print.cpp b/clang/test/OpenMP/parallel_ast_print.cpp
index 83afedcb740da..948baaff30d89 100644
--- a/clang/test/OpenMP/parallel_ast_print.cpp
+++ b/clang/test/OpenMP/parallel_ast_print.cpp
@@ -13,6 +13,14 @@
 // RUN: %clang_cc1 -DOMP51 -verify -Wno-vla -fopenmp-simd -ast-print %s | FileCheck -check-prefixes=CHECK,OMP51 %s
 // RUN: %clang_cc1 -DOMP51 -fopenmp-simd -x c++ -std=c++11 -emit-pch -o %t %s
 // RUN: %clang_cc1 -DOMP51 -fopenmp-simd -std=c++11 -include-pch %t -verify -Wno-vla %s -ast-print | FileCheck -check-prefixes=CHECK,OMP51 %s
+
+// RUN: %clang_cc1 -DOMP60 -verify -Wno-vla -fopenmp -fopenmp-version=60 -ast-print %s | FileCheck -check-prefixes=CHECK,OMP60 %s
+// RUN: %clang_cc1 -DOMP60 -fopenmp -fopenmp-version=60 -x c++ -std=c++11 -emit-pch -o %t %s
+// RUN: %clang_cc1 -DOMP60 -fopenmp -fopenmp-version=60 -std=c++11 -include-pch %t -verify -Wno-vla %s -ast-print | FileCheck -check-prefixes=CHECK,OMP60 %s
+
+// RUN: %clang_cc1 -DOMP60 -verify -Wno-vla -fopenmp-simd -fopenmp-version=60 -ast-print %s | FileCheck -check-prefixes=CHECK,OMP60 %s
+// RUN: %clang_cc1 -DOMP60 -fopenmp-simd -fopenmp-version=60 -x c++ -std=c++11 -emit-pch -o %t %s
+// RUN: %clang_cc1 -DOMP60 -fopenmp-simd -fopenmp-version=60 -std=c++11 -include-pch %t -verify -Wno-vla %s -ast-print | FileCheck -check-prefixes=CHECK,OMP60 %s
 // expected-no-diagnostics
 
 #ifndef HEADER
@@ -164,8 +172,16 @@ T tmain(T argc, T *argv) {
 #pragma omp parallel default(none), private(argc,b) firstprivate(argv) shared (d) if (parallel:argc > 0) num_threads(C) copyin(S<T>::TS, thrp) proc_bind(primary) reduction(+:c, arr1[argc]) reduction(max:e, arr[:C][0:10])
   foo();
 #endif
+#ifdef OMP60
+#pragma omp parallel default(none), private(argc,b) firstprivate(argv) shared (d) if (parallel:argc > 0) num_threads(strict: C) copyin(S<T>::TS, thrp) proc_bind(primary) reduction(+:c, arr1[argc]) reduction(max:e, arr[:C][0:10])
+  foo();
+#endif
 #pragma omp parallel if (C) num_threads(s) proc_bind(close) re...
[truncated]

@ro-i ro-i merged commit f624ba2 into llvm:main Jun 24, 2025
13 checks passed
@ro-i
Copy link
Contributor Author

ro-i commented Jun 24, 2025

merged, thanks for the review!

@ro-i ro-i deleted the omp-num-threads-strict branch June 24, 2025 19:13
DrSergei pushed a commit to DrSergei/llvm-project that referenced this pull request Jun 24, 2025
…lvm#145490)

Implement parsing and semantic analysis support for the optional
'strict' modifier of the num_threads clause. This modifier has been
introduced in OpenMP 6.0, section 12.1.2.

Note: this is basically 1:1 https://reviews.llvm.org/D138328.
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
…lvm#145490)

Implement parsing and semantic analysis support for the optional
'strict' modifier of the num_threads clause. This modifier has been
introduced in OpenMP 6.0, section 12.1.2.

Note: this is basically 1:1 https://reviews.llvm.org/D138328.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules clang:openmp OpenMP related changes to Clang clang Clang issues not falling into any other category flang:openmp
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants