diff --git a/include/swift/Remote/CMemoryReader.h b/include/swift/Remote/CMemoryReader.h index da4a5274f6dc2..4408bcb046c51 100644 --- a/include/swift/Remote/CMemoryReader.h +++ b/include/swift/Remote/CMemoryReader.h @@ -42,38 +42,16 @@ namespace remote { class CMemoryReader final : public MemoryReader { MemoryReaderImpl Impl; - uint64_t ptrauthMask; - - uint64_t getPtrauthMask() { - if (ptrauthMask == 0) { - int success; - if (Impl.PointerSize == 4) { - uint32_t ptrauthMask32 = 0; - success = queryDataLayout(DataLayoutQueryType::DLQ_GetPtrAuthMask, - nullptr, &ptrauthMask32); - ptrauthMask = ptrauthMask32; - } else if (Impl.PointerSize == 8) { - success = queryDataLayout(DataLayoutQueryType::DLQ_GetPtrAuthMask, - nullptr, &ptrauthMask); - } else { - success = 0; - } - - if (!success) - ptrauthMask = ~0ull; - } - return ptrauthMask; - } - // Check to see if an address has bits outside the ptrauth mask. This suggests // that we're likely failing to strip a signed pointer when reading from it. bool hasSignatureBits(RemoteAddress address) { uint64_t addressData = address.getAddressData(); - return addressData != (addressData & getPtrauthMask()); + uint64_t mask = getPtrAuthMask().value_or(~uint64_t(0)); + return addressData != (addressData & mask); } public: - CMemoryReader(MemoryReaderImpl Impl) : Impl(Impl), ptrauthMask(0) { + CMemoryReader(MemoryReaderImpl Impl) : Impl(Impl) { assert(this->Impl.queryDataLayout && "No queryDataLayout implementation"); assert(this->Impl.getStringLength && "No stringLength implementation"); assert(this->Impl.readBytes && "No readBytes implementation"); diff --git a/include/swift/Remote/MemoryReader.h b/include/swift/Remote/MemoryReader.h index 772bcfbdfbc3e..5881dff884533 100644 --- a/include/swift/Remote/MemoryReader.h +++ b/include/swift/Remote/MemoryReader.h @@ -22,6 +22,7 @@ #include "swift/SwiftRemoteMirror/MemoryReaderInterface.h" #include +#include #include #include #include @@ -37,7 +38,80 @@ namespace remote { /// This abstraction presents memory as if it were a read-only /// representation of the address space of a remote process. class MemoryReader { + uint8_t cachedPointerSize = 0; + uint8_t cachedSizeSize = 0; + uint64_t cachedPtrAuthMask = 0; + uint8_t cachedObjCReservedLowBits = 0xff; + uint64_t cachedLeastValidPointerValue = 0; + uint8_t cachedObjCInteropIsEnabled = 0xff; + +protected: + virtual bool queryDataLayout(DataLayoutQueryType type, void *inBuffer, + void *outBuffer) = 0; + public: + std::optional getPointerSize() { + if (cachedPointerSize == 0) { + if (!queryDataLayout(DLQ_GetPointerSize, nullptr, &cachedPointerSize)) + return std::nullopt; + } + return cachedPointerSize; + } + + std::optional getSizeSize() { + if (cachedSizeSize == 0) { + if (!queryDataLayout(DLQ_GetSizeSize, nullptr, &cachedSizeSize)) + return std::nullopt; + } + return cachedSizeSize; + } + + std::optional getPtrAuthMask() { + if (cachedPtrAuthMask == 0) { + auto ptrSize = getPointerSize(); + if (!ptrSize) + return std::nullopt; + + if (ptrSize.value() == sizeof(uint64_t)) { + if (!queryDataLayout(DLQ_GetPtrAuthMask, nullptr, &cachedPtrAuthMask)) + return std::nullopt; + } else if (ptrSize.value() == sizeof(uint32_t)) { + uint32_t mask; + if (!queryDataLayout(DLQ_GetPtrAuthMask, nullptr, &mask)) + return std::nullopt; + cachedPtrAuthMask = mask; + } + } + return cachedPtrAuthMask; + } + + std::optional getObjCReservedLowBits() { + if (cachedObjCReservedLowBits == 0xff) { + if (!queryDataLayout(DLQ_GetObjCReservedLowBits, nullptr, + &cachedObjCReservedLowBits)) + return std::nullopt; + } + return cachedObjCReservedLowBits; + } + + std::optional getLeastValidPointerValue() { + if (cachedLeastValidPointerValue == 0) { + if (!queryDataLayout(DLQ_GetLeastValidPointerValue, nullptr, + &cachedLeastValidPointerValue)) + return std::nullopt; + } + return cachedLeastValidPointerValue; + } + + std::optional getObjCInteropIsEnabled() { + if (cachedObjCInteropIsEnabled == 0xff) { + if (!queryDataLayout(DLQ_GetObjCInteropIsEnabled, nullptr, + &cachedObjCInteropIsEnabled)) + return std::nullopt; + } + return cachedObjCInteropIsEnabled; + } + /// A convenient name for the return type from readBytes. using ReadBytesResult = std::unique_ptr>; @@ -46,9 +120,6 @@ class MemoryReader { using ReadObjResult = std::unique_ptr>; - virtual bool queryDataLayout(DataLayoutQueryType type, void *inBuffer, - void *outBuffer) = 0; - /// Look up the given public symbol name in the remote process. virtual RemoteAddress getSymbolAddress(const std::string &name) = 0; @@ -209,50 +280,37 @@ class MemoryReader { // index (counting from 0). bool readHeapObjectExtraInhabitantIndex(RemoteAddress address, int *extraInhabitantIndex) { - uint8_t PointerSize; - if (!queryDataLayout(DataLayoutQueryType::DLQ_GetPointerSize, - nullptr, &PointerSize)) { - return false; - } - uint64_t LeastValidPointerValue; - if (!queryDataLayout(DataLayoutQueryType::DLQ_GetLeastValidPointerValue, - nullptr, &LeastValidPointerValue)) { - return false; - } - uint8_t ObjCReservedLowBits; - if (!queryDataLayout(DataLayoutQueryType::DLQ_GetObjCReservedLowBits, - nullptr, &ObjCReservedLowBits)) { + auto PointerSize = getPointerSize(); + auto LeastValidPointerValue = getLeastValidPointerValue(); + auto ObjCReservedLowBits = getObjCReservedLowBits(); + + if (!PointerSize || !LeastValidPointerValue || !ObjCReservedLowBits) return false; - } + uint64_t RawPointerValue; - if (!readInteger(address, PointerSize, &RawPointerValue)) { + if (!readInteger(address, PointerSize.value(), &RawPointerValue)) { return false; } - if (RawPointerValue >= LeastValidPointerValue) { + if (RawPointerValue >= LeastValidPointerValue.value()) { *extraInhabitantIndex = -1; // Valid value, not an XI } else { - *extraInhabitantIndex = (RawPointerValue >> ObjCReservedLowBits); + *extraInhabitantIndex = (RawPointerValue >> ObjCReservedLowBits.value()); } return true; } bool readFunctionPointerExtraInhabitantIndex(RemoteAddress address, int *extraInhabitantIndex) { - uint8_t PointerSize; - if (!queryDataLayout(DataLayoutQueryType::DLQ_GetPointerSize, - nullptr, &PointerSize)) { + auto PointerSize = getPointerSize(); + auto LeastValidPointerValue = getLeastValidPointerValue(); + if (!PointerSize || !LeastValidPointerValue) return false; - } - uint64_t LeastValidPointerValue; - if (!queryDataLayout(DataLayoutQueryType::DLQ_GetLeastValidPointerValue, - nullptr, &LeastValidPointerValue)) { - return false; - } + uint64_t RawPointerValue; - if (!readInteger(address, PointerSize, &RawPointerValue)) { + if (!readInteger(address, PointerSize.value(), &RawPointerValue)) { return false; } - if (RawPointerValue >= LeastValidPointerValue) { + if (RawPointerValue >= LeastValidPointerValue.value()) { *extraInhabitantIndex = -1; // Valid value, not an XI } else { *extraInhabitantIndex = RawPointerValue; diff --git a/include/swift/Remote/MetadataReader.h b/include/swift/Remote/MetadataReader.h index ee4dc8e0cdc16..a1ecaaab2a703 100644 --- a/include/swift/Remote/MetadataReader.h +++ b/include/swift/Remote/MetadataReader.h @@ -422,12 +422,8 @@ class MetadataReader { } StoredPointer queryPtrAuthMask() { - StoredPointer QueryResult; - if (Reader->queryDataLayout(DataLayoutQueryType::DLQ_GetPtrAuthMask, - nullptr, &QueryResult)) { - return QueryResult; - } - return ~StoredPointer(0); + auto QueryResult = Reader->getPtrAuthMask(); + return QueryResult.value_or(~StoredPointer(0)); } template diff --git a/lib/StaticMirror/ObjectFileContext.cpp b/lib/StaticMirror/ObjectFileContext.cpp index a2bca9bcc0d77..7172f6661e727 100644 --- a/lib/StaticMirror/ObjectFileContext.cpp +++ b/lib/StaticMirror/ObjectFileContext.cpp @@ -565,16 +565,18 @@ std::unique_ptr makeReflectionContextForObjectFiles( const std::vector &objectFiles, bool ObjCInterop) { auto Reader = std::make_shared(objectFiles); - uint8_t pointerSize; - Reader->queryDataLayout(DataLayoutQueryType::DLQ_GetPointerSize, nullptr, - &pointerSize); + auto pointerSize = Reader->getPointerSize(); + if (!pointerSize) { + fputs("unable to get target pointer size\n", stderr); + abort(); + } - switch (pointerSize) { + switch (pointerSize.value()) { case 4: #define MAKE_CONTEXT(INTEROP, PTRSIZE) \ makeReflectionContextForMetadataReader< \ External>>>(std::move(Reader), \ - pointerSize) + pointerSize.value()) #if SWIFT_OBJC_INTEROP if (ObjCInterop) return MAKE_CONTEXT(WithObjCInterop, 4);