diff --git a/.gitignore b/.gitignore index de1d236..2828314 100644 --- a/.gitignore +++ b/.gitignore @@ -340,3 +340,4 @@ src/RemoteNET/publish.ps1 src/detours_build/ src/ConsoleApp/ +src/RemoteNET/Resources/ diff --git a/README.md b/README.md index d663817..9d1d83f 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,46 @@ The limitations: 2. The callback must define the exact number of parameters for that event 3. Lambda expression are not allowed. The callback must be cast to an `Action<...>`. +### ✳️ Registering Custom Functions (Unmanaged/C++ Targets Only) +For unmanaged (MSVC C++) targets, you can register custom functions that aren't automatically discovered by the RTTI scanner. +This is useful when you know the address of a function in the target process and want to call it through RemoteNET. + +**Important:** The type must be a `RemoteRttiType` instance that was fetched from the remote process. Any existing instances of that type will be automatically updated with the new method. + +```C# +// Connect to an unmanaged target +UnmanagedRemoteApp unmanagedApp = (UnmanagedRemoteApp)RemoteAppFactory.Connect("MyNativeTarget.exe", RuntimeType.Unmanaged); + +// Get a remote type (this creates a RemoteRttiType) +Type remoteType = unmanagedApp.GetRemoteType("MyNamespace::MyClass", "MyModule.dll"); + +// Register a custom function on the type - returns the MethodInfo or null on error +MethodInfo customMethod = unmanagedApp.RegisterCustomFunction( + parentType: remoteType, + functionName: "MyCustomFunction", + moduleName: "MyModule.dll", + offset: 0x1234, // Offset from module base address + returnType: typeof(int), + parameterTypes: new[] { typeof(int), typeof(float) } +); + +if (customMethod != null) +{ + // After registration, the function can be invoked like any other remote method + // All existing instances of this type will have the new method available + dynamic dynamicObj = remoteObject.Dynamify(); + int result = dynamicObj.MyCustomFunction(42, 3.14f); +} +``` + +**Notes:** +- This feature is only available for unmanaged (C++) targets +- The parent type must be a RemoteRttiType (obtained via `UnmanagedRemoteApp.GetRemoteType()`) +- You need to know the module name and offset where the function is located +- The function will be added to the type's method list and available on all instances +- Parameter and return types should be specified as .NET types +- Returns the `MethodInfo` for the registered method, or `null` if registration fails + ## TODOs 1. Static members 2. Document "Reflection API" (RemoteType, RemoteMethodInfo, ... ) diff --git a/src/RemoteNET.Tests/RttiTypesFactoryTests.cs b/src/RemoteNET.Tests/RttiTypesFactoryTests.cs index a2bce320..26b3527 100644 --- a/src/RemoteNET.Tests/RttiTypesFactoryTests.cs +++ b/src/RemoteNET.Tests/RttiTypesFactoryTests.cs @@ -260,7 +260,7 @@ public void AddFunctionImpl_DifferentDeclaringClassOnFunc_DifferentDeclaringType RemoteApp? fakeApp = new FakeRemoteApp(); // Act - RttiTypesFactory.AddFunctionImpl(fakeApp, typeDump, func, childType, false); + RttiTypesFactory.AddFunctionImpl(fakeApp, typeDump.Assembly, func, childType, false); // Assert MethodInfo? method = childType.GetMethods().Single(); @@ -291,7 +291,7 @@ public void AddFunctionImpl_SameDeclaringClassOnFunc_SameDeclaringType() RemoteApp? fakeApp = new FakeRemoteApp(); // Act - RttiTypesFactory.AddFunctionImpl(fakeApp, typeDump, func, childType, false); + RttiTypesFactory.AddFunctionImpl(fakeApp, typeDump.Assembly, func, childType, false); // Assert MethodInfo? method = childType.GetMethods().Single(); @@ -345,7 +345,7 @@ public void UndecoratingConstRef_ParseType_NoMethod() RemoteApp? fakeApp = new FakeRemoteApp(); // Act - RttiTypesFactory.AddFunctionImpl(fakeApp, typeDump, func, childType, false); + RttiTypesFactory.AddFunctionImpl(fakeApp, typeDump.Assembly, func, childType, false); // Assert // Expecting `AddFunctionImpl` to NOT add that function (not supported yet) diff --git a/src/RemoteNET/Internal/Reflection/Rtti/RttiTypesFactory.cs b/src/RemoteNET/Internal/Reflection/Rtti/RttiTypesFactory.cs index 72abf21..d383e5c 100644 --- a/src/RemoteNET/Internal/Reflection/Rtti/RttiTypesFactory.cs +++ b/src/RemoteNET/Internal/Reflection/Rtti/RttiTypesFactory.cs @@ -164,18 +164,16 @@ private static void AddGroupOfFunctions(RemoteApp app, TypeDump typeDump, List parameters = new List(func.Parameters.Count); int i = 1; foreach (TypeDump.TypeMethod.MethodParameter restarizedParameter in func.Parameters) @@ -211,6 +209,7 @@ public static void AddFunctionImpl(RemoteApp app, TypeDump typeDump, TypeDump.Ty RemoteRttiConstructorInfo ctorInfo = new RemoteRttiConstructorInfo(declaringTypeResolver, parameters.ToArray()); declaringType.AddConstructor(ctorInfo); + return ctorInfo; } else { @@ -232,6 +231,7 @@ public static void AddFunctionImpl(RemoteApp app, TypeDump typeDump, TypeDump.Ty new RemoteRttiMethodInfo(declaringTypeResolver, returnTypeResolver, func.Name, mangledName, parameters.ToArray(), (MethodAttributes)func.Attributes); declaringType.AddMethod(methodInfo); + return methodInfo; } Lazy CreateTypeFactory(string namespaceAndTypeName, string moduleName) @@ -253,7 +253,7 @@ Lazy CreateTypeFactory(string namespaceAndTypeName, string moduleName) if (_shittyCache.TryGetValue(namespaceAndTypeName, out resultType)) return resultType; - resultType = RttiTypesResolver.Instance.Resolve(typeDump.Assembly, $"{typeDump.Assembly}!{namespaceAndTypeName}"); + resultType = RttiTypesResolver.Instance.Resolve(moduleName, $"{moduleName}!{namespaceAndTypeName}"); if (resultType != null) return resultType; @@ -282,7 +282,7 @@ Lazy CreateTypeFactory(string namespaceAndTypeName, string moduleName) // Prefer any matches in the existing assembly var paramTypeInSameAssembly = - possibleParamTypes.Where(t => t.Assembly == typeDump.Assembly).ToArray(); + possibleParamTypes.Where(t => t.Assembly == moduleName).ToArray(); if (paramTypeInSameAssembly.Length > 0) { if (paramTypeInSameAssembly.Length > 1) diff --git a/src/RemoteNET/UnmanagedRemoteApp.cs b/src/RemoteNET/UnmanagedRemoteApp.cs index d5020a6..f9c2296 100644 --- a/src/RemoteNET/UnmanagedRemoteApp.cs +++ b/src/RemoteNET/UnmanagedRemoteApp.cs @@ -2,10 +2,12 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Reflection; using RemoteNET.Common; using RemoteNET.Internal; using RemoteNET.RttiReflection; using ScubaDiver.API; +using ScubaDiver.API.Interactions; using ScubaDiver.API.Interactions.Dumps; using ScubaDiver.API.Utils; @@ -215,6 +217,108 @@ public override bool InjectDll(string path) return res; } + // + // Custom Functions + // + + /// + /// Registers a custom function on a remote type for unmanaged targets + /// + /// The type to add the function to (must be a RemoteRttiType) + /// Name of the function + /// Module name where the function is located (e.g., "MyModule.dll") + /// Offset within the module where the function is located + /// Return type of the function + /// Parameter types of the function + /// The MethodInfo for the registered function, or null if registration failed + public MethodInfo RegisterCustomFunction( + Type parentType, + string functionName, + string moduleName, + ulong offset, + Type returnType, + params Type[] parameterTypes) + { + if (parentType == null) + throw new ArgumentNullException(nameof(parentType)); + if (string.IsNullOrEmpty(functionName)) + throw new ArgumentException("Function name cannot be null or empty", nameof(functionName)); + if (string.IsNullOrEmpty(moduleName)) + throw new ArgumentException("Module name cannot be null or empty", nameof(moduleName)); + + // Verify the type is a RemoteRttiType + if (!(parentType is RemoteRttiType rttiType)) + { + throw new ArgumentException("The parent type must be a RemoteRttiType. Only remote RTTI types created by UnmanagedRemoteApp can have custom functions registered.", nameof(parentType)); + } + + static RegisterCustomFunctionRequest.ParameterTypeInfo CreateParamInfo(Type pt, int idx) + { + var typeName = pt?.FullName ?? "void"; + var assembly = pt.Assembly?.GetName()?.Name; + + if (pt.IsPrimitive) + { + typeName = pt.Name; + assembly = "mscorlib"; + + if (pt == typeof(ulong)) + { + typeName = "ulong"; + } + } + + return new RegisterCustomFunctionRequest.ParameterTypeInfo + { + Name = $"param{idx}", + TypeFullName = typeName, + Assembly = assembly + }; + } + + var retTypeName = returnType?.FullName ?? "void"; + var retTypeAssembly = returnType?.Assembly?.GetName()?.Name; + if (returnType != null && returnType.IsPrimitive) + { + retTypeName = returnType.Name; + retTypeAssembly = "mscorlib"; + if (returnType == typeof(ulong)) + { + retTypeName = "ulong"; + } + } + + var request = new RegisterCustomFunctionRequest + { + ParentTypeFullName = rttiType.Namespace + "::" + rttiType.Name, // Use namespace::name format for RTTI types + ParentAssembly = rttiType.Assembly?.GetName()?.Name, + FunctionName = functionName, + ModuleName = moduleName, + Offset = offset, + ReturnTypeFullName = retTypeName, + ReturnTypeAssembly = retTypeAssembly, + Parameters = parameterTypes?.Select(CreateParamInfo).ToList() ?? new List() + }; + + bool success = _unmanagedCommunicator.RegisterCustomFunction(request, out var methodDump); + + if (!success || methodDump == null) + { + return null; + } + + // Create a TypeDump for the parent type to pass to AddFunctionImpl + TypeDump parentTypeDump = new TypeDump + { + Assembly = rttiType.Assembly?.GetName()?.Name, + FullTypeName = rttiType.Namespace + "::" + rttiType.Name + }; + + // Use the existing factory method to add the function to the RemoteRttiType + // AddFunctionImpl returns the newly created MethodInfo + return RttiTypesFactory.AddFunctionImpl(this, parentTypeDump.Assembly, methodDump, rttiType, areConstructors: false) as MethodInfo; + } + // // IDisposable // diff --git a/src/ScubaDiver.API.Tests/RegisterCustomFunctionTests.cs b/src/ScubaDiver.API.Tests/RegisterCustomFunctionTests.cs new file mode 100644 index 0000000..47a6803 --- /dev/null +++ b/src/ScubaDiver.API.Tests/RegisterCustomFunctionTests.cs @@ -0,0 +1,94 @@ +using ScubaDiver.API.Interactions; +using ScubaDiver.API.Interactions.Dumps; +using System.Collections.Generic; + +namespace ScubaDiver.API.Tests +{ + [TestFixture] + public class RegisterCustomFunctionTests + { + [Test] + public void RegisterCustomFunctionRequest_ValidData_SetsPropertiesCorrectly() + { + // Arrange + var request = new RegisterCustomFunctionRequest + { + ParentTypeFullName = "MyNamespace::MyClass", + ParentAssembly = "MyModule.dll", + FunctionName = "MyCustomFunction", + ModuleName = "MyModule.dll", + Offset = 0x1234, + ReturnTypeFullName = "int", + ReturnTypeAssembly = "System.Private.CoreLib", + Parameters = new List + { + new RegisterCustomFunctionRequest.ParameterTypeInfo + { + Name = "param1", + TypeFullName = "int", + Assembly = "System.Private.CoreLib" + }, + new RegisterCustomFunctionRequest.ParameterTypeInfo + { + Name = "param2", + TypeFullName = "float", + Assembly = "System.Private.CoreLib" + } + } + }; + + // Assert + Assert.That(request.ParentTypeFullName, Is.EqualTo("MyNamespace::MyClass")); + Assert.That(request.ParentAssembly, Is.EqualTo("MyModule.dll")); + Assert.That(request.FunctionName, Is.EqualTo("MyCustomFunction")); + Assert.That(request.ModuleName, Is.EqualTo("MyModule.dll")); + Assert.That(request.Offset, Is.EqualTo(0x1234)); + Assert.That(request.ReturnTypeFullName, Is.EqualTo("int")); + Assert.That(request.Parameters.Count, Is.EqualTo(2)); + Assert.That(request.Parameters[0].Name, Is.EqualTo("param1")); + Assert.That(request.Parameters[0].TypeFullName, Is.EqualTo("int")); + Assert.That(request.Parameters[1].Name, Is.EqualTo("param2")); + Assert.That(request.Parameters[1].TypeFullName, Is.EqualTo("float")); + } + + [Test] + public void RegisterCustomFunctionResponse_Success_SetsPropertiesCorrectly() + { + // Arrange + var methodDump = new TypeDump.TypeMethod + { + Name = "TestMethod", + ReturnTypeFullName = "int" + }; + var response = new RegisterCustomFunctionResponse + { + Success = true, + ErrorMessage = null, + RegisteredMethod = methodDump + }; + + // Assert + Assert.That(response.Success, Is.True); + Assert.That(response.ErrorMessage, Is.Null); + Assert.That(response.RegisteredMethod, Is.Not.Null); + Assert.That(response.RegisteredMethod.Name, Is.EqualTo("TestMethod")); + } + + [Test] + public void RegisterCustomFunctionResponse_Failure_SetsPropertiesCorrectly() + { + // Arrange + var response = new RegisterCustomFunctionResponse + { + Success = false, + ErrorMessage = "Failed to register custom function", + RegisteredMethod = null + }; + + // Assert + Assert.That(response.Success, Is.False); + Assert.That(response.ErrorMessage, Is.EqualTo("Failed to register custom function")); + Assert.That(response.RegisteredMethod, Is.Null); + } + } +} diff --git a/src/ScubaDiver.API/DiverCommunicator.cs b/src/ScubaDiver.API/DiverCommunicator.cs index 070773e..df244c8 100644 --- a/src/ScubaDiver.API/DiverCommunicator.cs +++ b/src/ScubaDiver.API/DiverCommunicator.cs @@ -537,6 +537,32 @@ public void UnhookMethod(LocalHookCallback callback) public delegate (bool voidReturnType, ObjectOrRemoteAddress res) LocalEventCallback(ObjectOrRemoteAddress[] args, ObjectOrRemoteAddress retVal); + public bool RegisterCustomFunction(RegisterCustomFunctionRequest request) + { + return RegisterCustomFunction(request, out _); + } + + public bool RegisterCustomFunction(RegisterCustomFunctionRequest request, out TypeDump.TypeMethod methodDump) + { + methodDump = null; + var requestJsonBody = JsonConvert.SerializeObject(request); + var resJson = SendRequest("register_custom_function", null, requestJsonBody); + + try + { + RegisterCustomFunctionResponse response = JsonConvert.DeserializeObject(resJson, _withErrors); + if (response?.Success == true) + { + methodDump = response.RegisteredMethod; + } + return response?.Success ?? false; + } + catch (Exception ex) + { + throw new Exception($"Failed to register custom function. Error: {ex.Message}", ex); + } + } + public void Dispose() { if (_httpClient != null) diff --git a/src/ScubaDiver.API/Interactions/RegisterCustomFunctionRequest.cs b/src/ScubaDiver.API/Interactions/RegisterCustomFunctionRequest.cs new file mode 100644 index 0000000..d3d9d8d --- /dev/null +++ b/src/ScubaDiver.API/Interactions/RegisterCustomFunctionRequest.cs @@ -0,0 +1,57 @@ +using System.Collections.Generic; + +namespace ScubaDiver.API.Interactions +{ + /// + /// Request to register a custom function on a type (primarily for unmanaged targets) + /// + public class RegisterCustomFunctionRequest + { + /// + /// Full type name of the parent type that this function belongs to + /// + public string ParentTypeFullName { get; set; } + + /// + /// Assembly name of the parent type + /// + public string ParentAssembly { get; set; } + + /// + /// Name of the function to register + /// + public string FunctionName { get; set; } + + /// + /// Module name where the function is located (e.g., "MyModule.dll") + /// + public string ModuleName { get; set; } + + /// + /// Offset within the module where the function is located + /// + public ulong Offset { get; set; } + + /// + /// Full type name of the return type + /// + public string ReturnTypeFullName { get; set; } + + /// + /// Assembly of the return type + /// + public string ReturnTypeAssembly { get; set; } + + /// + /// List of parameter types (full type names) + /// + public List Parameters { get; set; } + + public class ParameterTypeInfo + { + public string Name { get; set; } + public string TypeFullName { get; set; } + public string Assembly { get; set; } + } + } +} diff --git a/src/ScubaDiver.API/Interactions/RegisterCustomFunctionResponse.cs b/src/ScubaDiver.API/Interactions/RegisterCustomFunctionResponse.cs new file mode 100644 index 0000000..2c7ebca --- /dev/null +++ b/src/ScubaDiver.API/Interactions/RegisterCustomFunctionResponse.cs @@ -0,0 +1,25 @@ +using ScubaDiver.API.Interactions.Dumps; + +namespace ScubaDiver.API.Interactions +{ + /// + /// Response for a custom function registration request + /// + public class RegisterCustomFunctionResponse + { + /// + /// Whether the registration was successful + /// + public bool Success { get; set; } + + /// + /// Error message if registration failed + /// + public string ErrorMessage { get; set; } + + /// + /// The method dump for the registered function (when successful) + /// + public TypeDump.TypeMethod RegisteredMethod { get; set; } + } +} diff --git a/src/ScubaDiver/DiverBase.cs b/src/ScubaDiver/DiverBase.cs index c49d465..6bc0b1e 100644 --- a/src/ScubaDiver/DiverBase.cs +++ b/src/ScubaDiver/DiverBase.cs @@ -62,6 +62,8 @@ public DiverBase(IRequestsListener listener) // Hooking {"/hook_method", MakeHookMethodResponse}, {"/unhook_method", MakeUnhookMethodResponse}, + // Custom Functions + {"/register_custom_function", MakeRegisterCustomFunctionResponse}, }; _remoteHooks = new ConcurrentDictionary(); } @@ -128,6 +130,13 @@ public string QuickError(string error, string stackTrace = null) private void HandleDispatchedRequest(object obj, ScubaDiverMessage request) { + // Check if the "debug" query parameter is set. If so, launch the debugger + if (request.QueryString.Get("debug") == "1") + { + Logger.Debug("[DiverBase] Debugging enabled"); + Debugger.Launch(); + } + Stopwatch sw = Stopwatch.StartNew(); string body; if (_responseBodyCreators.TryGetValue(request.UrlAbsolutePath, out var respBodyGenerator)) @@ -370,6 +379,7 @@ private string MakeLaunchDebuggerResponse(ScubaDiverMessage arg) protected abstract string MakeSetFieldResponse(ScubaDiverMessage arg); protected abstract string MakeArrayItemResponse(ScubaDiverMessage arg); protected abstract string MakeUnpinResponse(ScubaDiverMessage arg); + protected abstract string MakeRegisterCustomFunctionResponse(ScubaDiverMessage arg); private string MakeDieResponse(ScubaDiverMessage req) { diff --git a/src/ScubaDiver/DotNetDiver.cs b/src/ScubaDiver/DotNetDiver.cs index eb5842a..3eeda4f 100644 --- a/src/ScubaDiver/DotNetDiver.cs +++ b/src/ScubaDiver/DotNetDiver.cs @@ -1531,6 +1531,18 @@ protected override string MakeUnpinResponse(ScubaDiverMessage arg) } } + protected override string MakeRegisterCustomFunctionResponse(ScubaDiverMessage arg) + { + // Custom function registration is not supported for managed (.NET) targets + // This feature is only available for unmanaged (native C++) targets + RegisterCustomFunctionResponse response = new RegisterCustomFunctionResponse + { + Success = false, + ErrorMessage = "Custom function registration is not supported for managed (.NET) targets" + }; + return JsonConvert.SerializeObject(response); + } + // IDisposable public override void Dispose() { diff --git a/src/ScubaDiver/MsvcDiver.cs b/src/ScubaDiver/MsvcDiver.cs index 0f02308..f39325f 100644 --- a/src/ScubaDiver/MsvcDiver.cs +++ b/src/ScubaDiver/MsvcDiver.cs @@ -212,7 +212,6 @@ protected override ObjectOrRemoteAddress InvokeEventCallback(IPEndPoint callback protected override HookResponse InvokeHookCallback(IPEndPoint callbacksEndpoint, int token, string stackTrace, object retValue, params object[] parameters) { - Logger.Debug($"[{nameof(MsvcDiver)}] InvokeHookCallback Entered. EndPoint: {callbacksEndpoint} Token: {token}"); ReverseCommunicator reverseCommunicator = new(callbacksEndpoint); ObjectOrRemoteAddress[] remoteParams = new ObjectOrRemoteAddress[parameters.Length]; @@ -296,7 +295,7 @@ protected override string MakeTypesResponse(ScubaDiverMessage req) string typeFilter = req.QueryString.Get("type_filter"); if (string.IsNullOrWhiteSpace(typeFilter)) - return QuickError("Missing parameter 'type_filter'"); + return QuickError("Missing parameter 'type_filter'. Try this: /types?type_filter=*"); ParseFullTypeName(typeFilter, out var assemblyFilter, out typeFilter); Predicate typeFilterPredicate = Filter.CreatePredicate(typeFilter); @@ -339,7 +338,6 @@ protected override string MakeTypeResponse(ScubaDiverMessage req) { return QuickError("Failed to deserialize body"); } - Logger.Debug($"[MsvcDiver][MakeTypeResponse] Resolving type Name: {request.Assembly} {request.TypeFullName} vftable: 0x{request.MethodTableAddress:x16}"); TypeDump dump; if (request.MethodTableAddress != 0) @@ -499,7 +497,6 @@ protected override string MakeObjectResponse(ScubaDiverMessage arg) // Check if the object is already frozen if (_freezer.IsFrozen(objAddr)) { - Logger.Debug($"[MsvcDiver][MakeObjectResponse] Object at 0x{objAddr:X16} is already frozen."); ObjectDump alreadyFrozenObjDump = new ObjectDump() { Type = fullTypeName, @@ -556,7 +553,6 @@ protected override string MakeCreateObjectResponse(ScubaDiverMessage arg) protected override string MakeInvokeResponse(ScubaDiverMessage arg) { - Console.WriteLine($"[{nameof(MsvcDiver)}] MakeInvokeResponse Entered (!)"); if (string.IsNullOrEmpty(arg.Body)) return QuickError("Missing body"); @@ -865,6 +861,87 @@ protected override string MakeUnpinResponse(ScubaDiverMessage arg) return QuickError("Not Implemented"); } + protected override string MakeRegisterCustomFunctionResponse(ScubaDiverMessage arg) + { + string body = arg.Body; + if (string.IsNullOrEmpty(body)) + { + return QuickError("Missing body"); + } + + var request = JsonConvert.DeserializeObject(body); + if (request == null) + { + return QuickError("Failed to deserialize body"); + } + + Logger.Debug($"[MsvcDiver][MakeRegisterCustomFunctionResponse] Registering custom function: {request.FunctionName} on type {request.ParentTypeFullName}"); + + try + { + // Extract parameter type names from the request + string[] argTypeFullNames = request.Parameters?.Select(p => p.TypeFullName).ToArray() ?? Array.Empty(); + + bool success = _typesManager.RegisterCustomFunction( + request.ParentTypeFullName, + request.ParentAssembly, + request.FunctionName, + request.ModuleName, + request.Offset, + request.ReturnTypeFullName, + argTypeFullNames); + + if (!success) + { + RegisterCustomFunctionResponse failResponse = new RegisterCustomFunctionResponse + { + Success = false, + ErrorMessage = "Failed to register custom function" + }; + return JsonConvert.SerializeObject(failResponse); + } + + // Create a TypeMethod dump from the registered function + TypeDump.TypeMethod methodDump = new TypeDump.TypeMethod + { + Name = request.FunctionName, + DecoratedName = request.FunctionName, + UndecoratedFullName = request.FunctionName, + ReturnTypeFullName = request.ReturnTypeFullName, + ReturnTypeAssembly = request.ReturnTypeAssembly, + ReturnTypeName = request.ReturnTypeFullName, + Visibility = "Public", + Attributes = 0, + Parameters = request.Parameters?.Select((p, idx) => new TypeDump.TypeMethod.MethodParameter + { + Name = p.Name ?? $"a{idx}", + FullTypeName = p.TypeFullName, + TypeName = p.TypeFullName, + Assembly = p.Assembly + }).ToList() ?? new List() + }; + + RegisterCustomFunctionResponse response = new RegisterCustomFunctionResponse + { + Success = true, + ErrorMessage = null, + RegisteredMethod = methodDump + }; + + return JsonConvert.SerializeObject(response); + } + catch (Exception ex) + { + Logger.Debug($"[MsvcDiver][MakeRegisterCustomFunctionResponse] Exception: {ex}"); + RegisterCustomFunctionResponse response = new RegisterCustomFunctionResponse + { + Success = false, + ErrorMessage = ex.Message + }; + return JsonConvert.SerializeObject(response); + } + } + public override void Dispose() { } diff --git a/src/ScubaDiver/MsvcPrimitives/CustomUndecoratedFunction.cs b/src/ScubaDiver/MsvcPrimitives/CustomUndecoratedFunction.cs new file mode 100644 index 0000000..00edad8 --- /dev/null +++ b/src/ScubaDiver/MsvcPrimitives/CustomUndecoratedFunction.cs @@ -0,0 +1,42 @@ +using Microsoft.Diagnostics.Runtime.AbstractDac; +using ScubaDiver.Rtti; +using System; + +namespace ScubaDiver +{ + /// + /// Represents a custom user-defined function that can be registered on a type + /// + public class CustomUndecoratedFunction : UndecoratedFunction + { + private readonly ModuleInfo _module; + private readonly nuint _address; + private readonly string _retType; + private readonly string[] _argTypes; + + public CustomUndecoratedFunction( + ModuleInfo module, + ulong offset, + string functionName, + string returnType, + string[] argTypes) + : base(functionName, functionName, functionName, argTypes?.Length) + { + _module = module; + + // Check for potential overflow when adding baseAddress and offset + checked + { + _address = module.BaseAddress + (nuint)offset; + } + + _retType = returnType; + _argTypes = argTypes ?? Array.Empty(); + } + + public override ModuleInfo Module => _module; + public override nuint Address => _address; + public override string RetType => _retType; + public override string[] ArgTypes => _argTypes; + } +} diff --git a/src/ScubaDiver/MsvcPrimitives/MsvcTypesManager.cs b/src/ScubaDiver/MsvcPrimitives/MsvcTypesManager.cs index 0b4d22b..8a27c02 100644 --- a/src/ScubaDiver/MsvcPrimitives/MsvcTypesManager.cs +++ b/src/ScubaDiver/MsvcPrimitives/MsvcTypesManager.cs @@ -1,4 +1,5 @@ using NtApiDotNet.Win32; +using ScubaDiver.API.Utils; using ScubaDiver.Rtti; using System; using System.Collections.Generic; @@ -135,6 +136,7 @@ public class MsvcType : Type private MsvcModule _module; private MsvcMethod[] _methods; private VftableInfo[] _vftables; + private List _customMethods = new List(); public Rtti.TypeInfo TypeInfo { get; set; } public MsvcType(MsvcModule module, Rtti.TypeInfo typeInfo) @@ -152,6 +154,11 @@ public void SetVftables(VftableInfo[] vftables) _vftables = vftables; } + public void AddCustomMethod(MsvcMethod method) + { + _customMethods.Add(method); + } + public override MsvcModule Module => _module; public override string Name => TypeInfo.Name; @@ -159,7 +166,12 @@ public void SetVftables(VftableInfo[] vftables) public override string FullName => TypeInfo.FullTypeName; - public override MsvcMethod[] GetMethods(BindingFlags bindingAttr) => _methods; + public override MsvcMethod[] GetMethods(BindingFlags bindingAttr) + { + if (_methods == null) + return _customMethods.ToArray(); + return _methods.Concat(_customMethods).ToArray(); + } public new MsvcMethod[] GetMethods() => GetMethods(0); public VftableInfo[] GetVftables() => _vftables; public override MemberInfo[] GetMembers(BindingFlags bindingAttr) => [.. GetVftables(), .. GetMethods()]; @@ -467,6 +479,65 @@ bool IsNotExport(ulong addr) return res; } } + + /// + /// Registers a custom function on a type + /// + public bool RegisterCustomFunction( + string parentTypeFullName, + string parentAssembly, + string functionName, + string moduleName, + ulong offset, + string returnTypeFullName, + string[] argTypeFullNames) + { + try + { + // Get the parent type + Predicate moduleFilter = Filter.CreatePredicate(parentAssembly); + Predicate typeFilter = Filter.CreatePredicate(parentTypeFullName); + MsvcTypeStub typeStub = GetType(moduleFilter, typeFilter); + if (typeStub == null) + return false; + + MsvcType parentType = typeStub.Upgrade(); + if (parentType == null) + return false; + + // Find the module base address + List modules = GetUndecoratedModules(Filter.CreatePredicate(moduleName)); + if (modules.Count == 0) + { + modules = GetUndecoratedModules(Filter.CreatePredicate(moduleName + ".dll")); + if (modules.Count == 0) + return false; + } + + // Since we've verified the list is not empty, we can safely access the first element + UndecoratedModule targetModule = modules[0]; + + // Create a custom undecorated function + CustomUndecoratedFunction customFunc = new CustomUndecoratedFunction( + targetModule.ModuleInfo, + offset, + functionName, + returnTypeFullName, + argTypeFullNames); + + // Create an MsvcMethod from the custom function and add it to the type + MsvcMethod customMethod = new MsvcMethod(parentType, customFunc); + parentType.AddCustomMethod(customMethod); + + return true; + } + catch (Exception ex) + { + // Log the exception for debugging purposes + Logger.Debug($"[MsvcTypesManager][RegisterCustomFunction] Failed to register custom function. Exception: {ex}"); + return false; + } + } } public class MsvcModuleExports diff --git a/src/ScubaDiver/project_net6_x64/ScubaDiver_Net6_x64.csproj b/src/ScubaDiver/project_net6_x64/ScubaDiver_Net6_x64.csproj index cb1445b..877fc3b 100644 --- a/src/ScubaDiver/project_net6_x64/ScubaDiver_Net6_x64.csproj +++ b/src/ScubaDiver/project_net6_x64/ScubaDiver_Net6_x64.csproj @@ -49,6 +49,7 @@ + diff --git a/src/ScubaDiver/project_net6_x86/ScubaDiver_Net6_x86.csproj b/src/ScubaDiver/project_net6_x86/ScubaDiver_Net6_x86.csproj index c99e5cd..fc1a424 100644 --- a/src/ScubaDiver/project_net6_x86/ScubaDiver_Net6_x86.csproj +++ b/src/ScubaDiver/project_net6_x86/ScubaDiver_Net6_x86.csproj @@ -49,6 +49,7 @@ +