Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 87 additions & 75 deletions generator/native_generator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self):

def add_line(self, text: str = ""):
if text.strip():
self.lines.append(" " * self.indent_level + text)
self.lines.append(" " * self.indent_level + text)
else:
self.lines.append("")

Expand All @@ -62,14 +62,15 @@ def dedent(self):
self.indent_level = max(0, self.indent_level - 1)

def add_block(self, header: str, content_func):
self.add_line(header + " {")
self.add_line(header)
self.add_line("{")
self.indent()
content_func()
self.dedent()
self.add_line("}")

def get_code(self) -> str:
return "\n".join(self.lines)
return "\r\n".join(self.lines)


def split_by_last_dot(value: str):
Expand All @@ -94,7 +95,6 @@ def parse_native(lines: list[str]):
writer.add_line("#pragma warning disable CS0649")
writer.add_line("#pragma warning disable CS0169")
writer.add_line()
writer.add_line("using System.Buffers;")
writer.add_line("using System.Text;")
writer.add_line("using System.Threading;")
writer.add_line("using SwiftlyS2.Shared.Natives;")
Expand Down Expand Up @@ -158,24 +158,68 @@ def write_method_content():
if is_marked_sync:
writer.add_block("if (Thread.CurrentThread.ManagedThreadId != _MainThreadID)", lambda: writer.add_line('throw new InvalidOperationException("This method can only be called from the main thread.");'))

string_params = []
bytes_params = []
pool_declared = False

for t, n in param_signatures:
if t == "string":
if not pool_declared:
writer.add_line("var pool = ArrayPool<byte>.Shared;")
pool_declared = True
writer.add_line(f"var {n}Length = Encoding.UTF8.GetByteCount({n});")
writer.add_line(f"var {n}Buffer = pool.Rent({n}Length + 1);")
writer.add_line(f"Encoding.UTF8.GetBytes({n}, {n}Buffer);")
writer.add_line(f"{n}Buffer[{n}Length] = 0;")
string_params.append(n)
elif t == "byte[]":
writer.add_line(f"var {n}Length = {n}.Length;")
bytes_params.append(n)
string_params = [n for t, n in param_signatures if t == "string"]
bytes_params = [n for t, n in param_signatures if t == "byte[]"]

for param in bytes_params:
writer.add_line(f"var {param}Length = {param}.Length;")

if not string_params:
fixed_blocks = []
for param in bytes_params:
fixed_blocks.append(f"fixed (byte* {param}BufferPtr = {param})")

def write_simple_call():
call_args = []
for t, n in param_signatures:
if t == "byte[]":
call_args.extend([f"{n}BufferPtr", f"{n}Length"])
elif t == "bool":
call_args.append(f"{n} ? (byte)1 : (byte)0")
else:
call_args.append(n)

if is_buffer_return(return_type):
first_call_args = ["null"] + call_args
writer.add_line(f"var ret = _{function_name}({', '.join(first_call_args)});")
if return_type == "string":
writer.add_line("var retBuffer = new byte[ret + 1];")
else:
writer.add_line("var retBuffer = new byte[ret];")

def write_ret_fixed():
second_call_args = ["retBufferPtr"] + call_args
writer.add_line(f"ret = _{function_name}({', '.join(second_call_args)});")
if return_type == "string":
writer.add_line("return Encoding.UTF8.GetString(retBufferPtr, ret);")
else:
writer.add_line("var retBytes = new byte[ret];")
writer.add_line("for (int i = 0; i < ret; i++) retBytes[i] = retBufferPtr[i];")
writer.add_line("return retBytes;")

writer.add_block("fixed (byte* retBufferPtr = retBuffer)", write_ret_fixed)
else:
if return_type == "void":
writer.add_line(f"_{function_name}({', '.join(call_args)});")
else:
writer.add_line(f"var ret = _{function_name}({', '.join(call_args)});")
writer.add_line("return ret == 1;" if return_type == "bool" else "return ret;")

def write_with_fixed_blocks(blocks, index=0):
if index < len(blocks):
writer.add_block(blocks[index], lambda: write_with_fixed_blocks(blocks, index + 1))
else:
write_simple_call()

if fixed_blocks:
write_with_fixed_blocks(fixed_blocks)
else:
write_simple_call()
return

for param in string_params:
writer.add_line(f"byte[] {param}Buffer = Encoding.UTF8.GetBytes({param} + \"\\0\");")

