Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-38005: Fixed comparing and creating of InterpreterID and ChannelID. #15652

Merged
merged 9 commits into from Sep 13, 2019
44 changes: 28 additions & 16 deletions Lib/test/test__xxsubinterpreters.py
Expand Up @@ -526,20 +526,17 @@ def test_with_int(self):
self.assertEqual(int(id), 10)

def test_coerce_id(self):
id = interpreters.InterpreterID('10', force=True)
self.assertEqual(int(id), 10)

id = interpreters.InterpreterID(10.0, force=True)
self.assertEqual(int(id), 10)

class Int(str):
def __init__(self, value):
self._value = value
def __int__(self):
return self._value

id = interpreters.InterpreterID(Int(10), force=True)
self.assertEqual(int(id), 10)
for id in ('10', b'10', bytearray(b'10'), memoryview(b'10'), '1_0',
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a questionable behavior to me.

10.0, 10.1, Int(10)):
with self.subTest(id=id):
id = interpreters.InterpreterID(id, force=True)
self.assertEqual(int(id), 10)

def test_bad_id(self):
for id in [-1, 'spam']:
Expand All @@ -548,6 +545,8 @@ def test_bad_id(self):
interpreters.InterpreterID(id)
with self.assertRaises(OverflowError):
interpreters.InterpreterID(2**64)
with self.assertRaises(OverflowError):
interpreters.InterpreterID(float('inf'))
with self.assertRaises(TypeError):
interpreters.InterpreterID(object())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have a test still for object().

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced it with []. But if you prefer, I'll return object().


Expand All @@ -572,6 +571,14 @@ def test_equality(self):
self.assertTrue(id1 == id1)
self.assertTrue(id1 == id2)
self.assertTrue(id1 == int(id1))
self.assertTrue(int(id1) == id1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice :)

self.assertTrue(id1 == float(int(id1)))
self.assertTrue(float(int(id1)) == id1)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one test crashed with the past implementation.

self.assertFalse(id1 == float(int(id1)) + 0.1)
self.assertFalse(id1 == str(int(id1)))
self.assertFalse(id1 == 2**1000)
self.assertFalse(id1 == float('inf'))
self.assertFalse(id1 == 'spam')
self.assertFalse(id1 == id3)

self.assertFalse(id1 != id1)
Expand Down Expand Up @@ -1105,20 +1112,17 @@ def test_with_kwargs(self):
self.assertEqual(cid.end, 'both')

def test_coerce_id(self):
cid = interpreters._channel_id('10', force=True)
self.assertEqual(int(cid), 10)

cid = interpreters._channel_id(10.0, force=True)
self.assertEqual(int(cid), 10)

class Int(str):
def __init__(self, value):
self._value = value
def __int__(self):
return self._value

cid = interpreters._channel_id(Int(10), force=True)
self.assertEqual(int(cid), 10)
for id in ('10', b'10', bytearray(b'10'), memoryview(b'10'), '1_0',
10.0, 10.1, Int(10)):
with self.subTest(id=id):
cid = interpreters._channel_id(id, force=True)
self.assertEqual(int(cid), 10)

def test_bad_id(self):
for cid in [-1, 'spam']:
Expand Down Expand Up @@ -1164,6 +1168,14 @@ def test_equality(self):
self.assertTrue(cid1 == cid1)
self.assertTrue(cid1 == cid2)
self.assertTrue(cid1 == int(cid1))
self.assertTrue(int(cid1) == cid1)
self.assertTrue(cid1 == float(int(cid1)))
self.assertTrue(float(int(cid1)) == cid1)
self.assertFalse(cid1 == float(int(cid1)) + 0.1)
self.assertFalse(cid1 == str(int(cid1)))
self.assertFalse(cid1 == 2**1000)
self.assertFalse(cid1 == float('inf'))
self.assertFalse(cid1 == 'spam')
self.assertFalse(cid1 == cid3)

