Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
351 changes: 166 additions & 185 deletions managed/src/SwiftlyS2.Core/Modules/Hooks/HookManager.cs
Original file line number Diff line number Diff line change
@@ -1,228 +1,209 @@
using System.Collections.Concurrent;
using System.Runtime.InteropServices;
using System.Runtime.CompilerServices;
using Spectre.Console;
using SwiftlyS2.Core.Natives;
using SwiftlyS2.Shared.Memory;

namespace SwiftlyS2.Core.Hooks;

[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
internal unsafe delegate void MidHookInternalDelegate( void* contextPtr );

internal class HookManager
{
private class HookNode
{
public required Guid Id { get; init; }
public nint HookHandle { get; set; }
public nint OriginalFuncPtr { get; set; }
public required Func<Func<nint>, Delegate> CallbackBuilder { get; init; }
public Delegate? BuiltDelegate { get; set; }
public nint BuiltPointer { get; set; }
}

private class HookNode
{
public required Guid Id { get; init; }

public nint HookHandle { get; set; }
public nint OriginalFuncPtr { get; set; }
public required Func<Func<nint>, Delegate> CallbackBuilder { get; init; }
public Delegate? BuiltDelegate { get; set; }
public nint BuiltPointer { get; set; }
}

private class MidHookNode
{
public required Guid Id { get; init; }
public nint HookHandle { get; set; }
public required MidHookDelegate BuiltDelegate { get; init; }
}

private class HookChain
{
public bool Hooked { get; set; } = false;
public required nint FunctionAddress { get; set; }
public nint HookHandle { get; set; }
public nint OriginalFunctionAddress { get; set; }
public List<HookNode> Nodes { get; } = new();
}

private class MidHookChain
{
public bool Hooked { get; set; } = false;
public required nint Address { get; set; }
public nint HookHandle { get; set; }
public List<MidHookNode> Nodes { get; } = new();
}

private readonly Lock _sync = new();
private readonly Dictionary<nint, HookChain> _chains = new();
private readonly Dictionary<nint, MidHookChain> _midChains = new();

public bool IsMidHooked( nint address )
{
lock (_sync)
private class MidHookNode
{
return _midChains.TryGetValue(address, out var chain) && chain.Hooked;
public required Guid Id { get; init; }
public nint HookHandle { get; set; }
public required MidHookDelegate BuiltDelegate { get; init; }
}
}

public bool IsHooked( nint functionAddress )
{
lock (_sync)
private class HookChain
{
return _chains.TryGetValue(functionAddress, out var chain) && chain.Hooked;
public bool Hooked { get; set; } = false;
public required nint FunctionAddress { get; set; }
public nint HookHandle { get; set; }
public nint OriginalFunctionAddress { get; set; }
public List<HookNode> Nodes { get; } = [];
}
}

public nint GetOriginal( nint functionAddress )
{
lock (_sync)
private class MidHookChain
{
if (_chains.TryGetValue(functionAddress, out var chain))
{
if (!chain.Hooked)
{
return functionAddress;
}
if (chain.Nodes.Count == 0)
{
return chain.OriginalFunctionAddress;
}
return chain.Nodes[^1].OriginalFuncPtr;
}
return nint.Zero;
public bool Hooked { get; set; } = false;
public required nint Address { get; set; }
public nint HookHandle { get; set; }
public List<MidHookNode> Nodes { get; } = [];
public MidHookInternalDelegate? InternalCallback { get; set; } // Keep delegate alive
}

private readonly ConcurrentDictionary<nint, HookChain> chains = new();
private readonly ConcurrentDictionary<nint, MidHookChain> midChains = new();

public bool IsMidHooked( nint address )
{
return midChains.TryGetValue(address, out var chain) && chain.Hooked;
}

public bool IsHooked( nint functionAddress )
{
return chains.TryGetValue(functionAddress, out var chain) && chain.Hooked;
}
}

public Guid AddMidHook( nint address, MidHookDelegate callback )
{
MidHookChain chain;
MidHookNode node = new MidHookNode {
Id = Guid.NewGuid(),
BuiltDelegate = callback,
};
public nint GetOriginal( nint functionAddress )
{
return chains.TryGetValue(functionAddress, out var chain)
? !chain.Hooked ? functionAddress : chain.Nodes.Count == 0 ? chain.OriginalFunctionAddress : chain.Nodes[^1].OriginalFuncPtr
: nint.Zero;
}

lock (_sync)
public Guid AddMidHook( nint address, MidHookDelegate callback )
{
if (!_midChains.TryGetValue(address, out chain))
{
chain = new MidHookChain { Address = address };
chain.HookHandle = NativeHooks.AllocateMHook();
MidHookDelegate _unmanagedCallback = ( ref MidHookContext ctx ) =>
var node = new MidHookNode {
Id = Guid.NewGuid(),
BuiltDelegate = callback,
};

if (!midChains.TryGetValue(address, out var chain))
{
try
{
foreach (var n in chain.Nodes)
chain = new MidHookChain {
Address = address,
HookHandle = NativeHooks.AllocateMHook()
};

MidHookInternalDelegate internalCallback;

unsafe
{
n.BuiltDelegate(ref ctx);
internalCallback = ( contextPtr ) =>
{
try
{
ref var ctx = ref Unsafe.AsRef<MidHookContext>(contextPtr);
foreach (var n in chain.Nodes)
{
n.BuiltDelegate(ref ctx);
}
}
catch (Exception e)
{
if (!GlobalExceptionHandler.Handle(e)) return;
}
};
}
}
catch (Exception e)
{
if (!GlobalExceptionHandler.Handle(e)) return;
}
};
NativeHooks.SetMHook(chain.HookHandle, address, Marshal.GetFunctionPointerForDelegate(_unmanagedCallback));
NativeHooks.EnableMHook(chain.HookHandle);
chain.Hooked = true;
_midChains[address] = chain;
}
chain.Nodes.Add(node);
}

return node.Id;
}
// Keep delegate alive to prevent GC
chain.InternalCallback = internalCallback;
var callbackPtr = Marshal.GetFunctionPointerForDelegate(internalCallback);

public Guid AddHook( nint functionAddress, Func<Func<nint>, Delegate> callbackBuilder )
{
HookChain chain;
HookNode node = new HookNode {
Id = Guid.NewGuid(),
CallbackBuilder = callbackBuilder,
};
NativeHooks.SetMHook(chain.HookHandle, address, callbackPtr);
NativeHooks.EnableMHook(chain.HookHandle);

lock (_sync)
{
if (!_chains.TryGetValue(functionAddress, out chain))
{
chain = new HookChain { FunctionAddress = functionAddress };
_chains[functionAddress] = chain;
}
chain.Nodes.Add(node);
RebuildChain(chain);
}
chain.Hooked = true;
midChains[address] = chain;
}

return node.Id;
}
chain.Nodes.Add(node);
return node.Id;
}

