diff --git a/include/swift/ABI/MetadataValues.h b/include/swift/ABI/MetadataValues.h index 84f59a9d001e5..dd52a922a28ed 100644 --- a/include/swift/ABI/MetadataValues.h +++ b/include/swift/ABI/MetadataValues.h @@ -1936,9 +1936,9 @@ class JobFlags : public FlagSet { // Kind-specific flags. - Task_IsChildTask = 24, - Task_IsFuture = 25, - Task_IsTaskGroup = 26, + Task_IsChildTask = 24, + Task_IsFuture = 25, + Task_IsTaskGroup = 26 }; explicit JobFlags(size_t bits) : FlagSet(bits) {} @@ -1965,11 +1965,9 @@ class JobFlags : public FlagSet { FLAGSET_DEFINE_FLAG_ACCESSORS(Task_IsFuture, task_isFuture, task_setIsFuture) - FLAGSET_DEFINE_FLAG_ACCESSORS(Task_IsTaskGroup, task_isTaskGroup, task_setIsTaskGroup) - }; /// Kinds of task status record. diff --git a/include/swift/ABI/Task.h b/include/swift/ABI/Task.h index 23b7b68ebee6c..e72ff9ab2d111 100644 --- a/include/swift/ABI/Task.h +++ b/include/swift/ABI/Task.h @@ -137,14 +137,15 @@ class ActiveTaskStatus { /// ### Fragments /// An AsyncTask may have the following fragments: /// -/// +------------------+ -/// | childFragment? | -/// | groupFragment? | -/// | futureFragment? |* -/// +------------------+ +/// +--------------------------+ +/// | childFragment? | +/// | taskLocalValuesFragment? | +/// | groupFragment? | +/// | futureFragment? |* +/// +--------------------------+ /// -/// The future fragment is dynamic in size, based on the future result type -/// it can hold, and thus must be the *last* fragment. +/// * The future fragment is dynamic in size, based on the future result type +/// it can hold, and thus must be the *last* fragment. class AsyncTask : public HeapObject, public Job { public: /// The context for resuming the job. When a task is scheduled @@ -175,13 +176,15 @@ class AsyncTask : public HeapObject, public Job { void run(ExecutorRef currentExecutor) { ResumeTask(this, currentExecutor, ResumeContext); } - + /// Check whether this task has been cancelled. /// Checking this is, of course, inherently race-prone on its own. bool isCancelled() const { return Status.load(std::memory_order_relaxed).isCancelled(); } + // ==== Child Fragment ------------------------------------------------------- + /// A fragment of an async task structure that happens to be a child task. class ChildFragment { /// The parent task of this task. @@ -205,14 +208,252 @@ class AsyncTask : public HeapObject, public Job { } }; - // TODO: rename? all other functions are `is...` rather than `has...Fragment` - bool hasChildFragment() const { return Flags.task_isChildTask(); } + bool hasChildFragment() const { + return Flags.task_isChildTask(); + } ChildFragment *childFragment() { assert(hasChildFragment()); return reinterpret_cast(this + 1); } + // ==== Task Locals Values --------------------------------------------------- + + class TaskLocalValuesFragment { + public: + /// Type of the pointed at `next` task local item. + enum class NextLinkType : uintptr_t { + /// This task is known to be a "terminal" node in the lookup of task locals. + /// In other words, even if it had a parent, the parent (and its parents) + /// are known to not contain any any more task locals, and thus any further + /// search beyond this task. + IsTerminal = 0b00, + /// The storage pointer points at the next TaskLocalChainItem in this task. + IsNext = 0b01, + /// The storage pointer points at a parent AsyncTask, in which we should + /// continue the lookup. + /// + /// Note that this may not necessarily be the same as the task's parent + /// task -- we may point to a super-parent if we know / that the parent + /// does not "contribute" any task local values. This is to speed up + /// lookups by skipping empty parent tasks during get(), and explained + /// in depth in `createParentLink`. + IsParent = 0b11 + }; + + /// Values must match `TaskLocalInheritance` declared in `TaskLocal.swift`. + enum class TaskLocalInheritance : uint8_t { + Default = 0, + Never = 1 + }; + + class TaskLocalItem { + private: + /// Mask used for the low status bits in a task local chain item. + static const uintptr_t statusMask = 0x03; + + /// Pointer to the next task local item; be it in this task or in a parent. + /// Low bits encode `NextLinkType`. + /// TaskLocalItem *next = nullptr; + uintptr_t next; + + public: + /// The type of the key with which this value is associated. + const Metadata *keyType; + /// The type of the value stored by this item. + const Metadata *valueType; + + // Trailing storage for the value itself. The storage will be + // uninitialized or contain an instance of \c valueType. + + private: + explicit TaskLocalItem(const Metadata *keyType, const Metadata *valueType) + : keyType(keyType), + valueType(valueType), + next(0) { } + + public: + /// TaskLocalItem which does not by itself store any value, but only points + /// to the nearest task-local-value containing parent's first task item. + /// + /// This item type is used to link to the appropriate parent task's item, + /// when the current task itself does not have any task local values itself. + /// + /// When a task actually has its own task locals, it should rather point + /// to the parent's *first* task-local item in its *last* item, extending + /// the TaskLocalItem linked list into the appropriate parent. + static TaskLocalItem* createParentLink(AsyncTask *task, AsyncTask *parent) { + assert(parent); + size_t amountToAllocate = TaskLocalItem::itemSize(/*valueType*/nullptr); + // assert(amountToAllocate % MaximumAlignment == 0); // TODO: do we need this? + void *allocation = malloc(amountToAllocate); // TODO: use task-local allocator + + TaskLocalItem *item = + new(allocation) TaskLocalItem(nullptr, nullptr); + + auto parentHead = parent->localValuesFragment()->head; + if (parentHead) { + if (parentHead->isEmpty()) { + switch (parentHead->getNextLinkType()) { + case NextLinkType::IsParent: + // it has no values, and just points to its parent, + // therefore skip also skip pointing to that parent and point + // to whichever parent it was pointing to as well, it may be its + // immediate parent, or some super-parent. + item->next = reinterpret_cast(parentHead->getNext()); + static_cast(NextLinkType::IsParent); + break; + case NextLinkType::IsNext: + assert(false && "empty taskValue head in parent task, yet parent's 'head' is `IsNext`, " + "this should not happen, as it implies the parent must have stored some value."); + break; + case NextLinkType::IsTerminal: + item->next = reinterpret_cast(parentHead->getNext()); + static_cast(NextLinkType::IsTerminal); + break; + } + } else { + item->next = reinterpret_cast(parentHead) | + static_cast(NextLinkType::IsParent); + } + } else { + item->next = reinterpret_cast(parentHead) | + static_cast(NextLinkType::IsTerminal); + } + + return item; + } + + static TaskLocalItem* createLink(AsyncTask *task, + const Metadata *keyType, + const Metadata *valueType) { + assert(task); + size_t amountToAllocate = TaskLocalItem::itemSize(valueType); + // assert(amountToAllocate % MaximumAlignment == 0); // TODO: do we need this? + void *allocation = malloc(amountToAllocate); // TODO: use task-local allocator + TaskLocalItem *item = + new(allocation) TaskLocalItem(keyType, valueType); + + auto next = task->localValuesFragment()->head; + auto nextLinkType = next ? NextLinkType::IsNext : NextLinkType::IsTerminal; + item->next = reinterpret_cast(next) | + static_cast(nextLinkType); + + return item; + } + + void destroy() { + if (valueType) { + valueType->vw_destroy(getStoragePtr()); + } + } + + TaskLocalItem *getNext() { + return reinterpret_cast(next & ~statusMask); + } + + NextLinkType getNextLinkType() { + return static_cast(next & statusMask); + } + + /// Item does not contain any actual value, and is only used to point at + /// a specific parent item. + bool isEmpty() { + return !valueType; + } + + /// Retrieve a pointer to the storage of the value. + OpaqueValue *getStoragePtr() { + return reinterpret_cast( + reinterpret_cast(this) + storageOffset(valueType)); + } + + /// Compute the offset of the storage from the base of the item. + static size_t storageOffset(const Metadata *valueType) { + size_t offset = sizeof(TaskLocalItem); + if (valueType) { + size_t alignment = valueType->vw_alignment(); + return (offset + alignment - 1) & ~(alignment - 1); + } else { + return offset; + } + } + + /// Determine the size of the item given a particular value type. + static size_t itemSize(const Metadata *valueType) { + size_t offset = storageOffset(valueType); + if (valueType) { + offset += valueType->vw_size(); + } + return offset; + } + }; + + private: + /// A stack (single-linked list) of task local values. + /// + /// Once task local values within this task are traversed, the list continues + /// to the "next parent that contributes task local values," or if no such + /// parent exists it terminates with null. + /// + /// If the TaskLocalValuesFragment was allocated, it is expected that this + /// value should be NOT null; it either has own values, or at least one + /// parent that has values. If this task does not have any values, the head + /// pointer MAY immediately point at this task's parent task which has values. + /// + /// ### Concurrency + /// Access to the head is only performed from the task itself, when it + /// creates child tasks, the child during creation will inspect its parent's + /// task local value stack head, and point to it. This is done on the calling + /// task, and thus needs not to be synchronized. Subsequent traversal is + /// performed by child tasks concurrently, however they use their own + /// pointers/stack and can never mutate the parent's stack. + /// + /// The stack is only pushed/popped by the owning task, at the beginning and + /// end a `body` block of `withLocal(_:boundTo:body:)` respectively. + /// + /// Correctness of the stack strongly relies on the guarantee that tasks + /// never outline a scope in which they are created. Thanks to this, if + /// tasks are created inside the `body` of `withLocal(_:,boundTo:body:)` + /// all tasks created inside the `withLocal` body must complete before it + /// returns, as such, any child tasks potentially accessing the value stack + /// are guaranteed to be completed by the time we pop values off the stack + /// (after the body has completed). + TaskLocalItem *head = nullptr; + + public: + TaskLocalValuesFragment() {} + + void destroy(); + + /// If the parent task has task local values defined, point to in + /// the task local values chain. + void initializeLinkParent(AsyncTask* task, AsyncTask* parent); + + void pushValue(AsyncTask *task, const Metadata *keyType, + /* +1 */ OpaqueValue *value, const Metadata *valueType); + + void popValue(AsyncTask *task); + + OpaqueValue* get(const Metadata *keType, TaskLocalInheritance inheritance); + }; + + TaskLocalValuesFragment *localValuesFragment() { + auto offset = reinterpret_cast(this); + offset += sizeof(AsyncTask); + + if (hasChildFragment()) { + offset += sizeof(ChildFragment); + } + + return reinterpret_cast(offset); + } + + OpaqueValue* localValueGet(const Metadata *keyType, + TaskLocalValuesFragment::TaskLocalInheritance inheritance) { + return localValuesFragment()->get(keyType, inheritance); + } + // ==== TaskGroup ------------------------------------------------------------ class GroupFragment { @@ -516,12 +757,16 @@ class AsyncTask : public HeapObject, public Job { GroupFragment *groupFragment() { assert(isTaskGroup()); + auto offset = reinterpret_cast(this); + offset += sizeof(AsyncTask); + if (hasChildFragment()) { - return reinterpret_cast( - reinterpret_cast(this + 1) + 1); + offset += sizeof(ChildFragment); } - return reinterpret_cast(this + 1); + offset += sizeof(TaskLocalValuesFragment); + + return reinterpret_cast(offset); } /// Offer result of a task into this channel. @@ -647,13 +892,15 @@ class AsyncTask : public HeapObject, public Job { FutureFragment *futureFragment() { assert(isFuture()); - auto offset = reinterpret_cast(this); // TODO: char* instead? + auto offset = reinterpret_cast(this); offset += sizeof(AsyncTask); if (hasChildFragment()) { offset += sizeof(ChildFragment); } + offset += sizeof(TaskLocalValuesFragment); + if (isTaskGroup()) { offset += sizeof(GroupFragment); } diff --git a/include/swift/Runtime/Concurrency.h b/include/swift/Runtime/Concurrency.h index 403e5dd28f1d4..7971d921935b1 100644 --- a/include/swift/Runtime/Concurrency.h +++ b/include/swift/Runtime/Concurrency.h @@ -225,6 +225,55 @@ size_t swift_task_getJobFlags(AsyncTask* task); SWIFT_EXPORT_FROM(swift_Concurrency) SWIFT_CC(swift) bool swift_task_isCancelled(AsyncTask* task); +using TaskLocalValuesFragment = AsyncTask::TaskLocalValuesFragment; + +/// Get a task local value from the passed in task. Its Swift signature is +/// +/// \code +/// func _taskLocalValueGet( +/// _ task: Builtin.NativeObject, +/// keyType: Any.Type /*Key.Type*/, +/// inheritance: UInt8/*TaskLocalInheritance*/ +/// ) -> UnsafeMutableRawPointer? where Key: TaskLocalKey +/// \endcode +SWIFT_EXPORT_FROM(swift_Concurrency) SWIFT_CC(swift) +OpaqueValue* swift_task_localValueGet(AsyncTask* task, + const Metadata *keyType, + TaskLocalValuesFragment::TaskLocalInheritance inheritance); + +/// Add a task local value to the passed in task. +/// +/// This must be only invoked by the task itself to avoid concurrent writes. +/// +/// Its Swift signature is +/// +/// \code +/// public func _taskLocalValuePush( +/// _ task: Builtin.NativeObject, +/// keyType: Any.Type/*Key.Type*/, +/// value: __owned Value +/// ) +/// \endcode +SWIFT_EXPORT_FROM(swift_Concurrency) SWIFT_CC(swift) +void swift_task_localValuePush(AsyncTask* task, + const Metadata *keyType, + /* +1 */ OpaqueValue *value, + const Metadata *valueType); + +/// Remove task a local binding from the task local values stack. +/// +/// This must be only invoked by the task itself to avoid concurrent writes. +/// +/// Its Swift signature is +/// +/// \code +/// public func _taskLocalValuePop( +/// _ task: Builtin.NativeObject +/// ) +/// \endcode +SWIFT_EXPORT_FROM(swift_Concurrency) SWIFT_CC(swift) +void swift_task_localValuePop(AsyncTask* task); + /// This should have the same representation as an enum like this: /// enum NearestTaskDeadline { /// case none diff --git a/stdlib/public/Concurrency/CMakeLists.txt b/stdlib/public/Concurrency/CMakeLists.txt index ea98ad19b5513..060de9c0c50e9 100644 --- a/stdlib/public/Concurrency/CMakeLists.txt +++ b/stdlib/public/Concurrency/CMakeLists.txt @@ -63,6 +63,8 @@ add_swift_target_library(swift_Concurrency ${SWIFT_STDLIB_LIBRARY_BUILD_TYPES} I TaskStatus.cpp TaskGroup.cpp TaskGroup.swift + TaskLocal.cpp + TaskLocal.swift Mutex.cpp ${swift_concurrency_objc_sources} diff --git a/stdlib/public/Concurrency/Task.cpp b/stdlib/public/Concurrency/Task.cpp index bc550decf32d5..ceb87e9493c5d 100644 --- a/stdlib/public/Concurrency/Task.cpp +++ b/stdlib/public/Concurrency/Task.cpp @@ -32,6 +32,8 @@ using namespace swift; using FutureFragment = AsyncTask::FutureFragment; using GroupFragment = AsyncTask::GroupFragment; +using TaskLocalValuesFragment = AsyncTask::TaskLocalValuesFragment; +using TaskLocalInheritance = AsyncTask::TaskLocalValuesFragment::TaskLocalInheritance; void FutureFragment::destroy() { auto queueHead = waitQueue.load(std::memory_order_acquire); @@ -145,6 +147,9 @@ static void destroyTask(SWIFT_CONTEXT HeapObject *obj) { task->futureFragment()->destroy(); } + // release any objects potentially held as task local values. + task->localValuesFragment()->destroy(); + // The task execution itself should always hold a reference to it, so // if we get here, we know the task has finished running, which means // swift_task_complete should have been run, which will have torn down @@ -233,6 +238,8 @@ AsyncTaskAndContext swift::swift_task_create_future_f( headerSize += sizeof(AsyncTask::ChildFragment); } + headerSize += sizeof(AsyncTask::TaskLocalValuesFragment); + if (flags.task_isTaskGroup()) { headerSize += sizeof(AsyncTask::GroupFragment); } @@ -268,12 +275,16 @@ AsyncTaskAndContext swift::swift_task_create_future_f( new (childFragment) AsyncTask::ChildFragment(parent); } - // Initialize the channel fragment if applicable. - if (flags.task_isTaskGroup()) { + auto taskLocalsFragment = task->localValuesFragment(); + new (taskLocalsFragment) AsyncTask::TaskLocalValuesFragment(); + taskLocalsFragment->initializeLinkParent(task, parent); + + // Initialize the task group fragment if applicable. + if (flags.task_isTaskGroup()) { auto groupFragment = task->groupFragment(); new (groupFragment) GroupFragment(); } - + // Initialize the future fragment if applicable. if (futureResultType) { assert(task->isFuture()); @@ -300,8 +311,7 @@ AsyncTaskAndContext swift::swift_task_create_future_f( initialContext->Flags.setShouldNotDeallocateInCallee(true); // Initialize the task-local allocator. - // TODO: consider providing an initial pre-allocated first slab to the - // allocator. + // TODO: consider providing an initial pre-allocated first slab to the allocator. _swift_task_alloc_initialize(task); return {task, initialContext}; @@ -464,8 +474,8 @@ void swift::swift_task_runAndBlockThread(const void *function, RunAndBlockSemaphore semaphore; // Set up a task that runs the runAndBlock async function above. - auto pair = swift_task_create_f(JobFlags(JobKind::Task, - JobPriority::Default), + auto flags = JobFlags(JobKind::Task, JobPriority::Default); + auto pair = swift_task_create_f(flags, /*parent*/ nullptr, &runAndBlock_start, sizeof(RunAndBlockContext)); @@ -485,6 +495,23 @@ size_t swift::swift_task_getJobFlags(AsyncTask *task) { return task->Flags.getOpaqueValue(); } +void swift::swift_task_localValuePush(AsyncTask *task, + const Metadata *keyType, + /* +1 */ OpaqueValue *value, + const Metadata *valueType) { + task->localValuesFragment()->pushValue(task, keyType, value, valueType); +} + +void swift::swift_task_localValuePop(AsyncTask *task) { + task->localValuesFragment()->popValue(task); +} + +OpaqueValue* swift::swift_task_localValueGet(AsyncTask *task, + const Metadata *keyType, + TaskLocalInheritance inheritance) { + return task->localValueGet(keyType, inheritance); +} + namespace { /// Structure that gets filled in when a task is suspended by `withUnsafeContinuation`. diff --git a/stdlib/public/Concurrency/Task.swift b/stdlib/public/Concurrency/Task.swift index f5af87802ea3f..fe0d3047d8df5 100644 --- a/stdlib/public/Concurrency/Task.swift +++ b/stdlib/public/Concurrency/Task.swift @@ -324,7 +324,7 @@ extension Task { } } - /// Whether this is a channel. + /// Whether this is a task group. var isTaskGroup: Bool { get { (bits & (1 << 26)) != 0 @@ -339,6 +339,21 @@ extension Task { } } + /// Whether this (or its parents) have task local values. + var hasLocalValues: Bool { + get { + (bits & (1 << 27)) != 0 + } + + set { + if newValue { + bits = bits | 1 << 27 + } else { + bits = (bits & ~(1 << 27)) + } + } + } + } } @@ -381,6 +396,7 @@ extension Task { priority: Priority = .default, startingOn executor: ExecutorRef? = nil, operation: @concurrent @escaping () async -> T + // TODO: Allow inheriting task-locals? ) -> Handle { assert(executor == nil, "Custom executor support is not implemented yet.") // FIXME @@ -639,9 +655,17 @@ public func _runChildTask( return task } +class StringLike: CustomStringConvertible { + let value: String + init(_ value: String) { + self.value = value + } + var description: String { value } +} public func _runGroupChildTask( overridingPriority priorityOverride: Task.Priority? = nil, + withLocalValues hasLocalValues: Bool = false, operation: @concurrent @escaping () async throws -> T ) async -> Builtin.NativeObject { let currentTask = Builtin.getCurrentAsyncTask() diff --git a/stdlib/public/Concurrency/TaskGroup.cpp b/stdlib/public/Concurrency/TaskGroup.cpp index 0446713a5dd4a..797d92ed7d35a 100644 --- a/stdlib/public/Concurrency/TaskGroup.cpp +++ b/stdlib/public/Concurrency/TaskGroup.cpp @@ -1,4 +1,4 @@ -//===--- TaskGroup.cpp - Task Group internal message channel ------------===// +//===--- TaskGroup.cpp - Task Groups --------------------------------------===// // // This source file is part of the Swift.org open source project // @@ -10,7 +10,7 @@ // //===----------------------------------------------------------------------===// // -// Object management for async child tasks that are children of a task group. +// Object management for child tasks that are children of a task group. // //===----------------------------------------------------------------------===// diff --git a/stdlib/public/Concurrency/TaskLocal.cpp b/stdlib/public/Concurrency/TaskLocal.cpp new file mode 100644 index 0000000000000..2afe4ed629774 --- /dev/null +++ b/stdlib/public/Concurrency/TaskLocal.cpp @@ -0,0 +1,97 @@ +//===--- TaskLocal.cpp - Task Local Values --------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +#include "swift/Runtime/Concurrency.h" +#include "swift/ABI/Task.h" +#include "swift/ABI/Metadata.h" + +using namespace swift; +using TaskLocalValuesFragment = AsyncTask::TaskLocalValuesFragment; + +// ============================================================================= +// ==== destroy ---------------------------------------------------------------- + +void TaskLocalValuesFragment::destroy() { + auto item = head; + head = nullptr; + TaskLocalItem *next; + while (item) { + switch (item->getNextLinkType()) { + case TaskLocalValuesFragment::NextLinkType::IsNext: + next = item->getNext(); + item->destroy(); + free(item); + item = next; + break; + + case TaskLocalValuesFragment::NextLinkType::IsParent: + case TaskLocalValuesFragment::NextLinkType::IsTerminal: + // we're done here, we must not destroy values owned by the parent task. + return; + } + } +} + +// ============================================================================= +// ==== Initialization --------------------------------------------------------- + +void TaskLocalValuesFragment::initializeLinkParent(AsyncTask* task, + AsyncTask* parent) { + assert(!head && "fragment was already initialized"); + if (parent) { + head = TaskLocalItem::createParentLink(task, parent); + } +} + +// ============================================================================= +// ==== push / pop / get ------------------------------------------------------- + +void TaskLocalValuesFragment::pushValue(AsyncTask *task, + const Metadata *keyType, + /* +1 */ OpaqueValue *value, + const Metadata *valueType) { + assert(value && "Task local value must not be nil"); + + auto item = TaskLocalItem::createLink(task, keyType, valueType); + valueType->vw_initializeWithTake(item->getStoragePtr(), value); + head = item; +} + +void TaskLocalValuesFragment::popValue(AsyncTask *task) { + assert(head && "attempted to pop value off empty task-local stack"); + head->destroy(); + head = head->getNext(); +} + +OpaqueValue *TaskLocalValuesFragment::get( + const Metadata *keyType, + const TaskLocalInheritance inherit) { + assert(keyType && "Task.Local key must not be null."); + + auto item = head; + while (item) { + if (item->keyType == keyType) { + return item->getStoragePtr(); + } + + // if the key is an `inherit = .never` type, we stop our search the first + // time we would be jumping to a parent task to continue the search. + if (item->getNextLinkType() == NextLinkType::IsParent && + inherit == TaskLocalInheritance::Never) + return nullptr; + + item = item->getNext(); + } + + return nullptr; +} + diff --git a/stdlib/public/Concurrency/TaskLocal.swift b/stdlib/public/Concurrency/TaskLocal.swift new file mode 100644 index 0000000000000..90f40e9ee37ed --- /dev/null +++ b/stdlib/public/Concurrency/TaskLocal.swift @@ -0,0 +1,178 @@ +////===----------------------------------------------------------------------===// +//// +//// This source file is part of the Swift.org open source project +//// +//// Copyright (c) 2020 Apple Inc. and the Swift project authors +//// Licensed under Apache License v2.0 with Runtime Library Exception +//// +//// See https://swift.org/LICENSE.txt for license information +//// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +//// +////===----------------------------------------------------------------------===// + +import Swift +@_implementationOnly import _SwiftConcurrencyShims + +/// Namespace for declaring `TaskLocalKey`s. +public enum TaskLocalValues {} + +/// A `TaskLocalKey` is used to identify, bind and get a task local value from +/// a `Task` in which a function is currently executing. +/// +/// - SeeAlso: `Task.withLocal(_:boundTo:operation:)` +/// - SeeAlso: `Task.local(_:)` +public protocol TaskLocalKey { + /// The type of `Value` uniquely identified by this key. + associatedtype Value + + /// If a task local value is not present in a given context, its `defaultValue` + /// will be returned instead. + /// + /// A common pattern is to use an `Optional` type and use `nil` as default value, + /// if the type itself does not have a good "undefined" or "zero" value that could + /// be used here. + static var defaultValue: Value { get } + + /// Allows configuring specialized inheritance strategies for task local values. + /// + /// By default, task local values are accessible by the current or any of its + /// child tasks (with this rule applying recursively). + /// + /// Some, rare yet important, use-cases may require specialized inheritance + /// strategies, and this property allows them to configure these for their keys. + static var inherit: TaskLocalInheritance { get } +} + +extension TaskLocalKey { + public static var inherit: TaskLocalInheritance { .default } +} + +/// Determines task local value behavior in child tasks. +// TODO: should likely remain extensible +public enum TaskLocalInheritance: UInt8, Equatable { + /// The default inheritance strategy. + /// + /// Task local values whose keys are `default` inherited are available to the + /// task which declared them, as well as recursively by any child tasks + case `default` = 0 + + /// Causes task local values to never be inherited. + /// If the parent task has a value bound using this key, and a child task + /// attempts to look up a value of that key, it will return `defaultValue`. + case never = 1 +} + +extension Task { + + /// Read a task-local value, bound to the specified key. + /// + /// - Parameter keyPath: key path to the `TaskLocalKey` to be used for lookup + /// - Returns: the value bound to the key, or its default value it if was not + /// bound in the current (or any parent) tasks. + public static func local(_ keyPath: KeyPath) + async -> Key.Value where Key: TaskLocalKey { + let task = Builtin.getCurrentAsyncTask() + + let value = _taskLocalValueGet(task, keyType: Key.self, inheritance: Key.inherit.rawValue) + guard let rawValue = value else { + return Key.defaultValue + } + + // Take the value; The type should be correct by construction + let storagePtr = + rawValue.bindMemory(to: Key.Value.self, capacity: 1) + return UnsafeMutablePointer(mutating: storagePtr).pointee + } + + /// Bind the task local key to the given value for the scope of the `body` function. + /// Any child tasks spawned within this scope will inherit the binding. + /// + /// - Parameters: + /// - keyPath: key path to the `TaskLocalKey` to be used for lookup + /// - value: + /// - body: + /// - Returns: the value returned by the `body` function. + public static func withLocal( + _ keyPath: KeyPath, + boundTo value: Key.Value, + body: @escaping () async -> BodyResult + ) async -> BodyResult where Key: TaskLocalKey { + let task = Builtin.getCurrentAsyncTask() + + _taskLocalValuePush(task, keyType: Key.self, value: value) + + defer { + _taskLocalValuePop(task) + } + + return await body() + } + + /// Bind the task local key to the given value for the scope of the `body` function. + /// Any child tasks spawned within this scope will inherit the binding. + /// + /// - Parameters: + /// - key: + /// - value: + /// - body: + /// - Returns: the value returned by the `body` function, or throws. + public static func withLocal( + _ keyPath: KeyPath, + boundTo value: Key.Value, + body: @escaping () async throws -> BodyResult + ) async throws -> BodyResult where Key: TaskLocalKey { + let task = Builtin.getCurrentAsyncTask() + + _taskLocalValuePush(task, keyType: Key.self, value: value) + + defer { + _taskLocalValuePop(task) + } + + return try! await body() + } +} + +// ==== ------------------------------------------------------------------------ + +/// A type-erased `TaskLocalKey` used when iterating through the `Baggage` using its `forEach` method. +struct AnyTaskLocalKey { + let keyType: Any.Type + let valueType: Any.Type + + init(_: Key.Type) where Key: TaskLocalKey { + self.keyType = Key.self + self.valueType = Key.Value.self + } +} + +extension AnyTaskLocalKey: Hashable { + static func ==(lhs: AnyTaskLocalKey, rhs: AnyTaskLocalKey) -> Bool { + return ObjectIdentifier(lhs.keyType) == ObjectIdentifier(rhs.keyType) + } + + func hash(into hasher: inout Hasher) { + hasher.combine(ObjectIdentifier(self.keyType)) + } +} + +// ==== ------------------------------------------------------------------------ + +@_silgen_name("swift_task_localValuePush") +public func _taskLocalValuePush( + _ task: Builtin.NativeObject, + keyType: Any.Type/*Key.Type*/, + value: __owned Value +) // where Key: TaskLocalKey + +@_silgen_name("swift_task_localValuePop") +public func _taskLocalValuePop( + _ task: Builtin.NativeObject +) + +@_silgen_name("swift_task_localValueGet") +public func _taskLocalValueGet( + _ task: Builtin.NativeObject, + keyType: Any.Type/*Key.Type*/, + inheritance: UInt8/*TaskLocalInheritance*/ +) -> UnsafeMutableRawPointer? // where Key: TaskLocalKey diff --git a/test/Concurrency/Runtime/task_locals_async_let.swift b/test/Concurrency/Runtime/task_locals_async_let.swift new file mode 100644 index 0000000000000..3d4f76eb8f831 --- /dev/null +++ b/test/Concurrency/Runtime/task_locals_async_let.swift @@ -0,0 +1,87 @@ +// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-concurrency -parse-as-library %import-libdispatch) | %FileCheck %s + +// REQUIRES: executable_test +// REQUIRES: concurrency +// REQUIRES: libdispatch + +class StringLike: CustomStringConvertible { + let value: String + init(_ value: String) { + self.value = value + } + + var description: String { value } +} + +extension TaskLocalValues { + struct NumberKey: TaskLocalKey { + static var defaultValue: Int { 0 } + } + var number: NumberKey { .init() } +} + +@discardableResult +func printTaskLocal( + _ key: KeyPath, + _ expected: Key.Value? = nil, + file: String = #file, line: UInt = #line +) async throws -> Key.Value? where Key: TaskLocalKey { + let value = await Task.local(key) + print("\(Key.self): \(value) at \(file):\(line)") + if let expected = expected { + assert("\(expected)" == "\(value)", + "Expected [\(expected)] but found: \(value), at \(file):\(line)") + } + return expected +} + +// ==== ------------------------------------------------------------------------ + +func async_let_nested() async { + _ = try! await printTaskLocal(\.number) // COM: NumberKey: 0 {{.*}} + async let x1: () = Task.withLocal(\.number, boundTo: 2) { + async let x2 = printTaskLocal(\.number) // COM: NumberKey: 2 {{.*}} + + @concurrent + func test() async { + try! await printTaskLocal(\.number) // COM: NumberKey: 2 {{.*}} + async let x31 = printTaskLocal(\.number) // COM: NumberKey: 2 {{.*}} + _ = try! await x31 + } + async let x3: () = test() + + _ = try! await x2 + await x3 + } + + _ = await x1 + try! await printTaskLocal(\.number) // COM: NumberKey: 0 {{.*}} +} + +func async_let_nested_skip_optimization() async { + async let x1: Int? = Task.withLocal(\.number, boundTo: 2) { + async let x2: Int? = { () async -> Int? in + async let x3: Int? = { () async -> Int? in + async let x4: Int? = { () async -> Int? in + async let x5: Int? = { () async -> Int? in + async let xx = printTaskLocal(\.number) // CHECK: NumberKey: 2 {{.*}} + return try! await xx + }() + return await x5 + }() + return await x4 + }() + return await x3 + }() + return await x2 + } + + _ = await x1 +} + +@main struct Main { + static func main() async { + await async_let_nested() + await async_let_nested_skip_optimization() + } +} diff --git a/test/Concurrency/Runtime/task_locals_basic.swift b/test/Concurrency/Runtime/task_locals_basic.swift new file mode 100644 index 0000000000000..4fb62dbc8ad3a --- /dev/null +++ b/test/Concurrency/Runtime/task_locals_basic.swift @@ -0,0 +1,138 @@ +// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-concurrency -parse-as-library %import-libdispatch) | %FileCheck %s + +// REQUIRES: executable_test +// REQUIRES: concurrency +// REQUIRES: libdispatch + +class StringLike: CustomStringConvertible { + let value: String + init(_ value: String) { + self.value = value + } + + var description: String { value } +} + + +extension TaskLocalValues { + + struct StringKey: TaskLocalKey { + static var defaultValue: String { .init("") } + } + var string: StringKey { .init() } + + struct NumberKey: TaskLocalKey { + static var defaultValue: Int { 0 } + } + var number: NumberKey { .init() } + + struct NeverKey: TaskLocalKey { + static var defaultValue: StringLike { .init("") } + } + var never: NeverKey { .init() } + + struct ClazzKey: TaskLocalKey { + static var defaultValue: ClassTaskLocal? { nil } + } + var clazz: ClazzKey { .init() } + +} + +final class ClassTaskLocal { + init() { + print("clazz init \(ObjectIdentifier(self))") + } + + deinit { + print("clazz deinit \(ObjectIdentifier(self))") + } +} + +func printTaskLocal( + _ key: KeyPath, + _ expected: Key.Value? = nil, + file: String = #file, line: UInt = #line +) async throws where Key: TaskLocalKey { + let value = await Task.local(key) + print("\(Key.self): \(value) at \(file):\(line)") + if let expected = expected { + assert("\(expected)" == "\(value)", + "Expected [\(expected)] but found: \(value), at \(file):\(line)") + } +} + +// ==== ------------------------------------------------------------------------ + +func simple() async { + try! await printTaskLocal(\.number) // CHECK: NumberKey: 0 {{.*}} + await Task.withLocal(\.number, boundTo: 1) { + try! await printTaskLocal(\.number) // CHECK-NEXT: NumberKey: 1 {{.*}} + } +} + +func simple_deinit() async { + await Task.withLocal(\.clazz, boundTo: ClassTaskLocal()) { + // CHECK: clazz init [[C:.*]] + try! await printTaskLocal(\.clazz) // CHECK: ClazzKey: Optional(main.ClassTaskLocal) {{.*}} + } + // CHECK: clazz deinit [[C]] + try! await printTaskLocal(\.clazz) // CHECK: ClazzKey: nil {{.*}} +} + +func nested() async { + try! await printTaskLocal(\.string) // CHECK: StringKey: {{.*}} + await Task.withLocal(\.string, boundTo: "hello") { + try! await printTaskLocal(\.number) // CHECK-NEXT: NumberKey: 0 {{.*}} + try! await printTaskLocal(\.string)// CHECK-NEXT: StringKey: hello {{.*}} + await Task.withLocal(\.number, boundTo: 2) { + try! await printTaskLocal(\.number) // CHECK-NEXT: NumberKey: 2 {{.*}} + try! await printTaskLocal(\.string, "hello") // CHECK: StringKey: hello {{.*}} + } + try! await printTaskLocal(\.number) // CHECK-NEXT: NumberKey: 0 {{.*}} + try! await printTaskLocal(\.string) // CHECK-NEXT: StringKey: hello {{.*}} + } + try! await printTaskLocal(\.number) // CHECK-NEXT: NumberKey: 0 {{.*}} + try! await printTaskLocal(\.string) // CHECK-NEXT: StringKey: {{.*}} +} + +func nested_allContribute() async { + try! await printTaskLocal(\.string) // CHECK: StringKey: {{.*}} + await Task.withLocal(\.string, boundTo: "one") { + try! await printTaskLocal(\.string, "one")// CHECK-NEXT: StringKey: one {{.*}} + await Task.withLocal(\.string, boundTo: "two") { + try! await printTaskLocal(\.string, "two") // CHECK-NEXT: StringKey: two {{.*}} + await Task.withLocal(\.string, boundTo: "three") { + try! await printTaskLocal(\.string, "three") // CHECK-NEXT: StringKey: three {{.*}} + } + try! await printTaskLocal(\.string, "two") // CHECK-NEXT: StringKey: two {{.*}} + } + try! await printTaskLocal(\.string, "one")// CHECK-NEXT: StringKey: one {{.*}} + } + try! await printTaskLocal(\.string) // CHECK-NEXT: StringKey: {{.*}} +} + +func nested_3_onlyTopContributes() async { + try! await printTaskLocal(\.string) // CHECK: StringKey: {{.*}} + await Task.withLocal(\.string, boundTo: "one") { + try! await printTaskLocal(\.string)// CHECK-NEXT: StringKey: one {{.*}} + await Task.withLocal(\.number, boundTo: 2) { + try! await printTaskLocal(\.string) // CHECK-NEXT: StringKey: one {{.*}} + await Task.withLocal(\.number, boundTo: 3) { + try! await printTaskLocal(\.string) // CHECK-NEXT: StringKey: one {{.*}} + } + try! await printTaskLocal(\.string) // CHECK-NEXT: StringKey: one {{.*}} + } + try! await printTaskLocal(\.string)// CHECK-NEXT: StringKey: one {{.*}} + } + try! await printTaskLocal(\.string) // CHECK-NEXT: StringKey: {{.*}} +} + +@main struct Main { + static func main() async { + await simple() + await simple_deinit() + await nested() + await nested_allContribute() + await nested_3_onlyTopContributes() + } +} diff --git a/test/Concurrency/Runtime/task_locals_groups.swift b/test/Concurrency/Runtime/task_locals_groups.swift new file mode 100644 index 0000000000000..6c16940fac4a3 --- /dev/null +++ b/test/Concurrency/Runtime/task_locals_groups.swift @@ -0,0 +1,85 @@ +// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-concurrency -parse-as-library %import-libdispatch) | %FileCheck %s + +// REQUIRES: executable_test +// REQUIRES: concurrency +// REQUIRES: libdispatch + +class StringLike: CustomStringConvertible { + let value: String + init(_ value: String) { + self.value = value + } + + var description: String { value } +} + + +extension TaskLocalValues { + struct NumberKey: TaskLocalKey { + static var defaultValue: Int { 0 } + } + var number: NumberKey { .init() } +} + +func printTaskLocal( + _ key: KeyPath, + _ expected: Key.Value? = nil, + file: String = #file, line: UInt = #line +) async throws where Key: TaskLocalKey { + let value = await Task.local(key) + print("\(Key.self): \(value) at \(file):\(line)") + if let expected = expected { + assert("\(expected)" == "\(value)", + "Expected [\(expected)] but found: \(value), at \(file):\(line)") + } +} + +// ==== ------------------------------------------------------------------------ + + +func groups() async { + // no value + try! await Task.withGroup(resultType: Int.self) { group in + try! await printTaskLocal(\.number) // CHECK: NumberKey: 0 {{.*}} + } + + // no value in parent, value in child + let x1: Int = try! await Task.withGroup(resultType: Int.self) { group in + await group.add { + try! await printTaskLocal(\.number) // CHECK: NumberKey: 0 {{.*}} + // inside the child task, set a value + await Task.withLocal(\.number, boundTo: 1) { + try! await printTaskLocal(\.number) // CHECK: NumberKey: 1 {{.*}} + } + try! await printTaskLocal(\.number) // CHECK: NumberKey: 0 {{.*}} + return await Task.local(\.number) // 0 + } + + return try! await group.next()! + } + assert(x1 == 0) + + // value in parent and in groups + await Task.withLocal(\.number, boundTo: 2) { + try! await printTaskLocal(\.number) // CHECK: NumberKey: 2 {{.*}} + + let x2: Int = try! await Task.withGroup(resultType: Int.self) { group in + try! await printTaskLocal(\.number) // CHECK: NumberKey: 2 {{.*}} + await group.add { + try! await printTaskLocal(\.number) // CHECK: NumberKey: 2 {{.*}} + return await Task.local(\.number) + } + try! await printTaskLocal(\.number) // CHECK: NumberKey: 2 {{.*}} + + return try! await group.next()! + } + + assert(x2 == 2) + } +} + +@main struct Main { + static func main() async { + await groups() + } +} diff --git a/test/Concurrency/Runtime/task_locals_inherit_never.swift b/test/Concurrency/Runtime/task_locals_inherit_never.swift new file mode 100644 index 0000000000000..a17f6978ab02b --- /dev/null +++ b/test/Concurrency/Runtime/task_locals_inherit_never.swift @@ -0,0 +1,82 @@ +// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-concurrency -parse-as-library %import-libdispatch) | %FileCheck %s + +// REQUIRES: executable_test +// REQUIRES: concurrency +// REQUIRES: libdispatch + +class StringLike: CustomStringConvertible { + let value: String + init(_ value: String) { + self.value = value + } + + var description: String { value } +} + +func printTaskLocal( + _ key: KeyPath, + _ expected: Key.Value? = nil, + file: String = #file, line: UInt = #line +) async where Key: TaskLocalKey { + let value = await Task.local(key) + print("\(Key.self): \(value) at \(file):\(line)") + if let expected = expected { + assert("\(expected)" == "\(value)", + "Expected [\(expected)] but found: \(value), at \(file):\(line)") + } +} + +extension TaskLocalValues { + + struct StringKey: TaskLocalKey { + static var defaultValue: String { .init("") } + static var inherit: TaskLocalInheritance { .never } + } + var string: StringKey { .init() } + +} + +// ==== ------------------------------------------------------------------------ + +func test_async_let() async { + // CHECK: StringKey: {{.*}} + await printTaskLocal(\.string) + await Task.withLocal(\.string, boundTo: "top") { + // CHECK: StringKey: top {{.*}} + await printTaskLocal(\.string) + + // CHECK: StringKey: {{.*}} + async let child: () = printTaskLocal(\.string) + await child + + // CHECK: StringKey: top {{.*}} + await printTaskLocal(\.string) + } +} + +func pending_async_group() async { + print("SKIPPED: \(#function)") // FIXME: unlock once https://github.com/apple/swift/pull/35874 is merged + return + + // COM: CHECK: test_async_group + print(#function) + + // COM: CHECK: StringKey: {{.*}} + await printTaskLocal(\.string) + await Task.withLocal(\.string, boundTo: "top") { + // COM: CHECK: StringKey: top {{.*}} + await printTaskLocal(\.string) + + try! await Task.withGroup(resultType: String.self) { group -> Void in + // COM: CHECK: StringKey: top {{.*}} + await printTaskLocal(\.string) + } + } +} + +@main struct Main { + static func main() async { + await test_async_let() + await pending_async_group() + } +}