Skip to content

Commit

Permalink
Deprecating .get in llvm dialect
Browse files Browse the repository at this point in the history
Making respective changes from .get to __init__ in
mpi lowering and tests
  • Loading branch information
ShaolunWang committed Jul 26, 2023
1 parent d8f4b59 commit 640af51
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 60 deletions.
30 changes: 15 additions & 15 deletions tests/dialects/test_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ def test_llvm_pointer_ops():
module = builtin.ModuleOp(
[
idx := arith.Constant.from_int_and_width(0, 64),
ptr := llvm.AllocaOp.get(idx, builtin.i32),
val := llvm.LoadOp.get(ptr),
nullptr := llvm.NullOp.get(),
alloc_ptr := llvm.AllocaOp.get(idx, elem_type=builtin.IndexType()),
llvm.LoadOp.get(alloc_ptr),
store := llvm.StoreOp.get(
ptr := llvm.AllocaOp(idx, builtin.i32),
val := llvm.LoadOp(ptr),
nullptr := llvm.NullOp(),
alloc_ptr := llvm.AllocaOp(idx, elem_type=builtin.IndexType()),
llvm.LoadOp(alloc_ptr),
store := llvm.StoreOp(
val, ptr, alignment=32, volatile=True, nontemporal=True
),
]
Expand All @@ -42,8 +42,8 @@ def test_llvm_pointer_ops():

def test_llvm_ptr_to_int_to_ptr():
idx = arith.Constant.from_int_and_width(0, 64)
ptr = llvm.IntToPtrOp.get(idx, ptr_type=builtin.i32)
int_val = llvm.PtrToIntOp.get(ptr)
ptr = llvm.IntToPtrOp(idx, ptr_type=builtin.i32)
int_val = llvm.PtrToIntOp(ptr)

assert ptr.input == idx.result
assert isinstance(ptr.output.type, llvm.LLVMPointerType)
Expand All @@ -67,19 +67,19 @@ def test_llvm_pointer_type():

def test_llvm_getelementptr_op_invalid_construction():
size = arith.Constant.from_int_and_width(1, 32)
opaque_ptr = llvm.AllocaOp.get(size, builtin.i32, as_untyped_ptr=True)
opaque_ptr = llvm.AllocaOp(size, builtin.i32, as_untyped_ptr=True)

# check that passing an opaque pointer to GEP without a pointee type fails
with pytest.raises(ValueError):
llvm.GEPOp.get(
llvm.GEPOp(
opaque_ptr,
indices=[1],
result_type=llvm.LLVMPointerType.typed(builtin.i32),
)

# check that non-pointer arguments fail
with pytest.raises(ValueError):
llvm.GEPOp.get(
llvm.GEPOp(
size,
indices=[1],
result_type=llvm.LLVMPointerType.opaque(),
Expand All @@ -88,9 +88,9 @@ def test_llvm_getelementptr_op_invalid_construction():

def test_llvm_getelementptr_op():
size = arith.Constant.from_int_and_width(1, 32)
ptr = llvm.AllocaOp.get(size, builtin.i32)
ptr = llvm.AllocaOp(size, builtin.i32)
ptr_type = llvm.LLVMPointerType.typed(ptr.res.type)
opaque_ptr = llvm.AllocaOp.get(size, builtin.i32, as_untyped_ptr=True)
opaque_ptr = llvm.AllocaOp(size, builtin.i32, as_untyped_ptr=True)

# check that construction with static-only offsets and inbounds attr works:
gep1 = llvm.GEPOp.from_mixed_indices(
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_linkage_attr_unknown_str():


def test_global_op():
global_op = llvm.GlobalOp.get(
global_op = llvm.GlobalOp(
builtin.i32,
"testsymbol",
"internal",
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_global_op():

def test_addressof_op():
ptr_type = llvm.LLVMPointerType.typed(builtin.i32)
address_of = llvm.AddressOfOp.get("test", ptr_type)
address_of = llvm.AddressOfOp("test", ptr_type)

assert isinstance(address_of.global_name, builtin.SymbolRefAttr)
assert address_of.global_name.root_reference.data == "test"
Expand Down
155 changes: 125 additions & 30 deletions xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from xdsl.parser import AttrParser, Parser
from xdsl.printer import Printer
from xdsl.traits import IsTerminator, SymbolOpInterface
from xdsl.utils.deprecation import deprecated
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.hints import isa

Expand Down Expand Up @@ -455,8 +456,8 @@ class GEPOp(IRDLOperation):
rawConstantIndices: DenseArrayBase = attr_def(DenseArrayBase)
inbounds: UnitAttr | None = opt_attr_def(UnitAttr)

@staticmethod
def get(
def __init__(
self,
ptr: SSAValue | Operation,
indices: Sequence[int],
ssa_indices: Sequence[SSAValue | Operation] | None = None,
Expand All @@ -473,7 +474,6 @@ def get(
Take a look at `from_mixed_indices` for something without
magic values.
"""
# construct default mutable argument here:
if ssa_indices is None:
ssa_indices = []

Expand All @@ -499,10 +499,32 @@ def get(
if inbounds:
attrs["inbounds"] = UnitAttr()

return GEPOp.build(
super().__init__(
operands=[ptr, ssa_indices], result_types=[result_type], attributes=attrs
)

@deprecated("Use GEPOp(...) instead")
@staticmethod
def get(
ptr: SSAValue | Operation,
indices: Sequence[int],
ssa_indices: Sequence[SSAValue | Operation] | None = None,
result_type: LLVMPointerType = LLVMPointerType.opaque(),
inbounds: bool = False,
pointee_type: Attribute | None = None,
):
"""
A basic constructor for the GEPOp.
Pass the GEP_USE_SSA_VAL magic value in place of each constant
index that you want to be read from an SSA value.
Take a look at `from_mixed_indices` for something without
magic values.
"""
# construct default mutable argument here:
return GEPOp(ptr, indices, ssa_indices, result_type, inbounds, pointee_type)

@staticmethod
def from_mixed_indices(
ptr: SSAValue | Operation,
Expand All @@ -528,7 +550,7 @@ def from_mixed_indices(
else:
const_indices.append(GEP_USE_SSA_VAL)
ssa_indices.append(SSAValue.get(idx))
return GEPOp.get(
return GEPOp(
ptr,
const_indices,
ssa_indices,
Expand All @@ -548,8 +570,8 @@ class AllocaOp(IRDLOperation):

res: OpResult = result_def()

@staticmethod
def get(
def __init__(
self,
size: SSAValue | Operation,
elem_type: Attribute,
alignment: int = 32,
Expand All @@ -564,9 +586,17 @@ def get(
else:
ptr_type = LLVMPointerType.typed(elem_type)

return AllocaOp.build(
operands=[size], attributes=attrs, result_types=[ptr_type]
)
super().__init__(operands=[size], attributes=attrs, result_types=[ptr_type])

@deprecated("Use Alloca(...) instead")
@staticmethod
def get(
size: SSAValue | Operation,
elem_type: Attribute,
alignment: int = 32,
as_untyped_ptr: bool = False,
):
return AllocaOp(size, elem_type, alignment, as_untyped_ptr)


@irdl_op_definition
Expand All @@ -577,13 +607,17 @@ class IntToPtrOp(IRDLOperation):

output: OpResult = result_def(LLVMPointerType)

@staticmethod
def get(input: SSAValue | Operation, ptr_type: Attribute | None = None):
def __init__(self, input: SSAValue | Operation, ptr_type: Attribute | None = None):
if ptr_type is None:
ptr_type = LLVMPointerType.opaque()
else:
ptr_type = LLVMPointerType.typed(ptr_type)
return IntToPtrOp.build(operands=[input], result_types=[ptr_type])
super().__init__(operands=[input], result_types=[ptr_type])

@deprecated("Use IntToPtrOp(...) instead")
@staticmethod
def get(input: SSAValue | Operation, ptr_type: Attribute | None = None):
return IntToPtrOp(input, ptr_type)


@irdl_op_definition
Expand All @@ -594,9 +628,13 @@ class PtrToIntOp(IRDLOperation):

output: OpResult = result_def(IntegerType)

def __init__(self, arg: SSAValue | Operation, int_type: Attribute = i64):
super().__init__(operands=[arg], result_types=[int_type])

@deprecated("Use PtrToIntOp(...) instead")
@staticmethod
def get(arg: SSAValue | Operation, int_type: Attribute = i64):
return PtrToIntOp.build(operands=[arg], result_types=[int_type])
return PtrToIntOp(arg, int_type)


@irdl_op_definition
Expand All @@ -607,8 +645,7 @@ class LoadOp(IRDLOperation):

dereferenced_value: OpResult = result_def()

@staticmethod
def get(ptr: SSAValue | Operation, result_type: Attribute | None = None):
def __init__(self, ptr: SSAValue | Operation, result_type: Attribute | None = None):
if result_type is None:
ptr = SSAValue.get(ptr)
assert isinstance(ptr.type, LLVMPointerType)
Expand All @@ -619,7 +656,12 @@ def get(ptr: SSAValue | Operation, result_type: Attribute | None = None):
)
result_type = ptr.type.type

return LoadOp.build(operands=[ptr], result_types=[result_type])
super().__init__(operands=[ptr], result_types=[result_type])

@deprecated("Use LoadOp(...) instead")
@staticmethod
def get(ptr: SSAValue | Operation, result_type: Attribute | None = None):
return LoadOp(ptr, result_type)


@irdl_op_definition
Expand All @@ -634,8 +676,8 @@ class StoreOp(IRDLOperation):
volatile_: UnitAttr | None = opt_attr_def(UnitAttr)
nontemporal: UnitAttr | None = opt_attr_def(UnitAttr)

@staticmethod
def get(
def __init__(
self,
value: SSAValue | Operation,
ptr: SSAValue | Operation,
alignment: int | None = None,
Expand All @@ -654,26 +696,42 @@ def get(
if nontemporal:
attrs["nontemporal"] = UnitAttr()

return StoreOp.build(
super().__init__(
operands=[value, ptr],
attributes=attrs,
result_types=[],
)

@deprecated("Use Load(...) instead")
@staticmethod
def get(
value: SSAValue | Operation,
ptr: SSAValue | Operation,
alignment: int | None = None,
ordering: int = 0,
volatile: bool = False,
nontemporal: bool = False,
):
return StoreOp(value, ptr, alignment, ordering, volatile, nontemporal)


@irdl_op_definition
class NullOp(IRDLOperation):
name = "llvm.mlir.null"

nullptr: OpResult = result_def(LLVMPointerType)

@staticmethod
def get(ptr_type: LLVMPointerType | None = None):
def __init__(self, ptr_type: LLVMPointerType | None = None):
if ptr_type is None:
ptr_type = LLVMPointerType.opaque()
assert isinstance(ptr_type, LLVMPointerType)

return NullOp.build(result_types=[ptr_type])
super().__init__(result_types=[ptr_type])

@deprecated("Use NullOp(...) instead")
@staticmethod
def get(ptr_type: LLVMPointerType | None = None):
return NullOp(ptr_type)


@irdl_op_definition
Expand Down Expand Up @@ -768,8 +826,8 @@ class GlobalOp(IRDLOperation):

traits = frozenset([SymbolOpInterface()])

@staticmethod
def get(
def __init__(
self,
global_type: Attribute,
sym_name: str | StringAttr,
linkage: str | LinkageAttr,
Expand Down Expand Up @@ -818,7 +876,36 @@ def get(
section = StringAttr(section)
attrs["section"] = section

return GlobalOp.build(attributes=attrs, regions=[Region([])])
super().__init__(attributes=attrs, regions=[Region([])])

@deprecated("Use GlobalOp(...) instead")
@staticmethod
def get(
global_type: Attribute,
sym_name: str | StringAttr,
linkage: str | LinkageAttr,
addr_space: int = 0,
constant: bool | None = None,
dso_local: bool | None = None,
thread_local_: bool | None = None,
value: Attribute | None = None,
alignment: int | None = None,
unnamed_addr: int | None = None,
section: str | StringAttr | None = None,
):
return GlobalOp(
global_type,
sym_name,
linkage,
addr_space,
constant,
dso_local,
thread_local_,
value,
alignment,
unnamed_addr,
section,
)


@irdl_op_definition
Expand All @@ -828,17 +915,25 @@ class AddressOfOp(IRDLOperation):
global_name: SymbolRefAttr = attr_def(SymbolRefAttr)
result: OpResult = result_def(LLVMPointerType)

@staticmethod
def get(
global_name: str | StringAttr | SymbolRefAttr, result_type: LLVMPointerType
def __init__(
self,
global_name: str | StringAttr | SymbolRefAttr,
result_type: LLVMPointerType,
):
if isinstance(global_name, (StringAttr, str)):
global_name = SymbolRefAttr(global_name)

return AddressOfOp.build(
super().__init__(
attributes={"global_name": global_name}, result_types=[result_type]
)

@deprecated("Use AddressOfOp(...) instead")
@staticmethod
def get(
global_name: str | StringAttr | SymbolRefAttr, result_type: LLVMPointerType
):
return AddressOfOp(global_name, result_type)


LLVM_CALLING_CONVS: set[str] = {
"ccc",
Expand Down

0 comments on commit 640af51

Please sign in to comment.