Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions codegen/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def prepare():


def update_api():
""" Update the public API and patch the public-facing API of the backends. """
"""Update the public API and patch the public-facing API of the backends."""

print("## Updating API")

Expand Down Expand Up @@ -50,7 +50,7 @@ def update_api():


def update_rs():
""" Update and check the rs backend. """
"""Update and check the rs backend."""

print("## Validating rs.py")

Expand All @@ -68,7 +68,7 @@ def update_rs():


def main():
""" Codegen entry point. """
"""Codegen entry point."""

with PrintToFile(os.path.join(lib_dir, "resources", "codegen_report.md")):
print("# Code generatation report")
Expand Down
77 changes: 71 additions & 6 deletions codegen/apipatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import os

from .utils import print, lib_dir, blacken, to_snake_case, to_camel_case, Patcher
from .idlparser import get_idl_parser
from codegen.utils import print, lib_dir, blacken, to_snake_case, to_camel_case, Patcher
from codegen.idlparser import get_idl_parser


def patch_base_api(code):
Expand Down Expand Up @@ -42,7 +42,11 @@ def patch_backend_api(code):
base_api_code = f.read().decode()

# Patch!
for patcher in [CommentRemover(), BackendApiPatcher(base_api_code)]:
for patcher in [
CommentRemover(),
BackendApiPatcher(base_api_code),
StructValidationChecker(),
]:
patcher.apply(code)
code = patcher.dumps()
return code
Expand All @@ -53,7 +57,7 @@ class CommentRemover(Patcher):
to prevent accumulating comments.
"""

triggers = "# IDL:", "# FIXME: unknown api"
triggers = "# IDL:", "# FIXME: unknown api", "# FIXME: missing check_struct"

def apply(self, code):
self._init(code)
Expand Down Expand Up @@ -174,7 +178,7 @@ def patch_properties(self, classname, i1, i2):
self._apidiffs_from_lines(pre_lines, propname)
if self.prop_is_known(classname, propname):
if "@apidiff.add" in pre_lines:
print(f"Error: apidiff.add for known {classname}.{propname}")
print(f"ERROR: apidiff.add for known {classname}.{propname}")
elif "@apidiff.hide" in pre_lines:
pass # continue as normal
old_line = self.lines[j1]
Expand Down Expand Up @@ -207,7 +211,7 @@ def patch_methods(self, classname, i1, i2):
self._apidiffs_from_lines(pre_lines, methodname)
if self.method_is_known(classname, methodname):
if "@apidiff.add" in pre_lines:
print(f"Error: apidiff.add for known {classname}.{methodname}")
print(f"ERROR: apidiff.add for known {classname}.{methodname}")
elif "@apidiff.hide" in pre_lines:
pass # continue as normal
elif "@apidiff.change" in pre_lines:
Expand Down Expand Up @@ -443,3 +447,64 @@ def get_required_prop_names(self, classname):
def get_required_method_names(self, classname):
_, methods = self.classes[classname]
return list(name for name in methods.keys() if methods[name][1])


class StructValidationChecker(Patcher):
"""Checks that all structs are vaildated in the methods that have incoming structs."""

def apply(self, code):
self._init(code)

idl = get_idl_parser()
all_structs = set()
ignore_structs = {"Extent3D"}

for classname, i1, i2 in self.iter_classes():
if classname not in idl.classes:
continue

# For each method ...
for methodname, j1, j2 in self.iter_methods(i1 + 1):
code = "\n".join(self.lines[j1 : j2 + 1])
# Get signature and cut it up in words
sig_words = code.partition("(")[2].split("):")[0]
for c in "][(),\"'":
sig_words = sig_words.replace(c, " ")
# Collect incoming structs from signature
method_structs = set()
for word in sig_words.split():
if word.startswith("structs."):
structname = word.partition(".")[2]
method_structs.update(self._get_sub_structs(idl, structname))
all_structs.update(method_structs)
# Collect structs being checked
checked = set()
for line in code.splitlines():
line = line.lstrip()
if line.startswith("check_struct("):
name = line.split("(")[1].split(",")[0].strip('"')
checked.add(name)
# Test that a matching check is done
unchecked = method_structs.difference(checked)
unchecked = list(sorted(unchecked.difference(ignore_structs)))
if (
methodname.endswith("_async")
and f"return self.{methodname[:-7]}" in code
):
pass
elif unchecked:
msg = f"missing check_struct in {methodname}: {unchecked}"
self.insert_line(j1, f"# FIXME: {msg}")
print(f"ERROR: {msg}")

# Test that we did find structs. In case our detection fails for
# some reason, this would probably catch that.
assert len(all_structs) > 10

def _get_sub_structs(self, idl, structname):
structnames = {structname}
for structfield in idl.structs[structname].values():
structname2 = structfield.typename[3:] # remove "GPU"
if structname2 in idl.structs:
structnames.update(self._get_sub_structs(idl, structname2))
return structnames
2 changes: 1 addition & 1 deletion codegen/hparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def get_h_parser(*, allow_cache=True):
""" Get the global HParser object. """
"""Get the global HParser object."""

# Singleton pattern
global _parser
Expand Down
2 changes: 1 addition & 1 deletion codegen/idlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


def get_idl_parser(*, allow_cache=True):
""" Get the global IdlParser object. """
"""Get the global IdlParser object."""

# Singleton pattern
global _parser
Expand Down
6 changes: 3 additions & 3 deletions codegen/rspatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def apply(self, code):
if name not in hp.functions:
msg = f"unknown C function {name}"
self.insert_line(i, f"{indent}# FIXME: {msg}")
print(f"Error: {msg}")
print(f"ERROR: {msg}")
else:
detected.add(name)
anno = hp.functions[name].replace(name, "f").strip(";")
Expand Down Expand Up @@ -302,7 +302,7 @@ def _validate_struct(self, hp, i1, i2):
if struct_name not in hp.structs:
msg = f"unknown C struct {struct_name}"
self.insert_line(i1, f"{indent}# FIXME: {msg}")
print(f"Error: {msg}")
print(f"ERROR: {msg}")
return
else:
struct = hp.structs[struct_name]
Expand All @@ -322,7 +322,7 @@ def _validate_struct(self, hp, i1, i2):
if key not in struct:
msg = f"unknown C struct field {struct_name}.{key}"
self.insert_line(i1 + j, f"{indent}# FIXME: {msg}")
print(f"Error: {msg}")
print(f"ERROR: {msg}")

# Insert comments for unused keys
more_lines = []
Expand Down
10 changes: 10 additions & 0 deletions codegen/tests/test_codegen_z.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,15 @@ def test_that_code_is_up_to_date():
print("Codegen check ok!")


def test_that_codegen_report_has_no_errors():
filename = os.path.join(lib_dir, "resources", "codegen_report.md")
with open(filename, "rb") as f:
text = f.read().decode()

# The codegen uses a prefix "ERROR:" for unacceptable things.
# All caps, some function names may contain the name "error".
assert "ERROR" not in text


if __name__ == "__main__":
test_that_code_is_up_to_date()
2 changes: 1 addition & 1 deletion codegen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def print(*args, **kwargs):


class PrintToFile:
""" Context manager to print to file. """
"""Context manager to print to file."""

def __init__(self, f):
if isinstance(f, str):
Expand Down
2 changes: 1 addition & 1 deletion examples/cube_glfw.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def draw_frame():


def simple_event_loop():
""" A real simple event loop, but it keeps the CPU busy. """
"""A real simple event loop, but it keeps the CPU busy."""
while update_glfw_canvasses():
glfw.poll_events()

Expand Down
4 changes: 2 additions & 2 deletions examples/triangle_glfw.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@


def simple_event_loop():
""" A real simple event loop, but it keeps the CPU busy. """
"""A real simple event loop, but it keeps the CPU busy."""
while update_glfw_canvasses():
glfw.poll_events()


def better_event_loop(max_fps=100):
""" A simple event loop that schedules draws. """
"""A simple event loop that schedules draws."""
td = 1 / max_fps
while update_glfw_canvasses():
# Determine next time to draw
Expand Down
9 changes: 8 additions & 1 deletion tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,13 @@ def compute_shader(
)
bind_group = device.create_bind_group(layout=bind_group_layout, entries=bindings)

# Create and run the pipeline, fail - test check_struct
with raises(ValueError):
compute_pipeline = device.create_compute_pipeline(
layout=pipeline_layout,
compute={"module": cshader, "entry_point": "main", "foo": 42},
)

# Create and run the pipeline
compute_pipeline = device.create_compute_pipeline(
layout=pipeline_layout,
Expand Down Expand Up @@ -259,7 +266,7 @@ def compute_shader(
compute_with_buffers({0: in1}, {0: c_int32 * 100}, compute_shader, n=-1)

with raises(TypeError): # invalid shader
compute_with_buffers({0: in1}, {0: c_int32 * 100}, "not a shader")
compute_with_buffers({0: in1}, {0: c_int32 * 100}, {"not", "a", "shader"})


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_rs_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_shader_module_creation():
with raises(TypeError):
device.create_shader_module(code=code4)
with raises(TypeError):
device.create_shader_module(code="not a shader")
device.create_shader_module(code={"not", "a", "shader"})
with raises(ValueError):
device.create_shader_module(code=b"bytes but no SpirV magic number")

Expand Down
4 changes: 2 additions & 2 deletions tests/test_rs_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,13 +511,13 @@ def cb(renderpass):
format=wgpu.TextureFormat.depth24plus_stencil8,
depth_write_enabled=True,
depth_compare=wgpu.CompareFunction.less_equal,
front={
stencil_front={
"compare": wgpu.CompareFunction.equal,
"fail_op": wgpu.StencilOperation.keep,
"depth_fail_op": wgpu.StencilOperation.keep,
"pass_op": wgpu.StencilOperation.keep,
},
back={
stencil_back={
"compare": wgpu.CompareFunction.equal,
"fail_op": wgpu.StencilOperation.keep,
"depth_fail_op": wgpu.StencilOperation.keep,
Expand Down
Loading