Skip to content

Commit

Permalink
Fix Encoder reference leak
Browse files Browse the repository at this point in the history
Properly DECREF the default method, if present.
  • Loading branch information
lelit committed Jan 8, 2018
1 parent 799bce0 commit 2212be7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
21 changes: 14 additions & 7 deletions rapidjson.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3066,6 +3066,7 @@ encoder_call(PyObject* self, PyObject* args, PyObject* kwargs)
PyObject* chunkSizeObj = NULL;
size_t chunkSize = 65536;
PyObject* defaultFn = NULL;
PyObject* result;

if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O$O",
(char**) kwlist,
Expand Down Expand Up @@ -3102,14 +3103,20 @@ encoder_call(PyObject* self, PyObject* args, PyObject* kwargs)
return NULL;
}
}
return do_stream_encode(value, stream, chunkSize, e->skipInvalidKeys, defaultFn,
e->sortKeys, e->ensureAscii, e->prettyPrint, e->indent,
e->numberMode, e->datetimeMode, e->uuidMode);
result = do_stream_encode(value, stream, chunkSize, e->skipInvalidKeys, defaultFn,
e->sortKeys, e->ensureAscii, e->prettyPrint, e->indent,
e->numberMode, e->datetimeMode, e->uuidMode);
}
else
return do_encode(value, e->skipInvalidKeys, defaultFn, e->sortKeys,
e->ensureAscii, e->prettyPrint, e->indent,
e->numberMode, e->datetimeMode, e->uuidMode);
else {
result = do_encode(value, e->skipInvalidKeys, defaultFn, e->sortKeys,
e->ensureAscii, e->prettyPrint, e->indent,
e->numberMode, e->datetimeMode, e->uuidMode);
}

if (defaultFn != NULL)
Py_DECREF(defaultFn);

return result;
}


Expand Down
19 changes: 19 additions & 0 deletions tests/test_refs_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,22 @@ def string(self, string):
del value
rc1 = sys.gettotalrefcount()
assert (rc1 - rc0) < THRESHOLD


@pytest.mark.skipif(not hasattr(sys, 'gettotalrefcount'), reason='Non-debug Python')
def test_encoder_call_leaks():
class MyEncoder(rj.Encoder):
def default(self, obj):
return 'Foo'

class Foo:
pass

encoder = MyEncoder()
foo = Foo()
rc0 = sys.gettotalrefcount()
for i in range(1000):
value = encoder(foo)
del value
rc1 = sys.gettotalrefcount()
assert (rc1 - rc0) < THRESHOLD

0 comments on commit 2212be7

Please sign in to comment.