Skip to content

Commit

Permalink
cleaning up get methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ShaolunWang committed Jul 26, 2023
1 parent 5c2dbdf commit 703446c
Showing 1 changed file with 21 additions and 121 deletions.
142 changes: 21 additions & 121 deletions xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,34 +523,7 @@ def get(
magic values.
"""
# construct default mutable argument here:
if ssa_indices is None:
ssa_indices = []

# convert a potential Operation into an SSAValue
ptr_val = SSAValue.get(ptr)
ptr_type = ptr_val.type

if not isinstance(ptr_type, LLVMPointerType):
raise ValueError("Input must be a pointer")

attrs: dict[str, Attribute] = {
"rawConstantIndices": DenseArrayBase.create_dense_int_or_index(
i32, indices
),
}

if not ptr_type.is_typed():
if pointee_type is None:
raise ValueError("Opaque types must have a pointee type passed")
# opaque input ptr => opaque output ptr
attrs["elem_type"] = LLVMPointerType.opaque()

if inbounds:
attrs["inbounds"] = UnitAttr()

return GEPOp.build(
operands=[ptr, ssa_indices], result_types=[result_type], attributes=attrs
)
return GEPOp(ptr, indices, ssa_indices, result_type, inbounds, pointee_type)

@staticmethod
def from_mixed_indices(
Expand Down Expand Up @@ -623,18 +596,7 @@ def get(
alignment: int = 32,
as_untyped_ptr: bool = False,
):
attrs: dict[str, Attribute] = {
"alignment": IntegerAttr.from_int_and_width(alignment, 64)
}
if as_untyped_ptr:
ptr_type = LLVMPointerType.opaque()
attrs["elem_type"] = elem_type
else:
ptr_type = LLVMPointerType.typed(elem_type)

return AllocaOp.build(
operands=[size], attributes=attrs, result_types=[ptr_type]
)
return AllocaOp(size, elem_type, alignment, as_untyped_ptr)


@irdl_op_definition
Expand All @@ -655,11 +617,7 @@ def __init__(self, input: SSAValue | Operation, ptr_type: Attribute | None = Non
@deprecated("Use IntToPtrOp(...) instead")
@staticmethod
def get(input: SSAValue | Operation, ptr_type: Attribute | None = None):
if ptr_type is None:
ptr_type = LLVMPointerType.opaque()
else:
ptr_type = LLVMPointerType.typed(ptr_type)
return IntToPtrOp.build(operands=[input], result_types=[ptr_type])
return IntToPtrOp(input, ptr_type)


@irdl_op_definition
Expand All @@ -676,7 +634,7 @@ def __init__(self, arg: SSAValue | Operation, int_type: Attribute = i64):
@deprecated("Use PtrToIntOp(...) instead")
@staticmethod
def get(arg: SSAValue | Operation, int_type: Attribute = i64):
return PtrToIntOp.build(operands=[arg], result_types=[int_type])
return PtrToIntOp(arg, int_type)


@irdl_op_definition
Expand All @@ -703,17 +661,7 @@ def __init__(self, ptr: SSAValue | Operation, result_type: Attribute | None = No
@deprecated("Use LoadOp(...) instead")
@staticmethod
def get(ptr: SSAValue | Operation, result_type: Attribute | None = None):
if result_type is None:
ptr = SSAValue.get(ptr)
assert isinstance(ptr.type, LLVMPointerType)

if isinstance(ptr.type.type, NoneAttr):
raise ValueError(
"llvm.load requires either a result type or a typed pointer!"
)
result_type = ptr.type.type

return LoadOp.build(operands=[ptr], result_types=[result_type])
return LoadOp(ptr, result_type)


@irdl_op_definition
Expand Down Expand Up @@ -764,22 +712,7 @@ def get(
volatile: bool = False,
nontemporal: bool = False,
):
attrs: dict[str, Attribute] = {
"ordering": IntegerAttr(ordering, i64),
}

if alignment is not None:
attrs["alignment"] = IntegerAttr[IntegerType](alignment, i64)
if volatile:
attrs["volatile_"] = UnitAttr()
if nontemporal:
attrs["nontemporal"] = UnitAttr()

return StoreOp.build(
operands=[value, ptr],
attributes=attrs,
result_types=[],
)
return StoreOp(value, ptr, alignment, ordering, volatile, nontemporal)


@irdl_op_definition
Expand All @@ -798,11 +731,7 @@ def __init__(self, ptr_type: LLVMPointerType | None = None):
@deprecated("Use NullOp(...) instead")
@staticmethod
def get(ptr_type: LLVMPointerType | None = None):
if ptr_type is None:
ptr_type = LLVMPointerType.opaque()
assert isinstance(ptr_type, LLVMPointerType)

return NullOp.build(result_types=[ptr_type])
return NullOp(ptr_type)


@irdl_op_definition
Expand Down Expand Up @@ -964,43 +893,19 @@ def get(
unnamed_addr: int | None = None,
section: str | StringAttr | None = None,
):
if isinstance(sym_name, str):
sym_name = StringAttr(sym_name)

if isinstance(linkage, str):
linkage = LinkageAttr(linkage)

attrs: dict[str, Attribute] = {
"global_type": global_type,
"sym_name": sym_name,
"linkage": linkage,
"addr_space": IntegerAttr(addr_space, 32),
}

if constant is not None and constant:
attrs["constant"] = UnitAttr()

if dso_local is not None and dso_local:
attrs["dso_local"] = UnitAttr()

if thread_local_ is not None and thread_local_:
attrs["thread_local_"] = UnitAttr()

if value is not None:
attrs["value"] = value

if alignment is not None:
attrs["alignment"] = IntegerAttr(alignment, 64)

if unnamed_addr is not None:
attrs["unnamed_addr"] = IntegerAttr(unnamed_addr, 64)

if section is not None:
if isinstance(section, str):
section = StringAttr(section)
attrs["section"] = section

return GlobalOp.build(attributes=attrs, regions=[Region([])])
return GlobalOp(
global_type,
sym_name,
linkage,
addr_space,
constant,
dso_local,
thread_local_,
value,
alignment,
unnamed_addr,
section,
)


@irdl_op_definition
Expand All @@ -1027,12 +932,7 @@ def __init__(
def get(
global_name: str | StringAttr | SymbolRefAttr, result_type: LLVMPointerType
):
if isinstance(global_name, (StringAttr, str)):
global_name = SymbolRefAttr(global_name)

return AddressOfOp.build(
attributes={"global_name": global_name}, result_types=[result_type]
)
return AddressOfOp(global_name, result_type)


LLVM_CALLING_CONVS: set[str] = {
Expand Down

0 comments on commit 703446c

Please sign in to comment.