self.assertFalse(cid1 != cid1)
Expand Down
@@ -0,0 +1 @@
Fixed comparing and creating of InterpreterID and ChannelID.
38 changes: 18 additions & 20 deletions Modules/_xxsubinterpretersmodule.c
Expand Up @@ -1592,30 +1592,28 @@ channelid_richcompare(PyObject *self, PyObject *other, int op)
int equal;
if (PyObject_TypeCheck(other, &ChannelIDtype)) {
channelid *othercid = (channelid *)other;
if (cid->end != othercid->end) {
equal = 0;
}
else {
equal = (cid->id == othercid->id);
}
equal = (cid->end == othercid->end) && (cid->id == othercid->id);
}
else {
other = PyNumber_Long(other);
if (other == NULL) {
PyErr_Clear();
Py_RETURN_NOTIMPLEMENTED;
}
int64_t othercid = PyLong_AsLongLong(other);
Py_DECREF(other);
if (othercid == -1 && PyErr_Occurred() != NULL) {
else if (PyLong_Check(other)) {
/* Fast path */
int overflow;
long long othercid = PyLong_AsLongLongAndOverflow(other, &overflow);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use int64_t or the macro?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using int64_t can cause loss of bits and returning wrong result if long long is larger than int64_t.

But using long long is good here. The comparison with cid->id will return false if the value is too large for int64_t. Overflow also returns false instead of raising an exception (added a new test for this).

if (othercid == -1 && PyErr_Occurred()) {
return NULL;
}
if (othercid < 0) {
equal = 0;
}
else {
equal = (cid->id == othercid);
equal = !overflow && (othercid >= 0) && (cid->id == othercid);
}
else if (PyNumber_Check(other)) {
PyObject *pyid = PyLong_FromLongLong(cid->id);
if (pyid == NULL) {
return NULL;
}
PyObject *res = PyObject_RichCompare(pyid, other, op);
Py_DECREF(pyid);
return res;
}
else {
Py_RETURN_NOTIMPLEMENTED;
}

if ((op == Py_EQ && equal) || (op == Py_NE && !equal)) {
Expand Down
41 changes: 20 additions & 21 deletions Objects/interpreteridobject.c
Expand Up @@ -12,21 +12,17 @@ _Py_CoerceID(PyObject *orig)
if (pyid == NULL) {
if (PyErr_ExceptionMatches(PyExc_TypeError)) {
PyErr_Format(PyExc_TypeError,
"'id' must be a non-negative int, got %R", orig);
"'id' must be a non-negative int, got %s", orig->ob_type->tp_name);
}
else {
else if (PyErr_ExceptionMatches(PyExc_ValueError)) {
PyErr_Format(PyExc_ValueError,
"'id' must be a non-negative int, got %R", orig);
}
return -1;
}
int64_t id = PyLong_AsLongLong(pyid);
Py_DECREF(pyid);
if (id == -1 && PyErr_Occurred() != NULL) {
if (!PyErr_ExceptionMatches(PyExc_OverflowError)) {
PyErr_Format(PyExc_ValueError,
"'id' must be a non-negative int, got %R", orig);
}
if (id == -1 && PyErr_Occurred()) {
return -1;
}
if (id < 0) {
Expand Down Expand Up @@ -202,23 +198,26 @@ interpid_richcompare(PyObject *self, PyObject *other, int op)
interpid *otherid = (interpid *)other;
equal = (id->id == otherid->id);
}
else {
other = PyNumber_Long(other);
if (other == NULL) {
PyErr_Clear();
Py_RETURN_NOTIMPLEMENTED;
}
int64_t otherid = PyLong_AsLongLong(other);
Py_DECREF(other);
if (otherid == -1 && PyErr_Occurred() != NULL) {
else if (PyLong_CheckExact(other)) {
/* Fast path */
int overflow;
long long otherid = PyLong_AsLongLongAndOverflow(other, &overflow);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use int64_t or the macro?

if (otherid == -1 && PyErr_Occurred()) {
return NULL;
}
if (otherid < 0) {
equal = 0;
}
else {
equal = (id->id == otherid);
equal = !overflow && (otherid >= 0) && (id->id == otherid);
}
else if (PyNumber_Check(other)) {
PyObject *pyid = PyLong_FromLongLong(id->id);
if (pyid == NULL) {
return NULL;
}
PyObject *res = PyObject_RichCompare(pyid, other, op);
Py_DECREF(pyid);
return res;
}
else {
Py_RETURN_NOTIMPLEMENTED;
}

if ((op == Py_EQ && equal) || (op == Py_NE && !equal)) {
Expand Down