diff --git a/include/swift/DependencyScan/DependencyScanningTool.h b/include/swift/DependencyScan/DependencyScanningTool.h index 8aa380c83fc0e..5cfddc17751a3 100644 --- a/include/swift/DependencyScan/DependencyScanningTool.h +++ b/include/swift/DependencyScan/DependencyScanningTool.h @@ -31,8 +31,23 @@ struct ScanQueryInstance { std::shared_ptr ScanDiagnostics; }; +/// Pure virtual Diagnostic consumer intended for collecting +/// emitted diagnostics in a thread-safe fashion +class ThreadSafeDiagnosticCollector : public DiagnosticConsumer { +private: + llvm::sys::SmartMutex DiagnosticConsumerStateLock; + void handleDiagnostic(SourceManager &SM, const DiagnosticInfo &Info) final; + +protected: + virtual void addDiagnostic(SourceManager &SM, const DiagnosticInfo &Info) = 0; + +public: + ThreadSafeDiagnosticCollector() {} + bool finishProcessing() final { return false; } +}; + /// Diagnostic consumer that simply collects the diagnostics emitted so-far -class DependencyScanDiagnosticCollector : public DiagnosticConsumer { +class DependencyScanDiagnosticCollector : public ThreadSafeDiagnosticCollector { private: struct ScannerDiagnosticInfo { std::string Message; @@ -40,12 +55,9 @@ class DependencyScanDiagnosticCollector : public DiagnosticConsumer { std::optional ImportLocation; }; std::vector Diagnostics; - llvm::sys::SmartMutex ScanningDiagnosticConsumerStateLock; - - void handleDiagnostic(SourceManager &SM, const DiagnosticInfo &Info) override; protected: - virtual void addDiagnostic(SourceManager &SM, const DiagnosticInfo &Info); + void addDiagnostic(SourceManager &SM, const DiagnosticInfo &Info) override; public: friend DependencyScanningTool; diff --git a/include/swift/DependencyScan/ModuleDependencyScanner.h b/include/swift/DependencyScan/ModuleDependencyScanner.h index 1b5d7d4050b77..184824fa07b92 100644 --- a/include/swift/DependencyScan/ModuleDependencyScanner.h +++ b/include/swift/DependencyScan/ModuleDependencyScanner.h @@ -83,6 +83,8 @@ class ModuleDependencyScanningWorker { // Worker-specific instance of CompilerInvocation std::unique_ptr workerCompilerInvocation; + // Worker-specific diagnostic engine + std::unique_ptr workerDiagnosticEngine; // Worker-specific instance of ASTContext std::unique_ptr workerASTContext; // An AST delegate for interface scanning. diff --git a/lib/DependencyScan/DependencyScanningTool.cpp b/lib/DependencyScan/DependencyScanningTool.cpp index 157cefec7da64..3eb619a0b64e7 100644 --- a/lib/DependencyScan/DependencyScanningTool.cpp +++ b/lib/DependencyScan/DependencyScanningTool.cpp @@ -64,8 +64,9 @@ llvm::ErrorOr getTargetInfo(ArrayRef Comma return c_string_utils::create_clone(ResultStr.c_str()); } -void DependencyScanDiagnosticCollector::handleDiagnostic(SourceManager &SM, +void ThreadSafeDiagnosticCollector::handleDiagnostic(SourceManager &SM, const DiagnosticInfo &Info) { + llvm::sys::SmartScopedLock Lock(DiagnosticConsumerStateLock); addDiagnostic(SM, Info); for (auto ChildInfo : Info.ChildDiagnosticInfo) { addDiagnostic(SM, *ChildInfo); @@ -74,8 +75,6 @@ void DependencyScanDiagnosticCollector::handleDiagnostic(SourceManager &SM, void DependencyScanDiagnosticCollector::addDiagnostic( SourceManager &SM, const DiagnosticInfo &Info) { - llvm::sys::SmartScopedLock Lock(ScanningDiagnosticConsumerStateLock); - // Determine what kind of diagnostic we're emitting. llvm::SourceMgr::DiagKind SMKind; switch (Info.Kind) { diff --git a/lib/DependencyScan/ModuleDependencyScanner.cpp b/lib/DependencyScan/ModuleDependencyScanner.cpp index 60c6b925e3f8e..7e362ded70700 100644 --- a/lib/DependencyScan/ModuleDependencyScanner.cpp +++ b/lib/DependencyScan/ModuleDependencyScanner.cpp @@ -200,6 +200,13 @@ ModuleDependencyScanningWorker::ModuleDependencyScanningWorker( // Create a scanner-specific Invocation and ASTContext. workerCompilerInvocation = std::make_unique(ScanCompilerInvocation); + + // Instantiate a worker-specific diagnostic engine and copy over + // the scanner's diagnostic consumers (expected to be thread-safe). + workerDiagnosticEngine = std::make_unique(ScanASTContext.SourceMgr); + for (auto &scannerDiagConsumer : Diagnostics.getConsumers()) + workerDiagnosticEngine->addConsumer(*scannerDiagConsumer); + workerASTContext = std::unique_ptr( ASTContext::get(workerCompilerInvocation->getLangOptions(), workerCompilerInvocation->getTypeCheckerOptions(), @@ -209,7 +216,8 @@ ModuleDependencyScanningWorker::ModuleDependencyScanningWorker( workerCompilerInvocation->getSymbolGraphOptions(), workerCompilerInvocation->getCASOptions(), workerCompilerInvocation->getSerializationOptions(), - ScanASTContext.SourceMgr, Diagnostics)); + ScanASTContext.SourceMgr, *workerDiagnosticEngine)); + auto loader = std::make_unique( *workerASTContext, /*DepTracker=*/nullptr, workerCompilerInvocation->getFrontendOptions().CacheReplayPrefixMap, diff --git a/unittests/DependencyScan/ModuleDeps.cpp b/unittests/DependencyScan/ModuleDeps.cpp index 27c8edf24e98a..6aaa094cf9745 100644 --- a/unittests/DependencyScan/ModuleDeps.cpp +++ b/unittests/DependencyScan/ModuleDeps.cpp @@ -307,7 +307,6 @@ public func funcB() { }\n")); for (auto &command : CommandStr) Command.push_back(command.c_str()); - auto ScanDiagnosticConsumer = std::make_shared(); auto DependenciesOrErr = ScannerTool.getDependencies(Command, {}, {}); @@ -329,3 +328,78 @@ public func funcB() { }\n")); ASSERT_TRUE(Dependencies->dependencies->modules[0]->link_libraries->count == 0); swiftscan_dependency_graph_dispose(Dependencies); } + +TEST_F(ScanTest, TestStressConcurrentDiagnostics) { + SmallString<256> tempDir; + ASSERT_FALSE(llvm::sys::fs::createUniqueDirectory("ScanTest.TestStressConcurrentDiagnostics", tempDir)); + SWIFT_DEFER { llvm::sys::fs::remove_directories(tempDir); }; + + // Create includes + std::string IncludeDirPath = createFilename(tempDir, "include"); + ASSERT_FALSE(llvm::sys::fs::create_directory(IncludeDirPath)); + std::string CHeadersDirPath = createFilename(IncludeDirPath, "CHeaders"); + ASSERT_FALSE(llvm::sys::fs::create_directory(CHeadersDirPath)); + + // Create test input file + std::string TestPathStr = createFilename(tempDir, "foo.swift"); + + // Create imported module C modulemap/headers + std::string modulemapContent = ""; + std::string testFileContent = ""; + for (int i = 0; i < 100; ++i) { + std::string headerName = "A_" + std::to_string(i) + ".h"; + std::string headerContent = "void funcA_" + std::to_string(i) + "(void);"; + ASSERT_FALSE( + emitFileWithContents(CHeadersDirPath, headerName, headerContent)); + + std::string moduleMapEntry = "module A_" + std::to_string(i) + "{\n"; + moduleMapEntry.append("header \"A_" + std::to_string(i) + ".h\"\n"); + moduleMapEntry.append("export *\n"); + moduleMapEntry.append("}\n"); + modulemapContent.append(moduleMapEntry); + testFileContent.append("import A_" + std::to_string(i) + "\n"); + } + + ASSERT_FALSE(emitFileWithContents(tempDir, "foo.swift", testFileContent)); + ASSERT_FALSE( + emitFileWithContents(CHeadersDirPath, "module.modulemap", modulemapContent)); + + // Paths to shims and stdlib + llvm::SmallString<128> ShimsLibDir = StdLibDir; + llvm::sys::path::append(ShimsLibDir, "shims"); + auto Target = llvm::Triple(llvm::sys::getDefaultTargetTriple()); + llvm::sys::path::append(StdLibDir, getPlatformNameForTriple(Target)); + + std::vector BaseCommandStrArr = { + TestPathStr, + std::string("-I ") + CHeadersDirPath, + std::string("-I ") + StdLibDir.str().str(), + std::string("-I ") + ShimsLibDir.str().str(), + // Pass in a flag which will cause every instantiation of + // the clang scanner to fail with "unknown argument" + "-Xcc", + "-foobar" + }; + + std::vector CommandStr = BaseCommandStrArr; + CommandStr.push_back("-module-name"); + CommandStr.push_back("testConcurrentDiags"); + + // On Windows we need to add an extra escape for path separator characters because otherwise + // the command line tokenizer will treat them as escape characters. + for (size_t i = 0; i < CommandStr.size(); ++i) { + std::replace(CommandStr[i].begin(), CommandStr[i].end(), '\\', '/'); + } + std::vector Command; + for (auto &command : CommandStr) + Command.push_back(command.c_str()); + + auto DependenciesOrErr = ScannerTool.getDependencies(Command, {}, {}); + + // Ensure a hollow output with diagnostic info is produced + ASSERT_FALSE(DependenciesOrErr.getError()); + auto Dependencies = DependenciesOrErr.get(); + auto Diagnostics = Dependencies->diagnostics; + ASSERT_TRUE(Diagnostics->count > 100); + swiftscan_dependency_graph_dispose(Dependencies); +}