-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathnodeset.py
112 lines (93 loc) · 3.32 KB
/
nodeset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from collections.abc import MutableSet
from graphblas.semiring import any_pair, plus_pair
from . import _utils
class NodeSet(MutableSet):
def __init__(self, v, *, key_to_id=None):
self.vector = v
if key_to_id is None:
self._key_to_id = {i: i for i in range(v.size)}
else:
self._key_to_id = key_to_id
self._id_to_key = None
id_to_key = property(_utils.id_to_key)
# get_property = _utils.get_property
# get_properties = _utils.get_properties
dict_to_vector = _utils.dict_to_vector
list_to_vector = _utils.list_to_vector
list_to_mask = _utils.list_to_mask
list_to_ids = _utils.list_to_ids
list_to_keys = _utils.list_to_keys
matrix_to_dicts = _utils.matrix_to_dicts
set_to_vector = _utils.set_to_vector
# to_networkx = _utils.to_networkx
vector_to_dict = _utils.vector_to_dict
vector_to_list = _utils.vector_to_list
vector_to_nodemap = _utils.vector_to_nodemap
vector_to_nodeset = _utils.vector_to_nodeset
vector_to_set = _utils.vector_to_set
# _cacheit = _utils._cacheit
# Requirements for MutableSet
def __contains__(self, x):
idx = self._key_to_id[x]
return idx in self.vector
def __iter__(self):
# Slow if we iterate over one; fast if we iterate over all
return map(
self.id_to_key.__getitem__, self.vector.to_coo(values=False, sort=False)[0].tolist()
)
def __len__(self):
return self.vector.nvals
def add(self, value):
idx = self._key_to_id[value]
self.vector[idx] = True
def discard(self, value):
idx = self._key_to_id[value]
del self.vector[idx]
# Override other MutableSet methods
def __eq__(self, other):
if isinstance(other, NodeSet):
a = self.vector
b = other.vector
return (
a.size == b.size
and (nvals := a.nvals) == b.nvals
and plus_pair(a @ b).get(0) == nvals
and self._key_to_id == other._key_to_id
)
return super().__eq__(other)
# __and__
# __or__
# __sub__
# __xor__
def clear(self):
self.vector.clear()
def isdisjoin(self, other):
if isinstance(other, NodeSet):
return not any_pair[bool](self.vector @ other.vector)
return super().isdisjoint(other)
def pop(self):
try:
idx = next(self.vector.ss.iterkeys())
except StopIteration:
raise KeyError from None
del self.vector[idx]
return self.id_to_key[idx]
def remove(self, value):
idx = self._key_to_id[value]
if idx not in self.vector:
raise KeyError(value)
del self.vector[idx]
def _from_iterable(self, it):
# The elements in the iterable must be contained within key_to_id
rv = object.__new__(type(self))
rv._key_to_id = self._key_to_id
rv._id_to_key = self._id_to_key
rv.vector = rv.set_to_vector(it, size=self.vector.size)
return rv
# Add more set methods (as needed)
def union(self, *args):
return set(self).union(*args) # TODO: can we make this better?
def copy(self):
rv = type(self)(self.vector.dup(), key_to_id=self._key_to_id)
rv._id_to_key = self._id_to_key
return rv