From 640af51565e742b66ff1ead2b039f961dc01694e Mon Sep 17 00:00:00 2001 From: ShaolunWang Date: Wed, 26 Jul 2023 14:38:31 +0100 Subject: [PATCH] Deprecating .get in llvm dialect Making respective changes from .get to __init__ in mpi lowering and tests --- tests/dialects/test_llvm.py | 30 +++--- xdsl/dialects/llvm.py | 155 ++++++++++++++++++++++++------ xdsl/transforms/lower_mpi.py | 26 +++-- xdsl/transforms/printf_to_llvm.py | 2 +- 4 files changed, 153 insertions(+), 60 deletions(-) diff --git a/tests/dialects/test_llvm.py b/tests/dialects/test_llvm.py index 187a9517b2..8d5345ffd0 100644 --- a/tests/dialects/test_llvm.py +++ b/tests/dialects/test_llvm.py @@ -11,12 +11,12 @@ def test_llvm_pointer_ops(): module = builtin.ModuleOp( [ idx := arith.Constant.from_int_and_width(0, 64), - ptr := llvm.AllocaOp.get(idx, builtin.i32), - val := llvm.LoadOp.get(ptr), - nullptr := llvm.NullOp.get(), - alloc_ptr := llvm.AllocaOp.get(idx, elem_type=builtin.IndexType()), - llvm.LoadOp.get(alloc_ptr), - store := llvm.StoreOp.get( + ptr := llvm.AllocaOp(idx, builtin.i32), + val := llvm.LoadOp(ptr), + nullptr := llvm.NullOp(), + alloc_ptr := llvm.AllocaOp(idx, elem_type=builtin.IndexType()), + llvm.LoadOp(alloc_ptr), + store := llvm.StoreOp( val, ptr, alignment=32, volatile=True, nontemporal=True ), ] @@ -42,8 +42,8 @@ def test_llvm_pointer_ops(): def test_llvm_ptr_to_int_to_ptr(): idx = arith.Constant.from_int_and_width(0, 64) - ptr = llvm.IntToPtrOp.get(idx, ptr_type=builtin.i32) - int_val = llvm.PtrToIntOp.get(ptr) + ptr = llvm.IntToPtrOp(idx, ptr_type=builtin.i32) + int_val = llvm.PtrToIntOp(ptr) assert ptr.input == idx.result assert isinstance(ptr.output.type, llvm.LLVMPointerType) @@ -67,11 +67,11 @@ def test_llvm_pointer_type(): def test_llvm_getelementptr_op_invalid_construction(): size = arith.Constant.from_int_and_width(1, 32) - opaque_ptr = llvm.AllocaOp.get(size, builtin.i32, as_untyped_ptr=True) + opaque_ptr = llvm.AllocaOp(size, builtin.i32, as_untyped_ptr=True) # check that passing an opaque pointer to GEP without a pointee type fails with pytest.raises(ValueError): - llvm.GEPOp.get( + llvm.GEPOp( opaque_ptr, indices=[1], result_type=llvm.LLVMPointerType.typed(builtin.i32), @@ -79,7 +79,7 @@ def test_llvm_getelementptr_op_invalid_construction(): # check that non-pointer arguments fail with pytest.raises(ValueError): - llvm.GEPOp.get( + llvm.GEPOp( size, indices=[1], result_type=llvm.LLVMPointerType.opaque(), @@ -88,9 +88,9 @@ def test_llvm_getelementptr_op_invalid_construction(): def test_llvm_getelementptr_op(): size = arith.Constant.from_int_and_width(1, 32) - ptr = llvm.AllocaOp.get(size, builtin.i32) + ptr = llvm.AllocaOp(size, builtin.i32) ptr_type = llvm.LLVMPointerType.typed(ptr.res.type) - opaque_ptr = llvm.AllocaOp.get(size, builtin.i32, as_untyped_ptr=True) + opaque_ptr = llvm.AllocaOp(size, builtin.i32, as_untyped_ptr=True) # check that construction with static-only offsets and inbounds attr works: gep1 = llvm.GEPOp.from_mixed_indices( @@ -149,7 +149,7 @@ def test_linkage_attr_unknown_str(): def test_global_op(): - global_op = llvm.GlobalOp.get( + global_op = llvm.GlobalOp( builtin.i32, "testsymbol", "internal", @@ -179,7 +179,7 @@ def test_global_op(): def test_addressof_op(): ptr_type = llvm.LLVMPointerType.typed(builtin.i32) - address_of = llvm.AddressOfOp.get("test", ptr_type) + address_of = llvm.AddressOfOp("test", ptr_type) assert isinstance(address_of.global_name, builtin.SymbolRefAttr) assert address_of.global_name.root_reference.data == "test" diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py index 9dc1009419..d3563fe5b8 100644 --- a/xdsl/dialects/llvm.py +++ b/xdsl/dialects/llvm.py @@ -52,6 +52,7 @@ from xdsl.parser import AttrParser, Parser from xdsl.printer import Printer from xdsl.traits import IsTerminator, SymbolOpInterface +from xdsl.utils.deprecation import deprecated from xdsl.utils.exceptions import VerifyException from xdsl.utils.hints import isa @@ -455,8 +456,8 @@ class GEPOp(IRDLOperation): rawConstantIndices: DenseArrayBase = attr_def(DenseArrayBase) inbounds: UnitAttr | None = opt_attr_def(UnitAttr) - @staticmethod - def get( + def __init__( + self, ptr: SSAValue | Operation, indices: Sequence[int], ssa_indices: Sequence[SSAValue | Operation] | None = None, @@ -473,7 +474,6 @@ def get( Take a look at `from_mixed_indices` for something without magic values. """ - # construct default mutable argument here: if ssa_indices is None: ssa_indices = [] @@ -499,10 +499,32 @@ def get( if inbounds: attrs["inbounds"] = UnitAttr() - return GEPOp.build( + super().__init__( operands=[ptr, ssa_indices], result_types=[result_type], attributes=attrs ) + @deprecated("Use GEPOp(...) instead") + @staticmethod + def get( + ptr: SSAValue | Operation, + indices: Sequence[int], + ssa_indices: Sequence[SSAValue | Operation] | None = None, + result_type: LLVMPointerType = LLVMPointerType.opaque(), + inbounds: bool = False, + pointee_type: Attribute | None = None, + ): + """ + A basic constructor for the GEPOp. + + Pass the GEP_USE_SSA_VAL magic value in place of each constant + index that you want to be read from an SSA value. + + Take a look at `from_mixed_indices` for something without + magic values. + """ + # construct default mutable argument here: + return GEPOp(ptr, indices, ssa_indices, result_type, inbounds, pointee_type) + @staticmethod def from_mixed_indices( ptr: SSAValue | Operation, @@ -528,7 +550,7 @@ def from_mixed_indices( else: const_indices.append(GEP_USE_SSA_VAL) ssa_indices.append(SSAValue.get(idx)) - return GEPOp.get( + return GEPOp( ptr, const_indices, ssa_indices, @@ -548,8 +570,8 @@ class AllocaOp(IRDLOperation): res: OpResult = result_def() - @staticmethod - def get( + def __init__( + self, size: SSAValue | Operation, elem_type: Attribute, alignment: int = 32, @@ -564,9 +586,17 @@ def get( else: ptr_type = LLVMPointerType.typed(elem_type) - return AllocaOp.build( - operands=[size], attributes=attrs, result_types=[ptr_type] - ) + super().__init__(operands=[size], attributes=attrs, result_types=[ptr_type]) + + @deprecated("Use Alloca(...) instead") + @staticmethod + def get( + size: SSAValue | Operation, + elem_type: Attribute, + alignment: int = 32, + as_untyped_ptr: bool = False, + ): + return AllocaOp(size, elem_type, alignment, as_untyped_ptr) @irdl_op_definition @@ -577,13 +607,17 @@ class IntToPtrOp(IRDLOperation): output: OpResult = result_def(LLVMPointerType) - @staticmethod - def get(input: SSAValue | Operation, ptr_type: Attribute | None = None): + def __init__(self, input: SSAValue | Operation, ptr_type: Attribute | None = None): if ptr_type is None: ptr_type = LLVMPointerType.opaque() else: ptr_type = LLVMPointerType.typed(ptr_type) - return IntToPtrOp.build(operands=[input], result_types=[ptr_type]) + super().__init__(operands=[input], result_types=[ptr_type]) + + @deprecated("Use IntToPtrOp(...) instead") + @staticmethod + def get(input: SSAValue | Operation, ptr_type: Attribute | None = None): + return IntToPtrOp(input, ptr_type) @irdl_op_definition @@ -594,9 +628,13 @@ class PtrToIntOp(IRDLOperation): output: OpResult = result_def(IntegerType) + def __init__(self, arg: SSAValue | Operation, int_type: Attribute = i64): + super().__init__(operands=[arg], result_types=[int_type]) + + @deprecated("Use PtrToIntOp(...) instead") @staticmethod def get(arg: SSAValue | Operation, int_type: Attribute = i64): - return PtrToIntOp.build(operands=[arg], result_types=[int_type]) + return PtrToIntOp(arg, int_type) @irdl_op_definition @@ -607,8 +645,7 @@ class LoadOp(IRDLOperation): dereferenced_value: OpResult = result_def() - @staticmethod - def get(ptr: SSAValue | Operation, result_type: Attribute | None = None): + def __init__(self, ptr: SSAValue | Operation, result_type: Attribute | None = None): if result_type is None: ptr = SSAValue.get(ptr) assert isinstance(ptr.type, LLVMPointerType) @@ -619,7 +656,12 @@ def get(ptr: SSAValue | Operation, result_type: Attribute | None = None): ) result_type = ptr.type.type - return LoadOp.build(operands=[ptr], result_types=[result_type]) + super().__init__(operands=[ptr], result_types=[result_type]) + + @deprecated("Use LoadOp(...) instead") + @staticmethod + def get(ptr: SSAValue | Operation, result_type: Attribute | None = None): + return LoadOp(ptr, result_type) @irdl_op_definition @@ -634,8 +676,8 @@ class StoreOp(IRDLOperation): volatile_: UnitAttr | None = opt_attr_def(UnitAttr) nontemporal: UnitAttr | None = opt_attr_def(UnitAttr) - @staticmethod - def get( + def __init__( + self, value: SSAValue | Operation, ptr: SSAValue | Operation, alignment: int | None = None, @@ -654,12 +696,24 @@ def get( if nontemporal: attrs["nontemporal"] = UnitAttr() - return StoreOp.build( + super().__init__( operands=[value, ptr], attributes=attrs, result_types=[], ) + @deprecated("Use Load(...) instead") + @staticmethod + def get( + value: SSAValue | Operation, + ptr: SSAValue | Operation, + alignment: int | None = None, + ordering: int = 0, + volatile: bool = False, + nontemporal: bool = False, + ): + return StoreOp(value, ptr, alignment, ordering, volatile, nontemporal) + @irdl_op_definition class NullOp(IRDLOperation): @@ -667,13 +721,17 @@ class NullOp(IRDLOperation): nullptr: OpResult = result_def(LLVMPointerType) - @staticmethod - def get(ptr_type: LLVMPointerType | None = None): + def __init__(self, ptr_type: LLVMPointerType | None = None): if ptr_type is None: ptr_type = LLVMPointerType.opaque() assert isinstance(ptr_type, LLVMPointerType) - return NullOp.build(result_types=[ptr_type]) + super().__init__(result_types=[ptr_type]) + + @deprecated("Use NullOp(...) instead") + @staticmethod + def get(ptr_type: LLVMPointerType | None = None): + return NullOp(ptr_type) @irdl_op_definition @@ -768,8 +826,8 @@ class GlobalOp(IRDLOperation): traits = frozenset([SymbolOpInterface()]) - @staticmethod - def get( + def __init__( + self, global_type: Attribute, sym_name: str | StringAttr, linkage: str | LinkageAttr, @@ -818,7 +876,36 @@ def get( section = StringAttr(section) attrs["section"] = section - return GlobalOp.build(attributes=attrs, regions=[Region([])]) + super().__init__(attributes=attrs, regions=[Region([])]) + + @deprecated("Use GlobalOp(...) instead") + @staticmethod + def get( + global_type: Attribute, + sym_name: str | StringAttr, + linkage: str | LinkageAttr, + addr_space: int = 0, + constant: bool | None = None, + dso_local: bool | None = None, + thread_local_: bool | None = None, + value: Attribute | None = None, + alignment: int | None = None, + unnamed_addr: int | None = None, + section: str | StringAttr | None = None, + ): + return GlobalOp( + global_type, + sym_name, + linkage, + addr_space, + constant, + dso_local, + thread_local_, + value, + alignment, + unnamed_addr, + section, + ) @irdl_op_definition @@ -828,17 +915,25 @@ class AddressOfOp(IRDLOperation): global_name: SymbolRefAttr = attr_def(SymbolRefAttr) result: OpResult = result_def(LLVMPointerType) - @staticmethod - def get( - global_name: str | StringAttr | SymbolRefAttr, result_type: LLVMPointerType + def __init__( + self, + global_name: str | StringAttr | SymbolRefAttr, + result_type: LLVMPointerType, ): if isinstance(global_name, (StringAttr, str)): global_name = SymbolRefAttr(global_name) - return AddressOfOp.build( + super().__init__( attributes={"global_name": global_name}, result_types=[result_type] ) + @deprecated("Use AddressOfOp(...) instead") + @staticmethod + def get( + global_name: str | StringAttr | SymbolRefAttr, result_type: LLVMPointerType + ): + return AddressOfOp(global_name, result_type) + LLVM_CALLING_CONVS: set[str] = { "ccc", diff --git a/xdsl/transforms/lower_mpi.py b/xdsl/transforms/lower_mpi.py index ff6a7dad34..eaee7a3ea3 100644 --- a/xdsl/transforms/lower_mpi.py +++ b/xdsl/transforms/lower_mpi.py @@ -168,7 +168,7 @@ def _emit_mpi_status_objs( return ( [ lit1 := arith.Constant.from_int_and_width(1, builtin.i64), - res := llvm.IntToPtrOp.get(lit1), + res := llvm.IntToPtrOp(lit1), ], [], res, @@ -179,7 +179,7 @@ def _emit_mpi_status_objs( lit1 := arith.Constant.from_int_and_width( number_to_output, builtin.i64 ), - res := llvm.AllocaOp.get( + res := llvm.AllocaOp( lit1, builtin.IntegerType(8 * self.info.MPI_Status_size), as_untyped_ptr=True, @@ -313,7 +313,7 @@ def _memref_get_llvm_ptr(self, ref: SSAValue) -> tuple[list[Operation], Operatio return [ index := memref.ExtractAlignedPointerAsIndexOp.get(ref), i64 := arith.IndexCastOp.get(index, builtin.i64), - ptr := llvm.IntToPtrOp.get(i64), + ptr := llvm.IntToPtrOp(i64), ], ptr @@ -327,7 +327,7 @@ def lower(self, op: mpi.Init) -> tuple[list[Operation], list[SSAValue | None]]: We currently don't model any argument passing to `MPI_Init()` and pass two nullptrs. """ return [ - nullptr := llvm.NullOp.get(), + nullptr := llvm.NullOp(), func.Call(self._mpi_name(op), [nullptr, nullptr], [i32]), ], [] @@ -656,9 +656,7 @@ def lower( """ datatype_size = self._get_mpi_dtype_size(op.dtype) return [ - request := llvm.AllocaOp.get( - op.count, builtin.IntegerType(8 * datatype_size) - ), + request := llvm.AllocaOp(op.count, builtin.IntegerType(8 * datatype_size)), ], [request.results[0]] @@ -681,13 +679,13 @@ def lower( datatype_size = self._get_mpi_dtype_size(op.result.type) return [ - ptr_int := llvm.PtrToIntOp.get(op.vect, i64), + ptr_int := llvm.PtrToIntOp(op.vect, i64), lit1 := arith.Constant.from_int_and_width(datatype_size, 64), idx_cast1 := arith.IndexCastOp.get(op.element, IndexType()), idx_cast2 := arith.IndexCastOp.get(idx_cast1, i64), mul := arith.Muli(lit1, idx_cast2), add := arith.Addi(mul, ptr_int), - out_ptr := llvm.IntToPtrOp.get(add, op.vect.type.type), + out_ptr := llvm.IntToPtrOp(add, op.vect.type.type), ], [out_ptr.results[0]] @@ -707,9 +705,9 @@ def lower(self, op: mpi.CommRank) -> tuple[list[Operation], list[SSAValue | None self.info.MPI_COMM_WORLD, i32 ), lit1 := arith.Constant.from_int_and_width(1, 64), - int_ptr := llvm.AllocaOp.get(lit1, i32), + int_ptr := llvm.AllocaOp(lit1, i32), func.Call(self._mpi_name(op), [comm_global, int_ptr], [i32]), - rank := llvm.LoadOp.get(int_ptr), + rank := llvm.LoadOp(int_ptr), ], [rank.dereferenced_value] @@ -729,9 +727,9 @@ def lower(self, op: mpi.CommSize) -> tuple[list[Operation], list[SSAValue | None self.info.MPI_COMM_WORLD, i32 ), lit1 := arith.Constant.from_int_and_width(1, 64), - int_ptr := llvm.AllocaOp.get(lit1, i32), + int_ptr := llvm.AllocaOp(lit1, i32), func.Call(self._mpi_name(op), [comm_global, int_ptr], [i32]), - rank := llvm.LoadOp.get(int_ptr), + rank := llvm.LoadOp(int_ptr), ], [rank.dereferenced_value] @@ -787,7 +785,7 @@ def lower( assert isa(op.request.type, llvm.LLVMPointerType) return [ val := arith.Constant.from_int_and_width(self.info.MPI_REQUEST_NULL, i32), - llvm.StoreOp.get(val, op.request), + llvm.StoreOp(val, op.request), ], [] diff --git a/xdsl/transforms/printf_to_llvm.py b/xdsl/transforms/printf_to_llvm.py index 40d9a0027b..c1be2ee086 100644 --- a/xdsl/transforms/printf_to_llvm.py +++ b/xdsl/transforms/printf_to_llvm.py @@ -97,7 +97,7 @@ def _construct_global(self, val: str): t_type = builtin.TensorType.from_type_and_list(i8, [len(data)]) - return llvm.GlobalOp.get( + return llvm.GlobalOp( llvm.LLVMArrayType.from_size_and_type(len(data), i8), _key_from_str(val), constant=True,