From ebc6c28a252915e6689ba7559bdc17eed0eea10c Mon Sep 17 00:00:00 2001 From: Kiran Chandramohan Date: Fri, 24 May 2024 11:21:35 +0100 Subject: [PATCH] =?UTF-8?q?Revert=20"[mlir]=20Fix=20race=20condition=20int?= =?UTF-8?q?roduced=20in=20ThreadLocalCache=20(#93=E2=80=A6=20(#93290)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …280)" This reverts commit 6977bfb57c3efb9488aef463cd7ea521fd25a067. --- mlir/include/mlir/Support/ThreadLocalCache.h | 97 +++++--------------- 1 file changed, 25 insertions(+), 72 deletions(-) diff --git a/mlir/include/mlir/Support/ThreadLocalCache.h b/mlir/include/mlir/Support/ThreadLocalCache.h index fe6c6fa3cf6bd9..d19257bf6e25e0 100644 --- a/mlir/include/mlir/Support/ThreadLocalCache.h +++ b/mlir/include/mlir/Support/ThreadLocalCache.h @@ -16,6 +16,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Mutex.h" namespace mlir { @@ -24,80 +25,28 @@ namespace mlir { /// cache has very large lock contention. template class ThreadLocalCache { - struct PerInstanceState; - - /// The "observer" is owned by a thread-local cache instance. It is - /// constructed the first time a `ThreadLocalCache` instance is accessed by a - /// thread, unless `perInstanceState` happens to get re-allocated to the same - /// address as a previous one. This class is destructed the thread in which - /// the `thread_local` cache lives is destroyed. - /// - /// This class is called the "observer" because while values cached in - /// thread-local caches are owned by `PerInstanceState`, a reference is stored - /// via this class in the TLC. With a double pointer, it knows when the - /// referenced value has been destroyed. - struct Observer { - /// This is the double pointer, explicitly allocated because we need to keep - /// the address stable if the TLC map re-allocates. It is owned by the - /// observer and shared with the value owner. - std::shared_ptr ptr = std::make_shared(nullptr); - /// Because `Owner` living inside `PerInstanceState` contains a reference to - /// the double pointer, and livkewise this class contains a reference to the - /// value, we need to synchronize destruction of the TLC and the - /// `PerInstanceState` to avoid racing. This weak pointer is acquired during - /// TLC destruction if the `PerInstanceState` hasn't entered its destructor - /// yet, and prevents it from happening. - std::weak_ptr keepalive; - }; - - /// This struct owns the cache entries. It contains a reference back to the - /// reference inside the cache so that it can be written to null to indicate - /// that the cache entry is invalidated. It needs to do this because - /// `perInstanceState` could get re-allocated to the same pointer and we don't - /// remove entries from the TLC when it is deallocated. Thus, we have to reset - /// the TLC entries to a starting state in case the `ThreadLocalCache` lives - /// shorter than the threads. - struct Owner { - /// Save a pointer to the reference and write it to the newly created entry. - Owner(Observer &observer) - : value(std::make_unique()), ptrRef(observer.ptr) { - *observer.ptr = value.get(); - } - ~Owner() { - if (std::shared_ptr ptr = ptrRef.lock()) - *ptr = nullptr; - } - - Owner(Owner &&) = default; - Owner &operator=(Owner &&) = default; - - std::unique_ptr value; - std::weak_ptr ptrRef; - }; - // Keep a separate shared_ptr protected state that can be acquired atomically // instead of using shared_ptr's for each value. This avoids a problem // where the instance shared_ptr is locked() successfully, and then the // ThreadLocalCache gets destroyed before remove() can be called successfully. struct PerInstanceState { - /// Remove the given value entry. This is called when a thread local cache - /// is destructing but still contains references to values owned by the - /// `PerInstanceState`. Removal is required because it prevents writeback to - /// a pointer that was deallocated. + /// Remove the given value entry. This is generally called when a thread + /// local cache is destructing. void remove(ValueT *value) { // Erase the found value directly, because it is guaranteed to be in the // list. llvm::sys::SmartScopedLock threadInstanceLock(instanceMutex); - auto it = llvm::find_if(instances, [&](Owner &instance) { - return instance.value.get() == value; - }); + auto it = + llvm::find_if(instances, [&](std::unique_ptr &instance) { + return instance.get() == value; + }); assert(it != instances.end() && "expected value to exist in cache"); instances.erase(it); } /// Owning pointers to all of the values that have been constructed for this /// object in the static cache. - SmallVector instances; + SmallVector, 1> instances; /// A mutex used when a new thread instance has been added to the cache for /// this object. @@ -108,14 +57,14 @@ class ThreadLocalCache { /// instance of the non-static cache and a weak reference to an instance of /// ValueT. We use a weak reference here so that the object can be destroyed /// without needing to lock access to the cache itself. - struct CacheType : public llvm::SmallDenseMap { + struct CacheType + : public llvm::SmallDenseMap, ValueT *>> { ~CacheType() { - // Remove the values of this cache that haven't already expired. This is - // required because if we don't remove them, they will contain a reference - // back to the data here that is being destroyed. - for (auto &[instance, observer] : *this) - if (std::shared_ptr state = observer.keepalive.lock()) - state->remove(*observer.ptr); + // Remove the values of this cache that haven't already expired. + for (auto &it : *this) + if (std::shared_ptr value = it.second.first.lock()) + it.first->remove(value.get()); } /// Clear out any unused entries within the map. This method is not @@ -123,7 +72,7 @@ class ThreadLocalCache { void clearExpiredEntries() { for (auto it = this->begin(), e = this->end(); it != e;) { auto curIt = it++; - if (!*curIt->second.ptr) + if (curIt->second.first.expired()) this->erase(curIt); } } @@ -140,23 +89,27 @@ class ThreadLocalCache { ValueT &get() { // Check for an already existing instance for this thread. CacheType &staticCache = getStaticCache(); - Observer &threadInstance = staticCache[perInstanceState.get()]; - if (ValueT *value = *threadInstance.ptr) + std::pair, ValueT *> &threadInstance = + staticCache[perInstanceState.get()]; + if (ValueT *value = threadInstance.second) return *value; // Otherwise, create a new instance for this thread. { llvm::sys::SmartScopedLock threadInstanceLock( perInstanceState->instanceMutex); - perInstanceState->instances.emplace_back(threadInstance); + threadInstance.second = + perInstanceState->instances.emplace_back(std::make_unique()) + .get(); } - threadInstance.keepalive = perInstanceState; + threadInstance.first = + std::shared_ptr(perInstanceState, threadInstance.second); // Before returning the new instance, take the chance to clear out any used // entries in the static map. The cache is only cleared within the same // thread to remove the need to lock the cache itself. staticCache.clearExpiredEntries(); - return **threadInstance.ptr; + return *threadInstance.second; } ValueT &operator*() { return get(); } ValueT *operator->() { return &get(); }