public void RemoveMidHook( List<Guid> nodeIds )
{
lock (_sync)
public Guid AddHook( nint functionAddress, Func<Func<nint>, Delegate> callbackBuilder )
{
var chains = _midChains.Values.Where(c => c.Nodes.Any(n => nodeIds.Contains(n.Id))).ToList();
if (chains.Count == 0) return;
foreach (var chain in chains)
{
chain.Nodes.RemoveAll(n => nodeIds.Contains(n.Id));
}
var node = new HookNode {
Id = Guid.NewGuid(),
CallbackBuilder = callbackBuilder,
};

if (!chains.TryGetValue(functionAddress, out var chain))
{
chain = new HookChain { FunctionAddress = functionAddress };
chains[functionAddress] = chain;
}
chain.Nodes.Add(node);
RebuildChain(chain);

return node.Id;
}
}

public void Remove( List<Guid> nodeIds )
{
lock (_sync)
public void RemoveMidHook( List<Guid> nodeIds )
{
var chains = _chains.Values.Where(c => c.Nodes.Any(n => nodeIds.Contains(n.Id))).ToList();
if (chains.Count == 0) return;
foreach (var chain in chains)
{
chain.Nodes.RemoveAll(n => nodeIds.Contains(n.Id));
RebuildChain(chain);
}
midChains.Values.Where(c => c.Nodes.Any(n => nodeIds.Contains(n.Id))).ToList().ForEach(chain =>
{
_ = chain.Nodes.RemoveAll(n => nodeIds.Contains(n.Id));
});
}
}

