diff --git a/mypy/typeshed/stubs/librt/librt/base64.pyi b/mypy/typeshed/stubs/librt/librt/base64.pyi index 36366f5754ce..1cea838505d6 100644 --- a/mypy/typeshed/stubs/librt/librt/base64.pyi +++ b/mypy/typeshed/stubs/librt/librt/base64.pyi @@ -1 +1,2 @@ def b64encode(s: bytes) -> bytes: ... +def b64decode(s: bytes | str) -> bytes: ... diff --git a/mypyc/lib-rt/librt_base64.c b/mypyc/lib-rt/librt_base64.c index 020a56e412f4..1720359ef9a6 100644 --- a/mypyc/lib-rt/librt_base64.c +++ b/mypyc/lib-rt/librt_base64.c @@ -1,11 +1,16 @@ #define PY_SSIZE_T_CLEAN #include +#include #include "librt_base64.h" #include "libbase64.h" #include "pythoncapi_compat.h" #ifdef MYPYC_EXPERIMENTAL +static PyObject * +b64decode_handle_invalid_input( + PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen); + #define BASE64_MAXBIN ((PY_SSIZE_T_MAX - 3) / 2) #define STACK_BUFFER_SIZE 1024 @@ -63,11 +68,193 @@ b64encode(PyObject *self, PyObject *const *args, size_t nargs) { return b64encode_internal(args[0]); } +static inline int +is_valid_base64_char(char c, bool allow_padding) { + return ((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || + (c >= '0' && c <= '9') || (c == '+') || (c == '/') || (allow_padding && c == '=')); +} + +static PyObject * +b64decode_internal(PyObject *arg) { + const char *src; + Py_ssize_t srclen_ssz; + + // Get input pointer and length + if (PyBytes_Check(arg)) { + src = PyBytes_AS_STRING(arg); + srclen_ssz = PyBytes_GET_SIZE(arg); + } else if (PyUnicode_Check(arg)) { + if (!PyUnicode_IS_ASCII(arg)) { + PyErr_SetString(PyExc_ValueError, + "string argument should contain only ASCII characters"); + return NULL; + } + src = (const char *)PyUnicode_1BYTE_DATA(arg); + srclen_ssz = PyUnicode_GET_LENGTH(arg); + } else { + PyErr_SetString(PyExc_TypeError, + "argument should be a bytes-like object or ASCII string"); + return NULL; + } + + // Fast-path: empty input + if (srclen_ssz == 0) { + return PyBytes_FromStringAndSize(NULL, 0); + } + + // Quickly ignore invalid characters at the end. Other invalid characters + // are also accepted, but they need a slow path. + while (srclen_ssz > 0 && !is_valid_base64_char(src[srclen_ssz - 1], true)) { + srclen_ssz--; + } + + // Compute an output capacity that's at least 3/4 of input, without overflow: + // ceil(3/4 * N) == N - floor(N/4) + size_t srclen = (size_t)srclen_ssz; + size_t max_out = srclen - (srclen / 4); + if (max_out == 0) { + max_out = 1; // defensive (srclen > 0 implies >= 1 anyway) + } + if (max_out > (size_t)PY_SSIZE_T_MAX) { + PyErr_SetString(PyExc_OverflowError, "input too large"); + return NULL; + } + + // Allocate output bytes (uninitialized) of the max capacity + PyObject *out_bytes = PyBytes_FromStringAndSize(NULL, (Py_ssize_t)max_out); + if (out_bytes == NULL) { + return NULL; // Propagate memory error + } + + char *outbuf = PyBytes_AS_STRING(out_bytes); + size_t outlen = max_out; + + int ret = base64_decode(src, srclen, outbuf, &outlen, 0); + + if (ret != 1) { + if (ret == 0) { + // Slow path: handle non-base64 input + return b64decode_handle_invalid_input(out_bytes, outbuf, max_out, src, srclen); + } + Py_DECREF(out_bytes); + if (ret == -1) { + PyErr_SetString(PyExc_NotImplementedError, "base64 codec not available in this build"); + } else { + PyErr_SetString(PyExc_RuntimeError, "base64_decode failed"); + } + return NULL; + } + + // Sanity-check contract (decoder must not overflow our buffer) + if (outlen > max_out) { + Py_DECREF(out_bytes); + PyErr_SetString(PyExc_RuntimeError, "decoder wrote past output buffer"); + return NULL; + } + + // Shrink in place to the actual decoded length + if (_PyBytes_Resize(&out_bytes, (Py_ssize_t)outlen) < 0) { + // _PyBytes_Resize sets an exception and may free the old object + return NULL; + } + return out_bytes; +} + +// Process non-base64 input by ignoring non-base64 characters, for compatibility +// with stdlib b64decode. +static PyObject * +b64decode_handle_invalid_input( + PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen) +{ + // Copy input to a temporary buffer, with non-base64 characters and extra suffix + // characters removed + size_t newbuf_len = 0; + char *newbuf = PyMem_Malloc(srclen); + if (newbuf == NULL) { + Py_DECREF(out_bytes); + return PyErr_NoMemory(); + } + + // Copy base64 characters and some padding to the new buffer + for (size_t i = 0; i < srclen; i++) { + char c = src[i]; + if (is_valid_base64_char(c, false)) { + newbuf[newbuf_len++] = c; + } else if (c == '=') { + // Copy a necessary amount of padding + int remainder = newbuf_len % 4; + if (remainder == 0) { + // No padding needed + break; + } + int numpad = 4 - remainder; + // Check that there is at least the required amount padding (CPython ignores + // extra padding) + while (numpad > 0) { + if (i == srclen || src[i] != '=') { + break; + } + newbuf[newbuf_len++] = '='; + i++; + numpad--; + // Skip non-base64 alphabet characters within padding + while (i < srclen && !is_valid_base64_char(src[i], true)) { + i++; + } + } + break; + } + } + + // Stdlib always performs a non-strict padding check + if (newbuf_len % 4 != 0) { + Py_DECREF(out_bytes); + PyMem_Free(newbuf); + PyErr_SetString(PyExc_ValueError, "Incorrect padding"); + return NULL; + } + + size_t outlen = max_out; + int ret = base64_decode(newbuf, newbuf_len, outbuf, &outlen, 0); + PyMem_Free(newbuf); + + if (ret != 1) { + Py_DECREF(out_bytes); + if (ret == 0) { + PyErr_SetString(PyExc_ValueError, "Only base64 data is allowed"); + } + if (ret == -1) { + PyErr_SetString(PyExc_NotImplementedError, "base64 codec not available in this build"); + } else { + PyErr_SetString(PyExc_RuntimeError, "base64_decode failed"); + } + return NULL; + } + + // Shrink in place to the actual decoded length + if (_PyBytes_Resize(&out_bytes, (Py_ssize_t)outlen) < 0) { + // _PyBytes_Resize sets an exception and may free the old object + return NULL; + } + return out_bytes; +} + + +static PyObject* +b64decode(PyObject *self, PyObject *const *args, size_t nargs) { + if (nargs != 1) { + PyErr_SetString(PyExc_TypeError, "b64decode() takes exactly one argument"); + return 0; + } + return b64decode_internal(args[0]); +} + #endif static PyMethodDef librt_base64_module_methods[] = { #ifdef MYPYC_EXPERIMENTAL - {"b64encode", (PyCFunction)b64encode, METH_FASTCALL, PyDoc_STR("Encode bytes-like object using Base64.")}, + {"b64encode", (PyCFunction)b64encode, METH_FASTCALL, PyDoc_STR("Encode bytes object using Base64.")}, + {"b64decode", (PyCFunction)b64decode, METH_FASTCALL, PyDoc_STR("Decode a Base64 encoded bytes object or ASCII string.")}, #endif {NULL, NULL, 0, NULL} }; @@ -111,7 +298,7 @@ static PyModuleDef_Slot librt_base64_module_slots[] = { static PyModuleDef librt_base64_module = { .m_base = PyModuleDef_HEAD_INIT, .m_name = "base64", - .m_doc = "base64 encoding and decoding optimized for mypyc", + .m_doc = "Fast base64 encoding and decoding optimized for mypyc", .m_size = 0, .m_methods = librt_base64_module_methods, .m_slots = librt_base64_module_slots, diff --git a/mypyc/test-data/run-base64.test b/mypyc/test-data/run-base64.test index 0f9151c2b00b..8d7eb7c13482 100644 --- a/mypyc/test-data/run-base64.test +++ b/mypyc/test-data/run-base64.test @@ -1,8 +1,9 @@ [case testAllBase64Features_librt_experimental] from typing import Any import base64 +import binascii -from librt.base64 import b64encode +from librt.base64 import b64encode, b64decode from testutil import assertRaises @@ -44,6 +45,111 @@ def test_encode_wrapper() -> None: with assertRaises(TypeError): enc(b"x", b"y") +def test_decode_basic() -> None: + assert b64decode(b"eA==") == b"x" + + with assertRaises(TypeError): + b64decode(bytearray(b"eA==")) + + for non_ascii in "\x80", "foo\u100bar", "foo\ua1234bar": + with assertRaises(ValueError): + b64decode(non_ascii) + +def check_decode(b: bytes, encoded: bool = False) -> None: + if encoded: + enc = b + else: + enc = b64encode(b) + assert b64decode(enc) == getattr(base64, "b64decode")(enc) + if getattr(enc, "isascii")(): # Test stub has no "isascii" + enc_str = enc.decode("ascii") + assert b64decode(enc_str) == getattr(base64, "b64decode")(enc_str) + +def test_decode_different_strings() -> None: + for i in range(256): + check_decode(bytes([i])) + check_decode(bytes([i]) + b"x") + check_decode(bytes([i]) + b"xy") + check_decode(bytes([i]) + b"xyz") + check_decode(bytes([i]) + b"xyza") + check_decode(b"x" + bytes([i])) + check_decode(b"xy" + bytes([i])) + check_decode(b"xyz" + bytes([i])) + check_decode(b"xyza" + bytes([i])) + + b = b"a\x00\xb7" * 1000 + for i in range(1000): + check_decode(b[:i]) + + for b in b"", b"ab", b"bac", b"1234", b"xyz88", b"abc" * 200: + check_decode(b) + +def is_base64_char(x: int) -> bool: + c = chr(x) + return ('a' <= c <= 'z') or ('A' <= c <= 'Z') or ('0' <= c <= '9') or c in '+/=' + +def test_decode_with_non_base64_chars() -> None: + # For stdlib compatibility, non-base64 characters should be ignored. + + # Invalid characters as a suffix use a fast path. + check_decode(b"eA== ", encoded=True) + check_decode(b"eA==\n", encoded=True) + check_decode(b"eA== \t\n", encoded=True) + check_decode(b"\n", encoded=True) + + check_decode(b" e A = = ", encoded=True) + + # Special case: Two different encodings of the same data + check_decode(b"eAa=", encoded=True) + check_decode(b"eAY=", encoded=True) + + for x in range(256): + if not is_base64_char(x): + b = bytes([x]) + check_decode(b, encoded=True) + check_decode(b"eA==" + b, encoded=True) + check_decode(b"e" + b + b"A==", encoded=True) + check_decode(b"eA=" + b + b"=", encoded=True) + +def check_decode_error(b: bytes, ignore_stdlib: bool = False) -> None: + if not ignore_stdlib: + with assertRaises(binascii.Error): + getattr(base64, "b64decode")(b) + + # The raised error is different, since librt shouldn't depend on binascii + with assertRaises(ValueError): + b64decode(b) + +def test_decode_with_invalid_padding() -> None: + check_decode_error(b"eA") + check_decode_error(b"eA=") + check_decode_error(b"eHk") + check_decode_error(b"eA = ") + + # Here stdlib behavior seems nonsensical, so we don't try to duplicate it + check_decode_error(b"eA=a=", ignore_stdlib=True) + +def test_decode_with_extra_data_after_padding() -> None: + check_decode(b"=", encoded=True) + check_decode(b"==", encoded=True) + check_decode(b"===", encoded=True) + check_decode(b"====", encoded=True) + check_decode(b"eA===", encoded=True) + check_decode(b"eHk==", encoded=True) + check_decode(b"eA==x", encoded=True) + check_decode(b"eHk=x", encoded=True) + check_decode(b"eA==abc=======efg", encoded=True) + +def test_decode_wrapper() -> None: + dec: Any = b64decode + assert dec(b"eA==") == b"x" + + with assertRaises(TypeError): + dec() + + with assertRaises(TypeError): + dec(b"x", b"y") + [case testBase64FeaturesNotAvailableInNonExperimentalBuild_librt_base64] # This also ensures librt.base64 can be built without experimental features import librt.base64