Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 3 additions & 31 deletions include/swift/Basic/ThreadSafeRefCounted.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,10 @@

#include <atomic>
#include <cassert>
#include "llvm/ADT/IntrusiveRefCntPtr.h"

namespace swift {

/// A thread-safe version of \c llvm::RefCountedBase.
///
/// A generic base class for objects that wish to have their lifetimes managed
/// using reference counts. Classes subclass \c ThreadSafeRefCountedBase to
/// obtain such functionality, and are typically handled with
/// \c IntrusiveRefCntPtr "smart pointers" which automatically handle the
/// management of reference counts.
/// FIXME: This should eventually move to llvm.
template <class Derived>
class ThreadSafeRefCountedBase {
mutable std::atomic<unsigned> ref_cnt;

protected:
ThreadSafeRefCountedBase() : ref_cnt(0) {}

public:
void Retain() const {
ref_cnt.fetch_add(1, std::memory_order_acq_rel);
}

void Release() const {
int refCount =
static_cast<int>(ref_cnt.fetch_sub(1, std::memory_order_acq_rel));
assert(refCount >= 0 && "Reference count was already zero.");
if (refCount == 0) delete static_cast<const Derived*>(this);
}
};

/// A class that has the same function as \c ThreadSafeRefCountedBase, but with
/// a virtual destructor.
///
Expand All @@ -62,12 +35,11 @@ class ThreadSafeRefCountedBaseVPTR {

public:
void Retain() const {
ref_cnt.fetch_add(1, std::memory_order_acq_rel);
ref_cnt += 1;
}

void Release() const {
int refCount =
static_cast<int>(ref_cnt.fetch_sub(1, std::memory_order_acq_rel));
int refCount = static_cast<int>(--ref_cnt);
assert(refCount >= 0 && "Reference count was already zero.");
if (refCount == 0) delete this;
}
Expand Down
2 changes: 1 addition & 1 deletion include/swift/IDE/CodeCompletionCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class CodeCompletionCache {
}
};

struct Value : public ThreadSafeRefCountedBase<Value> {
struct Value : public llvm::ThreadSafeRefCountedBase<Value> {
llvm::sys::TimeValue ModuleModificationTime;
CodeCompletionResultSink Sink;
};
Expand Down
4 changes: 2 additions & 2 deletions tools/SourceKit/include/SourceKit/Core/LLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ namespace llvm {
// Reference counting.
template <typename T> class IntrusiveRefCntPtr;
template <typename T> struct IntrusiveRefCntPtrInfo;
template <class Derived> class ThreadSafeRefCountedBase;

class raw_ostream;
// TODO: DenseMap, ...
Expand All @@ -69,7 +70,6 @@ namespace llvm {
}

namespace swift {
template <class Derived> class ThreadSafeRefCountedBase;
class ThreadSafeRefCountedBaseVPTR;
}

Expand All @@ -95,7 +95,7 @@ namespace SourceKit {
// Reference counting.
using llvm::IntrusiveRefCntPtr;
using llvm::IntrusiveRefCntPtrInfo;
using swift::ThreadSafeRefCountedBase;
using llvm::ThreadSafeRefCountedBase;
using swift::ThreadSafeRefCountedBaseVPTR;
template <typename T> class ThreadSafeRefCntPtr;

Expand Down
2 changes: 1 addition & 1 deletion tools/SourceKit/lib/SwiftLang/SwiftInterfaceGenContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace SourceKit {
typedef IntrusiveRefCntPtr<ASTUnit> ASTUnitRef;

class SwiftInterfaceGenContext :
public swift::ThreadSafeRefCountedBase<SwiftInterfaceGenContext> {
public llvm::ThreadSafeRefCountedBase<SwiftInterfaceGenContext> {
public:
static SwiftInterfaceGenContextRef create(StringRef DocumentName,
bool IsModule,
Expand Down
2 changes: 1 addition & 1 deletion tools/SourceKit/lib/SwiftLang/SwiftInvocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace SourceKit {

/// Encompasses an invocation for getting an AST. This is used to control AST
/// sharing among different requests.
class SwiftInvocation : public swift::ThreadSafeRefCountedBase<SwiftInvocation> {
class SwiftInvocation : public llvm::ThreadSafeRefCountedBase<SwiftInvocation> {
public:
~SwiftInvocation();

Expand Down
1 change: 1 addition & 0 deletions unittests/Basic/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_swift_unittest(SwiftBasicTests
SourceManager.cpp
StringExtrasTest.cpp
SuccessorMapTest.cpp
ThreadSafeRefCntPointerTests.cpp
TreeScopedHashTableTests.cpp
Unicode.cpp
${generated_tests}
Expand Down
51 changes: 51 additions & 0 deletions unittests/Basic/ThreadSafeRefCntPointerTests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include "swift/Basic/ThreadSafeRefCounted.h"
#include "llvm/ADT/IntrusiveRefCntPtr.h"
#include "gtest/gtest.h"

using llvm::IntrusiveRefCntPtr;

struct TestRelease : llvm::ThreadSafeRefCountedBase<TestRelease> {
bool &destroy;
TestRelease(bool &destroy) : destroy(destroy) {}
~TestRelease() { destroy = true; }
};

TEST(ThreadSafeRefCountedBase, ReleaseSimple) {
bool destroyed = false;
{
IntrusiveRefCntPtr<TestRelease> ref = new TestRelease(destroyed);
}
EXPECT_TRUE(destroyed);
}
TEST(ThreadSafeRefCountedBase, Release) {
bool destroyed = false;
{
IntrusiveRefCntPtr<TestRelease> ref = new TestRelease(destroyed);
ref->Retain();
ref->Release();
}
EXPECT_TRUE(destroyed);
}

struct TestReleaseVPTR : swift::ThreadSafeRefCountedBaseVPTR {
bool &destroy;
TestReleaseVPTR(bool &destroy) : destroy(destroy) {}
virtual ~TestReleaseVPTR() { destroy = true; }
};

TEST(ThreadSafeRefCountedBaseVPTR, ReleaseSimple) {
bool destroyed = false;
{
IntrusiveRefCntPtr<TestReleaseVPTR> ref = new TestReleaseVPTR(destroyed);
}
EXPECT_TRUE(destroyed);
}
TEST(ThreadSafeRefCountedBaseVPTR, Release) {
bool destroyed = false;
{
IntrusiveRefCntPtr<TestReleaseVPTR> ref = new TestReleaseVPTR(destroyed);
ref->Retain();
ref->Release();
}
EXPECT_TRUE(destroyed);
}