fixed_blocks = []
for param in string_params:
fixed_blocks.append(f"fixed (byte* {param}BufferPtr = {param}Buffer)")
Expand All @@ -184,77 +228,42 @@ def write_method_content():

def write_native_call():
call_args = []
for t, n in param_signatures:
if t == "string":
call_args.append(f"{n}BufferPtr")
elif t == "byte[]":
call_args.extend([f"{n}BufferPtr", f"{n}Length"])
elif t == "bool":
call_args.append(f"{n} ? (byte)1 : (byte)0")
else:
call_args.append(n)

if is_buffer_return(return_type):
first_call_args = ["null"]
for t, n in param_signatures:
if t == "string":
first_call_args.append(f"{n}BufferPtr")
elif t == "byte[]":
first_call_args.extend([f"{n}BufferPtr", f"{n}Length"])
elif t == "bool":
first_call_args.append(f"{n} ? (byte)1 : (byte)0")
else:
first_call_args.append(n)

first_call_args = ["null"] + call_args
writer.add_line(f"var ret = _{function_name}({', '.join(first_call_args)});")

if not pool_declared:
writer.add_line("var pool = ArrayPool<byte>.Shared;")
writer.add_line("var retBuffer = pool.Rent(ret + 1);")
if return_type == "string":
writer.add_line("var retBuffer = new byte[ret + 1];")
else:
writer.add_line("var retBuffer = new byte[ret];")

def write_ret_fixed():
second_call_args = ["retBufferPtr"]
for t, n in param_signatures:
if t == "string":
second_call_args.append(f"{n}BufferPtr")
elif t == "byte[]":
second_call_args.extend([f"{n}BufferPtr", f"{n}Length"])
elif t == "bool":
second_call_args.append(f"{n} ? (byte)1 : (byte)0")
else:
second_call_args.append(n)

second_call_args = ["retBufferPtr"] + call_args
writer.add_line(f"ret = _{function_name}({', '.join(second_call_args)});")

if return_type == "string":
writer.add_line("var retString = Encoding.UTF8.GetString(retBufferPtr, ret);")
writer.add_line("pool.Return(retBuffer);")
for param in string_params:
writer.add_line(f"pool.Return({param}Buffer);")
writer.add_line("return retString;")
writer.add_line("return Encoding.UTF8.GetString(retBufferPtr, ret);")
else:
writer.add_line("var retBytes = new byte[ret];")
writer.add_line("for (int i = 0; i < ret; i++) retBytes[i] = retBufferPtr[i];")
writer.add_line("pool.Return(retBuffer);")
for param in string_params:
writer.add_line(f"pool.Return({param}Buffer);")
writer.add_line("return retBytes;")

writer.add_block("fixed (byte* retBufferPtr = retBuffer)", write_ret_fixed)

else:
for t, n in param_signatures:
if t == "string":
call_args.append(f"{n}BufferPtr")
elif t == "byte[]":
call_args.extend([f"{n}BufferPtr", f"{n}Length"])
elif t == "bool":
call_args.append(f"{n} ? (byte)1 : (byte)0")
else:
call_args.append(n)

if return_type == "void":
writer.add_line(f"_{function_name}({', '.join(call_args)});")
else:
writer.add_line(f"var ret = _{function_name}({', '.join(call_args)});")

for param in string_params:
writer.add_line(f"pool.Return({param}Buffer);")

if return_type != "void":
writer.add_line("return ret == 1;" if return_type == "bool" else f"return ret;")

writer.add_line("return ret == 1;" if return_type == "bool" else "return ret;")

def write_with_fixed_blocks(blocks, index=0):
if index < len(blocks):
writer.add_block(blocks[index], lambda: write_with_fixed_blocks(blocks, index + 1))
Expand All @@ -267,7 +276,10 @@ def write_with_fixed_blocks(blocks, index=0):
write_native_call()

writer.add_block(f"public unsafe static {RETURN_TYPE_MAP[return_type]} {function_name}({method_signature})", write_method_content)
writer.add_block(f"internal static class Native{class_name}", write_class_content)

# Benchmark class should be public for external access
access_modifier = "public" if class_name == "Benchmark" else "internal"
writer.add_block(f"{access_modifier} static class Native{class_name}", write_class_content)

with open(out_path, "w", encoding="utf-8", newline="") as f:
f.write(writer.get_code())
Expand Down
Loading
Loading