diff --git a/include/swift/Basic/ThreadSafeRefCounted.h b/include/swift/Basic/ThreadSafeRefCounted.h index b8aa30ddaa343..7c3516b3eee65 100644 --- a/include/swift/Basic/ThreadSafeRefCounted.h +++ b/include/swift/Basic/ThreadSafeRefCounted.h @@ -15,37 +15,10 @@ #include #include +#include "llvm/ADT/IntrusiveRefCntPtr.h" namespace swift { -/// A thread-safe version of \c llvm::RefCountedBase. -/// -/// A generic base class for objects that wish to have their lifetimes managed -/// using reference counts. Classes subclass \c ThreadSafeRefCountedBase to -/// obtain such functionality, and are typically handled with -/// \c IntrusiveRefCntPtr "smart pointers" which automatically handle the -/// management of reference counts. -/// FIXME: This should eventually move to llvm. -template -class ThreadSafeRefCountedBase { - mutable std::atomic ref_cnt; - -protected: - ThreadSafeRefCountedBase() : ref_cnt(0) {} - -public: - void Retain() const { - ref_cnt.fetch_add(1, std::memory_order_acq_rel); - } - - void Release() const { - int refCount = - static_cast(ref_cnt.fetch_sub(1, std::memory_order_acq_rel)); - assert(refCount >= 0 && "Reference count was already zero."); - if (refCount == 0) delete static_cast(this); - } -}; - /// A class that has the same function as \c ThreadSafeRefCountedBase, but with /// a virtual destructor. /// @@ -62,12 +35,11 @@ class ThreadSafeRefCountedBaseVPTR { public: void Retain() const { - ref_cnt.fetch_add(1, std::memory_order_acq_rel); + ref_cnt += 1; } void Release() const { - int refCount = - static_cast(ref_cnt.fetch_sub(1, std::memory_order_acq_rel)); + int refCount = static_cast(--ref_cnt); assert(refCount >= 0 && "Reference count was already zero."); if (refCount == 0) delete this; } diff --git a/include/swift/IDE/CodeCompletionCache.h b/include/swift/IDE/CodeCompletionCache.h index 14e0b7c79fc61..e13e31b85b6e7 100644 --- a/include/swift/IDE/CodeCompletionCache.h +++ b/include/swift/IDE/CodeCompletionCache.h @@ -49,7 +49,7 @@ class CodeCompletionCache { } }; - struct Value : public ThreadSafeRefCountedBase { + struct Value : public llvm::ThreadSafeRefCountedBase { llvm::sys::TimeValue ModuleModificationTime; CodeCompletionResultSink Sink; }; diff --git a/tools/SourceKit/include/SourceKit/Core/LLVM.h b/tools/SourceKit/include/SourceKit/Core/LLVM.h index 449b72c740e96..f54234db21a75 100644 --- a/tools/SourceKit/include/SourceKit/Core/LLVM.h +++ b/tools/SourceKit/include/SourceKit/Core/LLVM.h @@ -43,6 +43,7 @@ namespace llvm { // Reference counting. template class IntrusiveRefCntPtr; template struct IntrusiveRefCntPtrInfo; + template class ThreadSafeRefCountedBase; class raw_ostream; // TODO: DenseMap, ... @@ -69,7 +70,6 @@ namespace llvm { } namespace swift { - template class ThreadSafeRefCountedBase; class ThreadSafeRefCountedBaseVPTR; } @@ -95,7 +95,7 @@ namespace SourceKit { // Reference counting. using llvm::IntrusiveRefCntPtr; using llvm::IntrusiveRefCntPtrInfo; - using swift::ThreadSafeRefCountedBase; + using llvm::ThreadSafeRefCountedBase; using swift::ThreadSafeRefCountedBaseVPTR; template class ThreadSafeRefCntPtr; diff --git a/tools/SourceKit/lib/SwiftLang/SwiftInterfaceGenContext.h b/tools/SourceKit/lib/SwiftLang/SwiftInterfaceGenContext.h index bfbd7e8a7d635..f3cbedbff8f4f 100644 --- a/tools/SourceKit/lib/SwiftLang/SwiftInterfaceGenContext.h +++ b/tools/SourceKit/lib/SwiftLang/SwiftInterfaceGenContext.h @@ -31,7 +31,7 @@ namespace SourceKit { typedef IntrusiveRefCntPtr ASTUnitRef; class SwiftInterfaceGenContext : - public swift::ThreadSafeRefCountedBase { + public llvm::ThreadSafeRefCountedBase { public: static SwiftInterfaceGenContextRef create(StringRef DocumentName, bool IsModule, diff --git a/tools/SourceKit/lib/SwiftLang/SwiftInvocation.h b/tools/SourceKit/lib/SwiftLang/SwiftInvocation.h index 6413acb2ff3d3..ab25def6bf4a6 100644 --- a/tools/SourceKit/lib/SwiftLang/SwiftInvocation.h +++ b/tools/SourceKit/lib/SwiftLang/SwiftInvocation.h @@ -26,7 +26,7 @@ namespace SourceKit { /// Encompasses an invocation for getting an AST. This is used to control AST /// sharing among different requests. -class SwiftInvocation : public swift::ThreadSafeRefCountedBase { +class SwiftInvocation : public llvm::ThreadSafeRefCountedBase { public: ~SwiftInvocation(); diff --git a/unittests/Basic/CMakeLists.txt b/unittests/Basic/CMakeLists.txt index 401dd043eb29b..e254411857a4d 100644 --- a/unittests/Basic/CMakeLists.txt +++ b/unittests/Basic/CMakeLists.txt @@ -19,6 +19,7 @@ add_swift_unittest(SwiftBasicTests SourceManager.cpp StringExtrasTest.cpp SuccessorMapTest.cpp + ThreadSafeRefCntPointerTests.cpp TreeScopedHashTableTests.cpp Unicode.cpp ${generated_tests} diff --git a/unittests/Basic/ThreadSafeRefCntPointerTests.cpp b/unittests/Basic/ThreadSafeRefCntPointerTests.cpp new file mode 100644 index 0000000000000..434be42d0a959 --- /dev/null +++ b/unittests/Basic/ThreadSafeRefCntPointerTests.cpp @@ -0,0 +1,51 @@ +#include "swift/Basic/ThreadSafeRefCounted.h" +#include "llvm/ADT/IntrusiveRefCntPtr.h" +#include "gtest/gtest.h" + +using llvm::IntrusiveRefCntPtr; + +struct TestRelease : llvm::ThreadSafeRefCountedBase { + bool &destroy; + TestRelease(bool &destroy) : destroy(destroy) {} + ~TestRelease() { destroy = true; } +}; + +TEST(ThreadSafeRefCountedBase, ReleaseSimple) { + bool destroyed = false; + { + IntrusiveRefCntPtr ref = new TestRelease(destroyed); + } + EXPECT_TRUE(destroyed); +} +TEST(ThreadSafeRefCountedBase, Release) { + bool destroyed = false; + { + IntrusiveRefCntPtr ref = new TestRelease(destroyed); + ref->Retain(); + ref->Release(); + } + EXPECT_TRUE(destroyed); +} + +struct TestReleaseVPTR : swift::ThreadSafeRefCountedBaseVPTR { + bool &destroy; + TestReleaseVPTR(bool &destroy) : destroy(destroy) {} + virtual ~TestReleaseVPTR() { destroy = true; } +}; + +TEST(ThreadSafeRefCountedBaseVPTR, ReleaseSimple) { + bool destroyed = false; + { + IntrusiveRefCntPtr ref = new TestReleaseVPTR(destroyed); + } + EXPECT_TRUE(destroyed); +} +TEST(ThreadSafeRefCountedBaseVPTR, Release) { + bool destroyed = false; + { + IntrusiveRefCntPtr ref = new TestReleaseVPTR(destroyed); + ref->Retain(); + ref->Release(); + } + EXPECT_TRUE(destroyed); +}