From 2a6b375639ef1669a8fb37d57b9e40896c4ec669 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Thu, 3 Dec 2020 00:30:05 +0100 Subject: [PATCH] Replace itertools.tee with list --- torch/optim/sparse_adam.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/torch/optim/sparse_adam.py b/torch/optim/sparse_adam.py index 7244d2adc428..909aa0c6cc62 100644 --- a/torch/optim/sparse_adam.py +++ b/torch/optim/sparse_adam.py @@ -1,4 +1,3 @@ -import itertools import math import torch from .optimizer import Optimizer @@ -33,12 +32,10 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8): if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - # if params are in the form of a generator, the next for-loop exhausts it, - # so the copy is passed to the loop - params, params_copy = itertools.tee(params) + params = list(params) sparse_params = [] - for index, param in enumerate(params_copy): + for index, param in enumerate(params): if isinstance(param, dict): for d_index, d_param in enumerate(param.get("params", [])): if d_param.is_sparse: