Skip to content

Commit

Permalink
dialects: allow for default regions in inits (#1040)
Browse files Browse the repository at this point in the history
Another offshoot of #1017, wanted to discuss some extra inits directly.
This PR allows the creation of ops without the regions populated, to
allow for this sort of thing once 1017 is merged:

``` python
with ImplicitBuilder(toy.FuncOp("main", ((), ())).body):
  ...
```
  • Loading branch information
superlopuh committed Jun 2, 2023
1 parent 90747e5 commit 680c7f4
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 12 deletions.
9 changes: 8 additions & 1 deletion docs/Toy/toy/dialects/toy.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,19 @@ class FuncOp(IRDLOperation):
sym_visibility: OptOpAttr[StringAttr]

def __init__(
self, name: str, ftype: FunctionType, region: Region, /, private: bool = False
self,
name: str,
ftype: FunctionType,
region: Region | type[Region.DEFAULT] = Region.DEFAULT,
/,
private: bool = False,
):
attributes: dict[str, Attribute] = {
"sym_name": StringAttr(name),
"function_type": ftype,
}
if not isinstance(region, Region):
region = Region(Block(arg_types=ftype.inputs))
if private:
attributes["sym_visibility"] = StringAttr("private")

Expand Down
5 changes: 5 additions & 0 deletions tests/dialects/test_pdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def test_build_rewrite():
assert r.attributes["name"] == StringAttr("r")
assert r.external_args == (type_val, attr_val)
assert len(r.results) == 0
assert r.body is None

r1 = pdl.RewriteOp(name="r", root=None, external_args=[type_val, attr_val])

assert r1.body is not None


def test_build_operation_replace():
Expand Down
6 changes: 4 additions & 2 deletions xdsl/dialects/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,16 @@ def __init__(
self,
name: str,
function_type: FunctionType | tuple[Sequence[Attribute], Sequence[Attribute]],
region: Region,
region: Region | type[Region.DEFAULT] = Region.DEFAULT,
visibility: StringAttr | str | None = None,
):
if isinstance(visibility, str):
visibility = StringAttr(visibility)
if isinstance(function_type, tuple):
inputs, outputs = function_type
function_type = FunctionType.from_lists(inputs, outputs)
if not isinstance(region, Region):
region = Region(Block(arg_types=function_type.inputs))
attributes: dict[str, Attribute | None] = {
"sym_name": StringAttr(name),
"function_type": function_type,
Expand Down Expand Up @@ -199,7 +201,7 @@ def from_region(
name: str,
input_types: Sequence[Attribute],
return_types: Sequence[Attribute],
region: Region,
region: Region | type[Region.DEFAULT] = Region.DEFAULT,
visibility: StringAttr | str | None = None,
) -> FuncOp:
return FuncOp(
Expand Down
26 changes: 17 additions & 9 deletions xdsl/dialects/pdl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Annotated, Generic, Iterable, Sequence, TypeVar
from typing import Annotated, Generic, Iterable, Sequence, TypeVar, cast

from xdsl.dialects.builtin import (
AnyArrayAttr,
Expand Down Expand Up @@ -472,13 +472,15 @@ def __init__(
self,
benefit: int | IntegerAttr[IntegerType],
sym_name: str | StringAttr | None,
body: Region | Block.BlockCallback,
body: Region | Block.BlockCallback | None = None,
):
if isinstance(benefit, int):
benefit = IntegerAttr(benefit, 16)
if isinstance(sym_name, str):
sym_name = StringAttr(sym_name)
if not isinstance(body, Region):
if body is None:
body = Region(Block())
elif not isinstance(body, Region):
body = Region(Block.from_callable([], body))
super().__init__(
attributes={
Expand Down Expand Up @@ -743,7 +745,10 @@ def verify_(self) -> None:
def __init__(
self,
root: SSAValue | None,
body: Region | Block.BlockCallback | None = None,
body: Region
| Block.BlockCallback
| type[Region.DEFAULT]
| None = Region.DEFAULT,
name: str | StringAttr | None = None,
external_args: Sequence[SSAValue] = (),
) -> None:
Expand All @@ -758,12 +763,15 @@ def __init__(
operands.append(external_args)

regions: list[Region | list[Region]] = []
if isinstance(body, Region):
regions.append([body])
elif body is not None:
regions.append(Region(Block.from_callable([], body)))
else:
if body is Region.DEFAULT:
regions.append(Region(Block()))
elif isinstance(body, Region):
regions.append(body)
elif body is None:
regions.append([])
else:
body = cast(Block.BlockCallback, body)
regions.append(Region(Block.from_callable([], body)))

attributes: dict[str, Attribute] = {}
if name is not None:
Expand Down
6 changes: 6 additions & 0 deletions xdsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,6 +1359,12 @@ def __hash__(self) -> int:
class Region(IRNode):
"""A region contains a CFG of blocks. Regions are contained in operations."""

class DEFAULT:
"""
A marker to be used as a default parameter to functions when a default
single-block region should be constructed.
"""

blocks: list[Block] = field(default_factory=list)
"""Blocks contained in the region. The first block is the entry block."""

Expand Down

0 comments on commit 680c7f4

Please sign in to comment.