Skip to content

Commit

Permalink
Replace itertools.tee with list
Browse files Browse the repository at this point in the history
  • Loading branch information
mariosasko committed Dec 2, 2020
1 parent 2a14d00 commit 2a6b375
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions torch/optim/sparse_adam.py
@@ -1,4 +1,3 @@
import itertools
import math
import torch
from .optimizer import Optimizer
Expand Down Expand Up @@ -33,12 +32,10 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))

# if params are in the form of a generator, the next for-loop exhausts it,
# so the copy is passed to the loop
params, params_copy = itertools.tee(params)
params = list(params)

sparse_params = []
for index, param in enumerate(params_copy):
for index, param in enumerate(params):
if isinstance(param, dict):
for d_index, d_param in enumerate(param.get("params", [])):
if d_param.is_sparse:
Expand Down

0 comments on commit 2a6b375

Please sign in to comment.