-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
38 changed files
with
656 additions
and
348 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,3 +105,6 @@ venv.bak/ | |
|
||
# dir | ||
logdir/ | ||
|
||
# Untitled.ipynb | ||
Untitled.ipynb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,34 +1,32 @@ | ||
from . import _NegativeSampler | ||
import torch | ||
|
||
from typing import Dict | ||
|
||
class MultinomialSampler(_NegativeSampler): | ||
r"""MutlinomialSampler is to generate negative samplers by multinomial distribution, i.e. draw samples by given probabilities | ||
""" | ||
def __init__(self, | ||
weights : torch.Tensor, | ||
with_replacement : bool = True): | ||
r"""Initialize a Negative sampler which draw samples with multinomial distribution | ||
@staticmethod | ||
def _getlen(v: Dict[str, int]) -> int: | ||
r"""Get length of field. | ||
Args: | ||
weights (torch.Tensor): weights (probabilities) to draw samples, with shape = (total number of words in dictionary, ). | ||
with_replacement (bool, optional): boolean flag to control the replacement of sampling. Defaults to True. | ||
Returns: | ||
int: Length of field. | ||
""" | ||
self.with_replacement = with_replacement | ||
if isinstance(weights, torch.Tensor): | ||
self.weights = weights | ||
else: | ||
self.weights = torch.Tensor(weights) | ||
self.dict_size = len(self.weights) | ||
return len(v["weights"]) | ||
|
||
def generate(self, size: int) -> torch.Tensor: | ||
r"""Return drawn samples. | ||
def _generate(self, | ||
weights : torch.Tensor, | ||
with_replacement : bool, | ||
size : int) -> torch.Tensor: | ||
"""A function to generate negative samples with multinomial distribution. | ||
Args: | ||
size (int): Number of negative samples to be drawn | ||
weights (torch.Tensor): the input tensor containing probabilities | ||
with_replacement (bool): whether to draw with replacement or not | ||
size (int): number of samples to draw | ||
Returns: | ||
torch.Tensor, shape = (size, 1), dtype = torch.long: Drawn negative samples | ||
T, shape = (N * Nneg, 1), dtype = torch.long: Tensor of negative samples generated by multinomial distribution. | ||
""" | ||
samples = torch.multinomial(self.weights, size, replacement=self.with_replacement) | ||
return samples.long() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,67 +1,62 @@ | ||
from . import _NegativeSampler | ||
import torch | ||
|
||
from typing import Dict | ||
|
||
class UniformSamplerWithReplacement(_NegativeSampler): | ||
r"""UniformSamplerWithReplacement is to generate negative samplers by uniform distribution with replacement, i.e. draw samples uniformlly with replacement | ||
""" | ||
def __init__(self, | ||
low : int, | ||
high : int): | ||
r"""Initialize a Negative sampler which draw samples with uniform distribution with replacement | ||
@staticmethod | ||
def _getlen(v: Dict[str, int]) -> int: | ||
r"""Get length of field. | ||
Args: | ||
low (int): minimum value (i.e. lower bound) of sampling id. | ||
high (int): maximum value (i.e. upper bound) of sampling id. | ||
Returns: | ||
int: Length of field. | ||
""" | ||
self.low = low | ||
self.high = high | ||
self.dict_size = self.high - self.low | ||
|
||
def generate(self, size: int) -> torch.Tensor: | ||
r"""Return drawn samples. | ||
return v["high"] - v["low"] | ||
|
||
@staticmethod | ||
def _generate(low : int, | ||
high : int, | ||
size : int) -> torch.Tensor: | ||
r"""A function to generate negative samples with uniform distribution with replacement. | ||
Args: | ||
size (int): Number of negative samples to be drawn | ||
low (int): Lowest integer to be drawn from the distribution. | ||
high (int): One above the highest integer to be drawn from the distribution. | ||
size (int): An integer defining the shape of the output tensor. | ||
Returns: | ||
torch.Tensor, shape = (size, 1), dtype = torch.long: Drawn negative samples | ||
T, shape = (N * Nneg, 1), dtype = torch.long: Tensor of negative samples generated by uniform distribution. | ||
""" | ||
return torch.randint(low=self.low, high=self.high, size=(size, )).long() | ||
return torch.randint(low=low, high=high, size=(size, 1)).long() | ||
|
||
|
||
class UniformSamplerWithoutReplacement(_NegativeSampler): | ||
r"""UniformSamplerWithReplacement is to generate negative samplers by uniform distribution without replacement, i.e. draw samples uniformlly without replacement | ||
""" | ||
def __init__(self, | ||
low : int, | ||
high : int): | ||
r"""Initialize a Negative sampler which draw samples with uniform distribution without replacement | ||
@staticmethod | ||
def _getlen(v: Dict[str, int]) -> int: | ||
r"""Get length of field. | ||
Returns: | ||
int: Length of field. | ||
""" | ||
return v["high"] - v["low"] | ||
|
||
@staticmethod | ||
def _generate(low : int, | ||
high : int, | ||
size : int) -> torch.Tensor: | ||
"""A function to generate negative samples with uniform distribution without replacement. | ||
Args: | ||
low (int): minimum value (i.e. lower bound) of sampling id. | ||
high (int): maximum value (i.e. upper bound) of sampling id. | ||
""" | ||
self.low = low | ||
self.high = high | ||
self.dict_size = self.high - self.low | ||
|
||
def generate(self, size: int) -> torch.Tensor: | ||
r"""Generate negative samples by the sampler | ||
Args: | ||
size (int): Number of negative samples to be drawn | ||
Raises: | ||
ValueError: if input size is larger than the size of dictionary (i.e. high - low) | ||
size (int): An integer of defining the shape of the output tensor. | ||
Returns: | ||
torch.Tensor, shape = (size, 1), dtype = torch.long: Drawn negative samples | ||
T, shape = (N * Nneg, 1), dtype = torch.long: Tensor of negative samples generated by uniform distribution. | ||
""" | ||
|
||
if size >= (self.high - self.low): | ||
raise ValueError("input size cannot be larger than size of samples.") | ||
|
||
samples = torch.randperm(n=self.high) + self.low | ||
samples = torch.randperm(n=high) + low | ||
samples = samples[:size] | ||
return samples.long() | ||
return samples.view(-1, 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.