Skip to content

Commit e02a312

Browse files
committed
Fix bool un/marshaling
Test: import marshal; assert marshal.loads(marshal.dumps(True)) == True Test: import marshal; assert marshal.loads(marshal.dumps(False))== False
1 parent 6f54eb1 commit e02a312

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

Lib/test/test_bool.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,6 @@ def test_operator(self):
273273
self.assertIs(operator.is_not(True, True), False)
274274
self.assertIs(operator.is_not(True, False), True)
275275

276-
# TODO: RUSTPYTHON
277-
@unittest.expectedFailure
278276
def test_marshal(self):
279277
import marshal
280278
self.assertIs(marshal.loads(marshal.dumps(True)), True)

vm/src/stdlib/marshal.rs

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@ mod decl {
1414
bytecode,
1515
function::{ArgBytesLike, IntoPyObject},
1616
protocol::PyBuffer,
17+
pyobject::{IdProtocol, TypeProtocol},
1718
PyObjectRef, PyResult, TryFromObject, VirtualMachine,
1819
};
1920

2021
const STR_BYTE: u8 = b's';
2122
const INT_BYTE: u8 = b'i';
2223
const FLOAT_BYTE: u8 = b'f';
24+
const BOOL_BYTE: u8 = b'b';
2325
const LIST_BYTE: u8 = b'[';
2426
const TUPLE_BYTE: u8 = b'(';
2527
const DICT_BYTE: u8 = b',';
@@ -51,16 +53,22 @@ mod decl {
5153
fn _dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
5254
let r = match_class!(match value {
5355
pyint @ PyInt => {
54-
let (sign, mut int_bytes) = pyint.as_bigint().to_bytes_le();
55-
let sign_byte = match sign {
56-
Sign::Minus => b'-',
57-
Sign::NoSign => b'0',
58-
Sign::Plus => b'+',
59-
};
60-
// Return as [TYPE, SIGN, uint bytes]
61-
int_bytes.insert(0, sign_byte);
62-
int_bytes.push(INT_BYTE);
63-
int_bytes
56+
if pyint.class().is(&vm.ctx.types.bool_type) {
57+
let (_, mut bool_bytes) = pyint.as_bigint().to_bytes_le();
58+
bool_bytes.push(BOOL_BYTE);
59+
bool_bytes
60+
} else {
61+
let (sign, mut int_bytes) = pyint.as_bigint().to_bytes_le();
62+
let sign_byte = match sign {
63+
Sign::Minus => b'-',
64+
Sign::NoSign => b'0',
65+
Sign::Plus => b'+',
66+
};
67+
// Return as [TYPE, SIGN, uint bytes]
68+
int_bytes.insert(0, sign_byte);
69+
int_bytes.push(INT_BYTE);
70+
int_bytes
71+
}
6472
}
6573
pyfloat @ PyFloat => {
6674
let mut float_bytes = pyfloat.to_f64().to_le_bytes().to_vec();
@@ -190,6 +198,7 @@ mod decl {
190198
)
191199
})?;
192200
match *type_indicator {
201+
BOOL_BYTE => Ok((buf[0] != 0).into_pyobject(vm)),
193202
INT_BYTE => {
194203
let (sign_byte, uint_bytes) = buf
195204
.split_first()

0 commit comments

Comments
 (0)