Skip to content

Assign operation #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 183 additions & 44 deletions mlir_graphblas/implementations.py

Large diffs are not rendered by default.

182 changes: 151 additions & 31 deletions mlir_graphblas/operations.py

Large diffs are not rendered by default.

32 changes: 26 additions & 6 deletions mlir_graphblas/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def __init__(self, func, *, input=None, output=None):
If input is defined, it must be one of (bool, int, float) to
indicate the restricted allowable input dtypes
If output is defined, it must be one of (bool, int, float) to
indicate the output will always be of that type
indicate the output will always be of that type or an
an integer (0, 1, ...) to indicate that the output dtype
will match the dtype of argument 0, 1, ...
"""
super().__init__(func.__name__)
self.func = func
Expand All @@ -70,7 +72,8 @@ def __init__(self, func, *, input=None, output=None):
self.input = input
# Validate output
if output is not None:
assert output in {bool, int, float}
if type(output) is not int:
assert output in {bool, int, float}
self.output = output

@classmethod
Expand All @@ -94,9 +97,19 @@ def validate_input(self, input_val):
elif self.input is float and not val_dtype.is_float():
raise GrbDomainMismatch("input must be float type")

def get_output_type(self, input_dtype):
def get_output_type(self, left_input_dtype, right_input_dtype=None):
if self.output is None:
return input_dtype
if right_input_dtype is None:
return left_input_dtype
if left_input_dtype != right_input_dtype:
raise TypeError(f"Unable to infer output type from {left_input_dtype} and {right_input_dtype}")
return left_input_dtype
elif self.output == 0:
return left_input_dtype
elif self.output == 1:
if right_input_dtype is None:
raise TypeError("No type provided for expected 2nd input argument")
return right_input_dtype
return self._type_convert[self.output]


Expand Down Expand Up @@ -132,6 +145,13 @@ def name_of_op(x, y, input_dtype):
def __call__(self, x, y):
dtype = self._dtype_of(x)
dtype2 = self._dtype_of(y)
if self.output == 0:
self.validate_input(x)
return self.func(x, y, dtype)
if self.output == 1:
self.validate_input(y)
return self.func(x, y, dtype2)
# If we reached this point, inputs must have the same dtype
if dtype is not dtype2:
raise TypeError(f"Types must match, {dtype} != {dtype2}")
self.validate_input(x)
Expand Down Expand Up @@ -417,12 +437,12 @@ def oneb(x, y, dtype):
BinaryOp.pair = BinaryOp.oneb


@BinaryOp._register
@BinaryOp._register(output=0) # dtype matches x
def first(x, y, dtype):
return x


@BinaryOp._register
@BinaryOp._register(output=1) # dtype matches y
def second(x, y, dtype):
return y

Expand Down
37 changes: 30 additions & 7 deletions mlir_graphblas/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ def __repr__(self):
return f'Vector<{self.dtype.gb_name}, size={self.shape[0]}>'

@classmethod
def new(cls, dtype, size: int):
return cls(dtype, (size,))
def new(cls, dtype, size: int, *, intermediate_result=False):
return cls(dtype, (size,), intermediate_result=intermediate_result)

def resize(self, size: int):
raise NotImplementedError()
Expand All @@ -271,6 +271,7 @@ def build(self, indices, values, *, dup=None, sparsity=None):

indices: list or numpy array of int
values: list or numpy array with matching dtype as declared in `.new()`
can also be a scalar value to make the Vector iso-valued
dup: BinaryOp used to combined entries with the same index
NOTE: this is currently not support; passing dup will raise an error
sparsity: list of string or DimLevelType
Expand All @@ -287,7 +288,17 @@ def build(self, indices, values, *, dup=None, sparsity=None):
if not isinstance(indices, np.ndarray):
indices = np.array(indices, dtype=np.uint64)
if not isinstance(values, np.ndarray):
values = np.array(values, dtype=self.dtype.np_type)
if hasattr(values, '__len__'):
values = np.array(values, dtype=self.dtype.np_type)
else:
if type(values) is Scalar:
if values.dtype != self.dtype:
raise TypeError("Scalar value must have same dtype as Vector")
if values.nvals() == 0:
# Empty Scalar means nothing to build
return
values = values.extract_element()
values = np.ones(indices.shape, dtype=self.dtype.np_type) * values
if sparsity is None:
sparsity = [DimLevelType.compressed]
self._to_sparse_tensor(indices, values, sparsity=sparsity, ordering=[0])
Expand Down Expand Up @@ -329,8 +340,8 @@ def is_colwise(self):
return tuple(self._ordering) != self.permutation

@classmethod
def new(cls, dtype, nrows: int, ncols: int):
return cls(dtype, (nrows, ncols))
def new(cls, dtype, nrows: int, ncols: int, *, intermediate_result=False):
return cls(dtype, (nrows, ncols), intermediate_result=intermediate_result)

def diag(self, k: int):
raise NotImplementedError()
Expand All @@ -356,13 +367,15 @@ def nvals(self):

return nvals(self)

def build(self, row_indices, col_indices, values, *, dup=None, sparsity=None, colwise=False):
def build(self, row_indices, col_indices, values, *,
dup=None, sparsity=None, colwise=False):
"""
Build the underlying MLIRSparseTensor structure from COO.

row_indices: list or numpy array of int
col_indices: list or numpy array of int
values: list or numpy array with matching dtype as declared in `.new()`
can also be a scalar value to make the Vector iso-valued
dup: BinaryOp used to combined entries with the same (row, col) coordinate
NOTE: this is currently not support; passing dup will raise an error
sparsity: list of string or DimLevelType
Expand All @@ -383,7 +396,17 @@ def build(self, row_indices, col_indices, values, *, dup=None, sparsity=None, co
col_indices = np.array(col_indices, dtype=np.uint64)
indices = np.stack([row_indices, col_indices], axis=1)
if not isinstance(values, np.ndarray):
values = np.array(values, dtype=self.dtype.np_type)
if hasattr(values, '__len__'):
values = np.array(values, dtype=self.dtype.np_type)
else:
if type(values) is Scalar:
if values.dtype != self.dtype:
raise TypeError("Scalar value must have same dtype as Matrix")
if values.nvals() == 0:
# Empty Scalar means nothing to build
return
values = values.extract_element()
values = np.ones(indices.shape, dtype=self.dtype.np_type) * values
ordering = [1, 0] if colwise else [0, 1]
if sparsity is None:
sparsity = [DimLevelType.dense, DimLevelType.compressed]
Expand Down
Loading