Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SYCL free function namespace support #17585

Open
wants to merge 18 commits into
base: sycl
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 130 additions & 65 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
@@ -15,10 +15,9 @@
#include "clang/AST/QualTypeNames.h"
#include "clang/AST/RecordLayout.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/AST/TemplateArgumentVisitor.h"
#include "clang/AST/Mangle.h"
#include "clang/AST/SYCLKernelInfo.h"
#include "clang/AST/StmtSYCL.h"
#include "clang/AST/TemplateArgumentVisitor.h"
#include "clang/AST/TypeOrdering.h"
#include "clang/AST/TypeVisitor.h"
#include "clang/Analysis/CallGraph.h"
@@ -27,7 +26,6 @@
#include "clang/Basic/Diagnostic.h"
#include "clang/Basic/TargetInfo.h"
#include "clang/Basic/Version.h"
#include "clang/AST/SYCLKernelInfo.h"
#include "clang/Sema/Attr.h"
#include "clang/Sema/Initialization.h"
#include "clang/Sema/ParsedAttr.h"
@@ -6425,6 +6423,120 @@ static void EmitPragmaDiagnosticPop(raw_ostream &O) {
O << "\n";
}

template <typename BeforeFn, typename AfterFn>
static void PrintNSHelper(BeforeFn Before, AfterFn After, raw_ostream &OS,
const DeclContext *DC) {
if (DC->isTranslationUnit())
return;

const auto *CurDecl = cast<Decl>(DC);
// Ensure we are in the canonical version, so that we know we have the 'full'
// name of the thing.
CurDecl = CurDecl->getCanonicalDecl();

// We are intentionally skipping linkage decls and record decls. Namespaces
// can appear in a linkage decl, but not a record decl, so we don't have to
// worry about the names getting messed up from that. We handle record decls
// later when printing the name of the thing.
const auto *NS = dyn_cast<NamespaceDecl>(CurDecl);
if (NS)
Before(OS, NS);

if (const DeclContext *NewDC = CurDecl->getDeclContext())
PrintNSHelper(Before, After, OS, NewDC);

if (NS)
After(OS, NS);
}

static void PrintNamespaces(raw_ostream &OS, const DeclContext *DC,
bool isPrintNamesOnly = false) {
PrintNSHelper([](raw_ostream &OS, const NamespaceDecl *NS) {},
[isPrintNamesOnly](raw_ostream &OS, const NamespaceDecl *NS) {
if (!isPrintNamesOnly) {
if (NS->isInline())
OS << "inline ";
OS << "namespace ";
}
if (!NS->isAnonymousNamespace()) {
OS << NS->getName();
if (isPrintNamesOnly)
OS << "::";
else
OS << " ";
}
if (!isPrintNamesOnly) {
OS << "{\n";
}
},
OS, DC);
}

static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) {
PrintNSHelper(
[](raw_ostream &OS, const NamespaceDecl *NS) {
OS << "} // ";
if (NS->isInline())
OS << "inline ";

OS << "namespace ";
if (!NS->isAnonymousNamespace())
OS << NS->getName();

OS << '\n';
},
[](raw_ostream &OS, const NamespaceDecl *NS) {}, OS, DC);
}

class FreeFunctionPrinter {
raw_ostream &O;
const PrintingPolicy &Policy;
bool NSInserted = false;

public:
FreeFunctionPrinter(raw_ostream &O, const PrintingPolicy &Policy)
: O(O), Policy(Policy) {}

/// Emits the function declaration of a free function.
/// \param FD The function declaration to print.
/// \param Args The arguments of the function.
void printFreeFunctionDeclaration(const FunctionDecl *FD,
const std::string &Args) {
const DeclContext *DC = FD->getDeclContext();
if (DC) {
// if function in namespace, print namespace
if (isa<NamespaceDecl>(DC)) {
PrintNamespaces(O, FD);
// Set flag to print closing braces for namespaces and namespace in shim
// function
NSInserted = true;
}
O << FD->getReturnType().getAsString() << " ";
O << FD->getNameAsString() << "(" << Args << ");";
if (NSInserted) {
O << "\n";
PrintNSClosingBraces(O, FD);
}
O << "\n";
}
}

/// Emits free function shim function.
/// \param FD The function declaration to print.
/// \param ShimCounter The counter for the shim function.
/// \param ParmList The parameter list of the function.
void printFreeFunctionShim(const FunctionDecl *FD, const unsigned ShimCounter,
const std::string &ParmList) {
// Generate a shim function that returns the address of the free function.
O << "static constexpr auto __sycl_shim" << ShimCounter << "() {\n";
O << " return (void (*)(" << ParmList << "))";

if (NSInserted)
PrintNamespaces(O, FD, /*isPrintNamesOnly=*/true);
O << FD->getIdentifier()->getName().data();
}
};

