From ba9d148b4170e4080e09128de99426c49c74810f Mon Sep 17 00:00:00 2001 From: Ryan Grout Date: Wed, 5 Dec 2018 21:22:31 -0600 Subject: [PATCH 1/2] Use a non-recursive implementation inspired by cytoolz. --- toolz/dicttoolz.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/toolz/dicttoolz.py b/toolz/dicttoolz.py index e07b35cd..84ba3a01 100644 --- a/toolz/dicttoolz.py +++ b/toolz/dicttoolz.py @@ -7,7 +7,6 @@ 'valfilter', 'keyfilter', 'itemfilter', 'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in') - def _get_factory(f, kwargs): factory = kwargs.pop('factory', dict) if kwargs: @@ -266,16 +265,28 @@ def update_in(d, keys, func, default=None, factory=dict): >>> update_in({1: 'foo'}, [2, 3, 4], inc, 0) {1: 'foo', 2: {3: {4: 1}}} """ - assert len(keys) > 0 - k, ks = keys[0], keys[1:] - if ks: - return assoc(d, k, update_in(d[k] if (k in d) else factory(), - ks, func, default, factory), - factory) - else: - innermost = func(d[k]) if (k in d) else func(default) - return assoc(d, k, innermost, factory) + ks = iter(keys) + k = next(ks) + + rv = inner = factory() + rv.update(d) + + for key in ks: + if k in d: + d = d[k] + dtemp = factory() + dtemp.update(d) + else: + d = dtemp = factory() + inner[k] = inner = dtemp + k = key + + if k in d: + inner[k] = func(d[k]) + else: + inner[k] = func(default) + return rv def get_in(keys, coll, default=None, no_default=False): """ Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys. From b2eb929e4c88c6876a8510e5229842556a001245 Mon Sep 17 00:00:00 2001 From: Ryan Grout Date: Wed, 5 Dec 2018 23:06:56 -0600 Subject: [PATCH 2/2] Fix linting error. --- toolz/dicttoolz.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/toolz/dicttoolz.py b/toolz/dicttoolz.py index 84ba3a01..dbdd9eb7 100644 --- a/toolz/dicttoolz.py +++ b/toolz/dicttoolz.py @@ -7,6 +7,7 @@ 'valfilter', 'keyfilter', 'itemfilter', 'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in') + def _get_factory(f, kwargs): factory = kwargs.pop('factory', dict) if kwargs: @@ -288,6 +289,7 @@ def update_in(d, keys, func, default=None, factory=dict): inner[k] = func(default) return rv + def get_in(keys, coll, default=None, no_default=False): """ Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys.