Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions c10/core/impl/DeviceGuardImplInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,18 @@ class C10_API DeviceGuardImplRegistrar {
static ::c10::impl::DeviceGuardImplRegistrar C10_ANONYMOUS_VARIABLE(g_##DeviceType)(::c10::DeviceType::DevType, new DeviceGuardImpl());

inline const DeviceGuardImplInterface* getDeviceGuardImpl(DeviceType type) {
#if defined(__CUDACC__)
// Two adjacent int16_t fields DeviceType and DeviceIndex has field access
// miscompiled on NVCC. To workaround this issue, we apply a mask to the
// DeviceType. First check if the DeviceType is 16-bit.
// FB employees can see
// https://fb.workplace.com/groups/llvm.gcc/permalink/4053565044692080/
// for more details
static_assert(sizeof(DeviceType) == 2, "DeviceType is not 16-bit");
auto p = device_guard_impl_registry[static_cast<size_t>(type) & 0xFFFF].load();
#else
auto p = device_guard_impl_registry[static_cast<size_t>(type)].load();
#endif
// This seems to be the first place where you make use of a device
// when you pass devices to factory functions. Give a nicer error
// message in this case.
Expand Down