Skip to content

Commit

Permalink
Replace torch._utils._accumulate with numpy.cumsum.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Zickel committed May 3, 2024
1 parent 60b9cc3 commit 20c25a5
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 17 deletions.
2 changes: 1 addition & 1 deletion examples/contrib/mue/FactorMuE.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main(args):
indices = torch.randperm(sum(data_lengths), device=device).tolist()
dataset_train, dataset_test = [
torch.utils.data.Subset(dataset, indices[(offset - length) : offset])
for offset, length in zip(pyro.util._accumulate(data_lengths), data_lengths)
for offset, length in zip(np.cumsum(data_lengths), data_lengths)
]
else:
dataset_train = dataset
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/mue/ProfileHMM.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def main(args):
indices = torch.randperm(sum(data_lengths), device=device).tolist()
dataset_train, dataset_test = [
torch.utils.data.Subset(dataset, indices[(offset - length) : offset])
for offset, length in zip(pyro.util._accumulate(data_lengths), data_lengths)
for offset, length in zip(np.cumsum(data_lengths), data_lengths)
]
else:
dataset_train = dataset
Expand Down
15 changes: 0 additions & 15 deletions pyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,18 +641,3 @@ def __exit__(self, exc_type, exc_val, exc_tb):

def torch_float(x):
return x.float() if isinstance(x, torch.Tensor) else float(x)


def _accumulate(iterable, fn=lambda x, y: x + y):
"Return running totals"
# _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
# _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
it = iter(iterable)
try:
total = next(it)
except StopIteration:
return
yield total
for element in it:
total = fn(total, element)
yield total

0 comments on commit 20c25a5

Please sign in to comment.