Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mypy/typeshed/stubs/librt/librt/base64.pyi
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
def b64encode(s: bytes) -> bytes: ...
def b64decode(s: bytes | str) -> bytes: ...
191 changes: 189 additions & 2 deletions mypyc/lib-rt/librt_base64.c
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <stdbool.h>
#include "librt_base64.h"
#include "libbase64.h"
#include "pythoncapi_compat.h"

#ifdef MYPYC_EXPERIMENTAL

static PyObject *
b64decode_handle_invalid_input(
PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen);

#define BASE64_MAXBIN ((PY_SSIZE_T_MAX - 3) / 2)

#define STACK_BUFFER_SIZE 1024
Expand Down Expand Up @@ -63,11 +68,193 @@ b64encode(PyObject *self, PyObject *const *args, size_t nargs) {
return b64encode_internal(args[0]);
}

static inline int
is_valid_base64_char(char c, bool allow_padding) {
return ((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') ||
(c >= '0' && c <= '9') || (c == '+') || (c == '/') || (allow_padding && c == '='));
}

static PyObject *
b64decode_internal(PyObject *arg) {
const char *src;
Py_ssize_t srclen_ssz;

// Get input pointer and length
if (PyBytes_Check(arg)) {
src = PyBytes_AS_STRING(arg);
srclen_ssz = PyBytes_GET_SIZE(arg);
} else if (PyUnicode_Check(arg)) {
if (!PyUnicode_IS_ASCII(arg)) {
PyErr_SetString(PyExc_ValueError,
"string argument should contain only ASCII characters");
return NULL;
}
src = (const char *)PyUnicode_1BYTE_DATA(arg);
srclen_ssz = PyUnicode_GET_LENGTH(arg);
} else {
PyErr_SetString(PyExc_TypeError,
"argument should be a bytes-like object or ASCII string");
return NULL;
}

// Fast-path: empty input
if (srclen_ssz == 0) {
return PyBytes_FromStringAndSize(NULL, 0);
}

// Quickly ignore invalid characters at the end. Other invalid characters
// are also accepted, but they need a slow path.
while (srclen_ssz > 0 && !is_valid_base64_char(src[srclen_ssz - 1], true)) {
srclen_ssz--;
}

// Compute an output capacity that's at least 3/4 of input, without overflow:
// ceil(3/4 * N) == N - floor(N/4)
size_t srclen = (size_t)srclen_ssz;
size_t max_out = srclen - (srclen / 4);
if (max_out == 0) {
max_out = 1; // defensive (srclen > 0 implies >= 1 anyway)
}
if (max_out > (size_t)PY_SSIZE_T_MAX) {
PyErr_SetString(PyExc_OverflowError, "input too large");
return NULL;
}

// Allocate output bytes (uninitialized) of the max capacity
PyObject *out_bytes = PyBytes_FromStringAndSize(NULL, (Py_ssize_t)max_out);
if (out_bytes == NULL) {
return NULL; // Propagate memory error
}

char *outbuf = PyBytes_AS_STRING(out_bytes);
size_t outlen = max_out;

int ret = base64_decode(src, srclen, outbuf, &outlen, 0);

if (ret != 1) {
if (ret == 0) {
// Slow path: handle non-base64 input
return b64decode_handle_invalid_input(out_bytes, outbuf, max_out, src, srclen);
}
Py_DECREF(out_bytes);
if (ret == -1) {
PyErr_SetString(PyExc_NotImplementedError, "base64 codec not available in this build");
} else {
PyErr_SetString(PyExc_RuntimeError, "base64_decode failed");
}
return NULL;
}

// Sanity-check contract (decoder must not overflow our buffer)
if (outlen > max_out) {
Py_DECREF(out_bytes);
PyErr_SetString(PyExc_RuntimeError, "decoder wrote past output buffer");
return NULL;
}

// Shrink in place to the actual decoded length
if (_PyBytes_Resize(&out_bytes, (Py_ssize_t)outlen) < 0) {
// _PyBytes_Resize sets an exception and may free the old object
return NULL;
}
return out_bytes;
}

// Process non-base64 input by ignoring non-base64 characters, for compatibility
// with stdlib b64decode.
static PyObject *
b64decode_handle_invalid_input(
PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen)
{
// Copy input to a temporary buffer, with non-base64 characters and extra suffix
// characters removed
size_t newbuf_len = 0;
char *newbuf = PyMem_Malloc(srclen);
if (newbuf == NULL) {
Py_DECREF(out_bytes);
return PyErr_NoMemory();
}

// Copy base64 characters and some padding to the new buffer
for (size_t i = 0; i < srclen; i++) {
char c = src[i];
if (is_valid_base64_char(c, false)) {
newbuf[newbuf_len++] = c;
} else if (c == '=') {
// Copy a necessary amount of padding
int remainder = newbuf_len % 4;
if (remainder == 0) {
// No padding needed
break;
}
int numpad = 4 - remainder;
// Check that there is at least the required amount padding (CPython ignores
// extra padding)
while (numpad > 0) {
if (i == srclen || src[i] != '=') {
break;
}
newbuf[newbuf_len++] = '=';
i++;
numpad--;
// Skip non-base64 alphabet characters within padding
while (i < srclen && !is_valid_base64_char(src[i], true)) {
i++;
}
}
break;
}
}

// Stdlib always performs a non-strict padding check
if (newbuf_len % 4 != 0) {
Py_DECREF(out_bytes);
PyMem_Free(newbuf);
PyErr_SetString(PyExc_ValueError, "Incorrect padding");
return NULL;
}

size_t outlen = max_out;
int ret = base64_decode(newbuf, newbuf_len, outbuf, &outlen, 0);
PyMem_Free(newbuf);

if (ret != 1) {
Py_DECREF(out_bytes);
if (ret == 0) {
PyErr_SetString(PyExc_ValueError, "Only base64 data is allowed");
}
if (ret == -1) {
PyErr_SetString(PyExc_NotImplementedError, "base64 codec not available in this build");
} else {
PyErr_SetString(PyExc_RuntimeError, "base64_decode failed");
}
return NULL;
}

// Shrink in place to the actual decoded length
if (_PyBytes_Resize(&out_bytes, (Py_ssize_t)outlen) < 0) {
// _PyBytes_Resize sets an exception and may free the old object
return NULL;
}
return out_bytes;
}


static PyObject*
b64decode(PyObject *self, PyObject *const *args, size_t nargs) {
if (nargs != 1) {
PyErr_SetString(PyExc_TypeError, "b64decode() takes exactly one argument");
return 0;
}
return b64decode_internal(args[0]);
}

#endif

static PyMethodDef librt_base64_module_methods[] = {
#ifdef MYPYC_EXPERIMENTAL
{"b64encode", (PyCFunction)b64encode, METH_FASTCALL, PyDoc_STR("Encode bytes-like object using Base64.")},
{"b64encode", (PyCFunction)b64encode, METH_FASTCALL, PyDoc_STR("Encode bytes object using Base64.")},
{"b64decode", (PyCFunction)b64decode, METH_FASTCALL, PyDoc_STR("Decode a Base64 encoded bytes object or ASCII string.")},
#endif
{NULL, NULL, 0, NULL}
};
Expand Down Expand Up @@ -111,7 +298,7 @@ static PyModuleDef_Slot librt_base64_module_slots[] = {
static PyModuleDef librt_base64_module = {
.m_base = PyModuleDef_HEAD_INIT,
.m_name = "base64",
.m_doc = "base64 encoding and decoding optimized for mypyc",
.m_doc = "Fast base64 encoding and decoding optimized for mypyc",
.m_size = 0,
.m_methods = librt_base64_module_methods,
.m_slots = librt_base64_module_slots,
Expand Down
108 changes: 107 additions & 1 deletion mypyc/test-data/run-base64.test
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
[case testAllBase64Features_librt_experimental]
from typing import Any
import base64
import binascii

from librt.base64 import b64encode
from librt.base64 import b64encode, b64decode

from testutil import assertRaises

Expand Down Expand Up @@ -44,6 +45,111 @@ def test_encode_wrapper() -> None:
with assertRaises(TypeError):
enc(b"x", b"y")

def test_decode_basic() -> None:
assert b64decode(b"eA==") == b"x"

with assertRaises(TypeError):
b64decode(bytearray(b"eA=="))

for non_ascii in "\x80", "foo\u100bar", "foo\ua1234bar":
with assertRaises(ValueError):
b64decode(non_ascii)

def check_decode(b: bytes, encoded: bool = False) -> None:
if encoded:
enc = b
else:
enc = b64encode(b)
assert b64decode(enc) == getattr(base64, "b64decode")(enc)
if getattr(enc, "isascii")(): # Test stub has no "isascii"
enc_str = enc.decode("ascii")
assert b64decode(enc_str) == getattr(base64, "b64decode")(enc_str)

def test_decode_different_strings() -> None:
for i in range(256):
check_decode(bytes([i]))
check_decode(bytes([i]) + b"x")
check_decode(bytes([i]) + b"xy")
check_decode(bytes([i]) + b"xyz")
check_decode(bytes([i]) + b"xyza")
check_decode(b"x" + bytes([i]))
check_decode(b"xy" + bytes([i]))
check_decode(b"xyz" + bytes([i]))
check_decode(b"xyza" + bytes([i]))

b = b"a\x00\xb7" * 1000
for i in range(1000):
check_decode(b[:i])

for b in b"", b"ab", b"bac", b"1234", b"xyz88", b"abc" * 200:
check_decode(b)

def is_base64_char(x: int) -> bool:
c = chr(x)
return ('a' <= c <= 'z') or ('A' <= c <= 'Z') or ('0' <= c <= '9') or c in '+/='

def test_decode_with_non_base64_chars() -> None:
# For stdlib compatibility, non-base64 characters should be ignored.

# Invalid characters as a suffix use a fast path.
check_decode(b"eA== ", encoded=True)
check_decode(b"eA==\n", encoded=True)
check_decode(b"eA== \t\n", encoded=True)
check_decode(b"\n", encoded=True)

check_decode(b" e A = = ", encoded=True)

# Special case: Two different encodings of the same data
check_decode(b"eAa=", encoded=True)
check_decode(b"eAY=", encoded=True)

for x in range(256):
if not is_base64_char(x):
b = bytes([x])
check_decode(b, encoded=True)
check_decode(b"eA==" + b, encoded=True)
check_decode(b"e" + b + b"A==", encoded=True)
check_decode(b"eA=" + b + b"=", encoded=True)

def check_decode_error(b: bytes, ignore_stdlib: bool = False) -> None:
if not ignore_stdlib:
with assertRaises(binascii.Error):
getattr(base64, "b64decode")(b)

# The raised error is different, since librt shouldn't depend on binascii
with assertRaises(ValueError):
b64decode(b)

def test_decode_with_invalid_padding() -> None:
check_decode_error(b"eA")
check_decode_error(b"eA=")
check_decode_error(b"eHk")
check_decode_error(b"eA = ")

# Here stdlib behavior seems nonsensical, so we don't try to duplicate it
check_decode_error(b"eA=a=", ignore_stdlib=True)

def test_decode_with_extra_data_after_padding() -> None:
check_decode(b"=", encoded=True)
check_decode(b"==", encoded=True)
check_decode(b"===", encoded=True)
check_decode(b"====", encoded=True)
check_decode(b"eA===", encoded=True)
check_decode(b"eHk==", encoded=True)
check_decode(b"eA==x", encoded=True)
check_decode(b"eHk=x", encoded=True)
check_decode(b"eA==abc=======efg", encoded=True)

def test_decode_wrapper() -> None:
dec: Any = b64decode
assert dec(b"eA==") == b"x"

with assertRaises(TypeError):
dec()

with assertRaises(TypeError):
dec(b"x", b"y")

[case testBase64FeaturesNotAvailableInNonExperimentalBuild_librt_base64]
# This also ensures librt.base64 can be built without experimental features
import librt.base64
Expand Down