private void RebuildChain( HookChain chain )
{
try
public void Remove( List<Guid> nodeIds )
{
// Rebuild delegates from first to last, wiring each to previous pointer (or original for first)
if (chain.Hooked)
{
for (int i = 0; i < chain.Nodes.Count; i++)
chains.Values.Where(c => c.Nodes.Any(n => nodeIds.Contains(n.Id))).ToList().ForEach(chain =>
{
chain.Nodes[i].BuiltDelegate = null;
chain.Nodes[i].BuiltPointer = nint.Zero;
if (chain.Nodes[i].HookHandle != 0)
{
NativeHooks.DeallocateHook(chain.Nodes[i].HookHandle);
chain.Nodes[i].HookHandle = 0;
}
}
chain.OriginalFunctionAddress = 0;
NativeHooks.DeallocateHook(chain.HookHandle);
chain.HookHandle = 0;
chain.Hooked = false;
}
chain.HookHandle = NativeHooks.AllocateHook();

for (int i = 0; i < chain.Nodes.Count; i++)
{
var node = chain.Nodes[i];

var built = node.CallbackBuilder.Invoke(() => node.OriginalFuncPtr);
node.BuiltDelegate = built;
node.BuiltPointer = Marshal.GetFunctionPointerForDelegate(node.BuiltDelegate);
if (i == 0)
_ = chain.Nodes.RemoveAll(n => nodeIds.Contains(n.Id));
RebuildChain(chain);
});
}

private void RebuildChain( HookChain chain )
{
try
{
NativeHooks.SetHook(chain.HookHandle, chain.FunctionAddress, node.BuiltPointer);
node.OriginalFuncPtr = NativeHooks.GetHookOriginal(chain.HookHandle);
chain.OriginalFunctionAddress = node.OriginalFuncPtr;
NativeHooks.EnableHook(chain.HookHandle);
chain.Hooked = true;
// Rebuild delegates from first to last, wiring each to previous pointer (or original for first)
if (chain.Hooked)
{
for (var i = 0; i < chain.Nodes.Count; i++)
{
chain.Nodes[i].BuiltDelegate = null;
chain.Nodes[i].BuiltPointer = nint.Zero;
if (chain.Nodes[i].HookHandle != 0)
{
NativeHooks.DeallocateHook(chain.Nodes[i].HookHandle);
chain.Nodes[i].HookHandle = 0;
}
}
chain.OriginalFunctionAddress = 0;
NativeHooks.DeallocateHook(chain.HookHandle);
chain.HookHandle = 0;
chain.Hooked = false;
}
chain.HookHandle = NativeHooks.AllocateHook();

for (var i = 0; i < chain.Nodes.Count; i++)
{
var node = chain.Nodes[i];

var built = node.CallbackBuilder.Invoke(() => node.OriginalFuncPtr);
node.BuiltDelegate = built;
node.BuiltPointer = Marshal.GetFunctionPointerForDelegate(node.BuiltDelegate);
if (i == 0)
{
NativeHooks.SetHook(chain.HookHandle, chain.FunctionAddress, node.BuiltPointer);
node.OriginalFuncPtr = NativeHooks.GetHookOriginal(chain.HookHandle);
chain.OriginalFunctionAddress = node.OriginalFuncPtr;
NativeHooks.EnableHook(chain.HookHandle);
chain.Hooked = true;
}
else
{
node.HookHandle = NativeHooks.AllocateHook();
NativeHooks.SetHook(node.HookHandle, chain.Nodes[i - 1].OriginalFuncPtr, node.BuiltPointer);
NativeHooks.EnableHook(node.HookHandle);
node.OriginalFuncPtr = NativeHooks.GetHookOriginal(node.HookHandle);
}
}
}
else
catch (Exception e)
{
node.HookHandle = NativeHooks.AllocateHook();
NativeHooks.SetHook(node.HookHandle, chain.Nodes[i - 1].OriginalFuncPtr, node.BuiltPointer);
NativeHooks.EnableHook(node.HookHandle);
node.OriginalFuncPtr = NativeHooks.GetHookOriginal(node.HookHandle);
if (!GlobalExceptionHandler.Handle(e)) return;
AnsiConsole.WriteException(e);
}
}
}
catch (Exception e)
{
if (!GlobalExceptionHandler.Handle(e)) return;
AnsiConsole.WriteException(e);
}
}
}
Loading
Loading