Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding single context window #5

Merged
merged 3 commits into from
Apr 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
102 changes: 81 additions & 21 deletions libs/residual2vec/residual2vec/residual2vec_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
q=1,
cuda=False,
buffer_size=100000,
context_window_type="double",
miniters=200,
):
"""Residual2Vec based on the stochastic gradient descent.
Expand All @@ -108,6 +109,8 @@ def __init__(
:type q: float, optional
:param buffer_size: Buffer size for sampled center and context pairs, defaults to 10000
:type buffer_size: int, optional
:param context_window_type: The type of context window. `context_window_type="double"` specifies a context window that extends both left and right of a focal node. context_window_type="left" and ="right" specifies that extends left and right, respectively.
:type context_window_type: str, optional
:param miniter: Minimum number of iterations, defaults to 200
:type miniter: int, optional
"""
Expand All @@ -121,6 +124,7 @@ def __init__(
self.batch_size = batch_size
self.buffer_size = buffer_size
self.miniters = miniters
self.context_window_type = context_window_type

def fit(self, adjmat):
"""Learn the graph structure to generate the node embeddings.
Expand Down Expand Up @@ -179,6 +183,7 @@ def transform(self, dim):
p=self.p,
q=self.q,
buffer_size=self.buffer_size,
context_window_type=self.context_window_type,
)
dataloader = DataLoader(
dataset,
Expand Down Expand Up @@ -221,6 +226,7 @@ def __init__(
walk_length=40,
p=1.0,
q=1.0,
context_window_type="double",
buffer_size=100000,
):
"""Dataset for training word2vec with negative sampling.
Expand All @@ -241,6 +247,8 @@ def __init__(
:type p: float, optional
:param q: node2vec parameter q (1/q) is the weights of the edges to nodes that are not directly connected to the previously visted node, defaults to 1
:type q: float, optional
:param context_window_type: The type of context window. `context_window_type="double"` specifies a context window that extends both left and right of a focal node. context_window_type="left" and ="right" specifies that extends left and right, respectively.
:type context_window_type: str, optional
:param buffer_size: Buffer size for sampled center and context pairs, defaults to 10000
:type buffer_size: int, optional
"""
Expand All @@ -250,6 +258,9 @@ def __init__(
self.noise_sampler = noise_sampler
self.walk_length = walk_length
self.padding_id = padding_id
self.context_window_type = {"double": 0, "left": -1, "right": 1}[
context_window_type
]
self.rw_sampler = RandomWalkSampler(adjmat, walk_length=walk_length, p=p, q=q)
self.node_order = np.random.choice(
adjmat.shape[0], adjmat.shape[0], replace=False
Expand Down Expand Up @@ -293,10 +304,11 @@ def _generate_samples(self):
self.node_order[self.scanned_node_id : next_scanned_node_id]
)
self.centers, self.contexts = _get_center_context(
walks,
walks.shape[0],
walks.shape[1],
self.window_length,
context_window_type=self.context_window_type,
walks=walks,
n_walks=walks.shape[0],
walk_len=walks.shape[1],
window_length=self.window_length,
padding_id=self.padding_id,
)
self.random_contexts = self.noise_sampler.sampling(
Expand All @@ -310,31 +322,79 @@ def _generate_samples(self):
self.sample_id = 0


def _get_center_context(
context_window_type, walks, n_walks, walk_len, window_length, padding_id
):
"""Get center and context pairs from a sequence
window_type = {-1,0,1} specifies the type of context window.
window_type = 0 specifies a context window of length window_length that extends both
left and right of a center word. window_type = -1 and 1 specifies a context window
that extends either left or right of a center word, respectively.
"""
if context_window_type == 0:
return _get_center_double_context_windows(
walks, n_walks, walk_len, window_length, padding_id
)
elif context_window_type == -1:
return _get_center_single_context_window(
walks, n_walks, walk_len, window_length, padding_id, is_left_window=True
)
elif context_window_type == 1:
return _get_center_single_context_window(
walks, n_walks, walk_len, window_length, padding_id, is_left_window=False
)
else:
raise ValueError("Unknown window type")


@njit(nogil=True)
def _get_center_context(walks, n_walks, walk_len, window_length, padding_id):
def _get_center_double_context_windows(
walks, n_walks, walk_len, window_length, padding_id
):
centers = np.zeros(n_walks * walk_len, dtype=np.int64)
contexts = np.zeros((n_walks * walk_len, 2 * window_length), dtype=np.int64)
contexts = padding_id * np.ones(
(n_walks * walk_len, 2 * window_length), dtype=np.int64
)
for t_walk in range(walk_len):
start, end = n_walks * t_walk, n_walks * (t_walk + 1)
centers[start:end] = walks[:, t_walk]
contexts[start:end, :] = _get_context(
walks, n_walks, walk_len, t_walk, window_length, padding_id
)

for i in range(window_length):
if t_walk - 1 - i < 0:
break
contexts[start:end, window_length - 1 - i] = walks[:, t_walk - 1 - i]

for i in range(window_length):
if t_walk + 1 + i >= walk_len:
break
contexts[start:end, window_length + i] = walks[:, t_walk + 1 + i]

order = np.arange(walk_len * n_walks)
random.shuffle(order)
return centers[order], contexts[order, :]


@njit(nogil=True)
def _get_context(walks, n_walks, walk_len, t_walk, window_length, padding_id):
retval = padding_id * np.ones((n_walks, 2 * window_length), dtype=np.int64)
for i in range(window_length):
if t_walk - 1 - i < 0:
break
retval[:, window_length - 1 - i] = walks[:, t_walk - 1 - i]

for i in range(window_length):
if t_walk + 1 + i >= walk_len:
break
retval[:, window_length + i] = walks[:, t_walk + 1 + i]
return retval
def _get_center_single_context_window(
walks, n_walks, walk_len, window_length, padding_id, is_left_window=True
):
centers = np.zeros(n_walks * walk_len, dtype=np.int64)
contexts = padding_id * np.ones((n_walks * walk_len, window_length), dtype=np.int64)
for t_walk in range(walk_len):
start, end = n_walks * t_walk, n_walks * (t_walk + 1)
centers[start:end] = walks[:, t_walk]

if is_left_window:
for i in range(window_length):
if t_walk - 1 - i < 0:
break
contexts[start:end, window_length - 1 - i] = walks[:, t_walk - 1 - i]
else:
for i in range(window_length):
if t_walk + 1 + i >= walk_len:
break
contexts[start:end, i] = walks[:, t_walk + 1 + i]

order = np.arange(walk_len * n_walks)
random.shuffle(order)
return centers[order], contexts[order, :]
2 changes: 1 addition & 1 deletion libs/residual2vec/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from setuptools import find_packages, setup

__version__ = "0.0.6"
__version__ = "0.0.7"


def load_requires_from_file(fname):
Expand Down