void SYCLIntegrationHeader::emit(raw_ostream &O) {
O << "// This is auto-generated SYCL integration header.\n";
O << "\n";
@@ -6713,16 +6825,25 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
if (K.SyclKernel->getLanguageLinkage() == CLanguageLinkage)
O << "extern \"C\" ";
std::string ParmList;
std::string ParmListWithNames;
bool FirstParam = true;
Policy.SuppressDefaultTemplateArgs = false;
Policy.PrintCanonicalTypes = true;
llvm::raw_string_ostream ParmListWithNamesOstream{ParmListWithNames};
for (ParmVarDecl *Param : K.SyclKernel->parameters()) {
if (FirstParam)
FirstParam = false;
else
else {
ParmList += ", ";
ParmListWithNamesOstream << ", ";
}
Policy.SuppressTagKeyword = true;
Param->getType().print(ParmListWithNamesOstream, Policy);
Policy.SuppressTagKeyword = false;
ParmListWithNamesOstream << " " << Param->getNameAsString();
Comment on lines +6840 to +6843
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate what this particular addition is trying to achieve, why the previous code did not suffice and how does it relate to namespace printing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, ParamList contains only parameter types, i.e. for the function
void some_func(float a, float* b)
ParamList contains {float, float*}
It was added new list to have additional list of parameters with names to pass already existed tests for free function which checked generated header.
flag
Policy.SuppressTagKeyword = true;
forces printing without type tags, i.e. without words class and struct.

ParmList += Param->getType().getCanonicalType().getAsString(Policy);
}
ParmListWithNamesOstream.flush();
FunctionTemplateDecl *FTD = K.SyclKernel->getPrimaryTemplate();
Policy.PrintCanonicalTypes = false;
Policy.SuppressDefinition = true;
@@ -6756,17 +6877,16 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
// template arguments that match default template arguments while printing
// template-ids, even if the source code doesn't reference them.
Policy.EnforceDefaultTemplateArgs = true;
FreeFunctionPrinter FFPrinter(O, Policy);
// bool NSInserted{false};
if (FTD) {
FTD->print(O, Policy);
O << ";\n";
} else {
K.SyclKernel->print(O, Policy);
FFPrinter.printFreeFunctionDeclaration(K.SyclKernel, ParmListWithNames);
}
O << ";\n";

// Generate a shim function that returns the address of the free function.
O << "static constexpr auto __sycl_shim" << ShimCounter << "() {\n";
O << " return (void (*)(" << ParmList << "))"
<< K.SyclKernel->getIdentifier()->getName().data();
FFPrinter.printFreeFunctionShim(K.SyclKernel, ShimCounter, ParmList);
if (FTD) {
const TemplateArgumentList *TAL =
K.SyclKernel->getTemplateSpecializationArgs();
@@ -6935,61 +7055,6 @@ bool SYCLIntegrationFooter::emit(StringRef IntHeaderName) {
return emit(Out);
}

template <typename BeforeFn, typename AfterFn>
static void PrintNSHelper(BeforeFn Before, AfterFn After, raw_ostream &OS,
const DeclContext *DC) {
if (DC->isTranslationUnit())
return;

const auto *CurDecl = cast<Decl>(DC);
// Ensure we are in the canonical version, so that we know we have the 'full'
// name of the thing.
CurDecl = CurDecl->getCanonicalDecl();

// We are intentionally skipping linkage decls and record decls. Namespaces
// can appear in a linkage decl, but not a record decl, so we don't have to
// worry about the names getting messed up from that. We handle record decls
// later when printing the name of the thing.
const auto *NS = dyn_cast<NamespaceDecl>(CurDecl);
if (NS)
Before(OS, NS);

if (const DeclContext *NewDC = CurDecl->getDeclContext())
PrintNSHelper(Before, After, OS, NewDC);

if (NS)
After(OS, NS);
}

static void PrintNamespaces(raw_ostream &OS, const DeclContext *DC) {
PrintNSHelper([](raw_ostream &OS, const NamespaceDecl *NS) {},
[](raw_ostream &OS, const NamespaceDecl *NS) {
if (NS->isInline())
OS << "inline ";
OS << "namespace ";
if (!NS->isAnonymousNamespace())
OS << NS->getName() << " ";
OS << "{\n";
},
OS, DC);
}

static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) {
PrintNSHelper(
[](raw_ostream &OS, const NamespaceDecl *NS) {
OS << "} // ";
if (NS->isInline())
OS << "inline ";

OS << "namespace ";
if (!NS->isAnonymousNamespace())
OS << NS->getName();

OS << '\n';
},
[](raw_ostream &OS, const NamespaceDecl *NS) {}, OS, DC);
}

static std::string EmitShim(raw_ostream &OS, unsigned &ShimCounter,
const std::string &LastShim,
const NamespaceDecl *AnonNS) {
Original file line number Diff line number Diff line change
@@ -86,18 +86,20 @@ foo(Arg1<int> arg) {
// CHECK-NEXT: template <typename T, typename, int a, typename, typename ...TS> struct Arg;
// CHECK-NEXT: }

// CHECK: void ns::simple(ns::Arg<char, int, 12, ns::notatuple>);
// CHECK-NEXT: static constexpr auto __sycl_shim1() {
// CHECK-NEXT: return (void (*)(struct ns::Arg<char, int, 12, struct ns::notatuple>))simple;
// CHECK: namespace ns {
// CHECK-NEXT: void simple(ns::Arg<char, int, 12, ns::notatuple> );
// CHECK-NEXT: } // namespace ns
// CHECK: static constexpr auto __sycl_shim1() {
// CHECK-NEXT: return (void (*)(struct ns::Arg<char, int, 12, struct ns::notatuple>))ns::simple;
// CHECK-NEXT: }

// CHECK: Forward declarations of kernel and its argument types:
// CHECK: namespace ns {
// CHECK: namespace ns1 {
// CHECK-NEXT: template <typename A> class hasDefaultArg;
// CHECK-NEXT: }
// CHECK-NEXT: }}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add more FE tests? Like with various combinations of namespaces around the free function kernel declaration? With inline namespace and not. Can we also test that codegen and semantic analysis is ok for free function kernels defined in a (maybe nested) namespace?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added new e2e tests to check any possible namespaces: nested, anonymous, inline etc. Is it enough or add in these tests too? New tests do not check header directly but if something is emitted wrong, they will fail.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SYCL compiler is complicated and has a lot of components. If we only have a e2e test and it fails suddenly (for example after a pulldown), it may take a while to identify which component now has a problem.
This is one of reasons why we normally check each component separately with unit tests and everything together in e2e tests. FE-only tests are "unit" tests in this scenario. They will help more quickly to identify that the problem is in FE. They will also help people to fix any FE problems without needing to have sycl rt and device. So, I still encourage to add FE-only tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I did not see that these tests are units. Added new checks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well they are "unit" because clang is enormous itself and has its own unit tests but in terms of SYCL compiler we can consider them as unit tests.


// CHECK: void simple1(ns::Arg<ns::ns1::hasDefaultArg<ns::notatuple>, int, 12, ns::notatuple>);
// CHECK: void simple1(ns::Arg<ns::ns1::hasDefaultArg<ns::notatuple>, int, 12, ns::notatuple> );
// CHECK-NEXT: static constexpr auto __sycl_shim2() {
// CHECK-NEXT: return (void (*)(struct ns::Arg<class ns::ns1::hasDefaultArg<struct ns::notatuple>, int, 12, struct ns::notatuple>))simple1;
// CHECK-NEXT: }
Loading
Oops, something went wrong.
Loading
Oops, something went wrong.