From 233a13390cdd5abc831f6383124da3319c6fc279 Mon Sep 17 00:00:00 2001 From: carsonburr Date: Tue, 5 Mar 2024 22:43:54 -0500 Subject: [PATCH 01/15] rust attempt v2. needs build setup --- .gitignore | 1 + MANIFEST.in | 2 + pyproject.toml | 2 +- setup.py | 1 + src/markupsafe/__init__.py | 59 +++- src/markupsafe/_native.py | 58 +--- src/markupsafe/_rust_speedups.pyi | 1 + src/rust/Cargo.lock | 288 +++++++++++++++++++ src/rust/Cargo.toml | 12 + src/rust/src/lib.rs | 454 ++++++++++++++++++++++++++++++ 10 files changed, 815 insertions(+), 63 deletions(-) create mode 100644 src/markupsafe/_rust_speedups.pyi create mode 100644 src/rust/Cargo.lock create mode 100644 src/rust/Cargo.toml create mode 100644 src/rust/src/lib.rs diff --git a/.gitignore b/.gitignore index dbaee290..d9df6a15 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ __pycache__/ /.coverage* /htmlcov/ /docs/_build/ +/src/rust/target/ diff --git a/MANIFEST.in b/MANIFEST.in index 7dfa3f60..35d1a31f 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -6,4 +6,6 @@ prune docs/_build graft tests include src/markupsafe/py.typed include src/markupsafe/*.pyi +graft src/rust +prune src/rust/target global-exclude *.pyc diff --git a/pyproject.toml b/pyproject.toml index 5624a279..e3727f42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ Changes = "https://markupsafe.palletsprojects.com/changes/" Chat = "https://discord.gg/pallets" [build-system] -requires = ["setuptools"] +requires = ["setuptools", "setuptools-rust"] build-backend = "setuptools.build_meta" [tool.pytest.ini_options] diff --git a/setup.py b/setup.py index d19a4faa..a63fca99 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ import sys from setuptools import Extension +from setuptools_rust import RustExtension from setuptools import setup from setuptools.command.build_ext import build_ext from setuptools.errors import CCompilerError diff --git a/src/markupsafe/__init__.py b/src/markupsafe/__init__.py index cbc4f555..2afbfbc6 100644 --- a/src/markupsafe/__init__.py +++ b/src/markupsafe/__init__.py @@ -319,13 +319,60 @@ def __float__(self, /) -> float: # circular import try: - from ._speedups import escape as escape - from ._speedups import escape_silent as escape_silent - from ._speedups import soft_str as soft_str + from ._rust_speedups import escape_inner as escape_inner except ImportError: - from ._native import escape as escape - from ._native import escape_silent as escape_silent # noqa: F401 - from ._native import soft_str as soft_str # noqa: F401 + from ._native import escape_inner as escape_inner + +def escape(s: t.Any, /) -> Markup: + """Replace the characters ``&``, ``<``, ``>``, ``'``, and ``"`` in + the string with HTML-safe sequences. Use this if you need to display + text that might contain such characters in HTML. + + If the object has an ``__html__`` method, it is called and the + return value is assumed to already be safe for HTML. + + :param s: An object to be converted to a string and escaped. + :return: A :class:`Markup` string with the escaped text. + """ + if hasattr(s, "__html__"): + return Markup(s.__html__()) + + return Markup(escape_inner(str(s))) + +def escape_silent(s: t.Any | None, /) -> Markup: + """Like :func:`escape` but treats ``None`` as the empty string. + Useful with optional values, as otherwise you get the string + ``'None'`` when the value is ``None``. + + >>> escape(None) + Markup('None') + >>> escape_silent(None) + Markup('') + """ + if s is None: + return Markup() + + return escape(s) + + +def soft_str(s: t.Any, /) -> str: + """Convert an object to a string if it isn't already. This preserves + a :class:`Markup` string rather than converting it back to a basic + string, so it will still be marked as safe and won't be escaped + again. + + >>> value = escape("") + >>> value + Markup('<User 1>') + >>> escape(str(value)) + Markup('&lt;User 1&gt;') + >>> escape(soft_str(value)) + Markup('<User 1>') + """ + if not isinstance(s, str): + return str(s) + + return s def __getattr__(name: str) -> t.Any: diff --git a/src/markupsafe/_native.py b/src/markupsafe/_native.py index e5ac0c13..063b13c8 100644 --- a/src/markupsafe/_native.py +++ b/src/markupsafe/_native.py @@ -2,64 +2,10 @@ import typing as t -from . import Markup - - -def escape(s: t.Any, /) -> Markup: - """Replace the characters ``&``, ``<``, ``>``, ``'``, and ``"`` in - the string with HTML-safe sequences. Use this if you need to display - text that might contain such characters in HTML. - - If the object has an ``__html__`` method, it is called and the - return value is assumed to already be safe for HTML. - - :param s: An object to be converted to a string and escaped. - :return: A :class:`Markup` string with the escaped text. - """ - if hasattr(s, "__html__"): - return Markup(s.__html__()) - - return Markup( - str(s) +def escape_inner(s: str, /) -> str: + return s .replace("&", "&") .replace(">", ">") .replace("<", "<") .replace("'", "'") .replace('"', """) - ) - - -def escape_silent(s: t.Any | None, /) -> Markup: - """Like :func:`escape` but treats ``None`` as the empty string. - Useful with optional values, as otherwise you get the string - ``'None'`` when the value is ``None``. - - >>> escape(None) - Markup('None') - >>> escape_silent(None) - Markup('') - """ - if s is None: - return Markup() - - return escape(s) - - -def soft_str(s: t.Any, /) -> str: - """Convert an object to a string if it isn't already. This preserves - a :class:`Markup` string rather than converting it back to a basic - string, so it will still be marked as safe and won't be escaped - again. - - >>> value = escape("") - >>> value - Markup('<User 1>') - >>> escape(str(value)) - Markup('&lt;User 1&gt;') - >>> escape(soft_str(value)) - Markup('<User 1>') - """ - if not isinstance(s, str): - return str(s) - - return s diff --git a/src/markupsafe/_rust_speedups.pyi b/src/markupsafe/_rust_speedups.pyi new file mode 100644 index 00000000..764b8e61 --- /dev/null +++ b/src/markupsafe/_rust_speedups.pyi @@ -0,0 +1 @@ +def escape_inner(s: str) -> str: ... diff --git a/src/rust/Cargo.lock b/src/rust/Cargo.lock new file mode 100644 index 00000000..f8aa6734 --- /dev/null +++ b/src/rust/Cargo.lock @@ -0,0 +1,288 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "indoc" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" + +[[package]] +name = "libc" +version = "0.2.153" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" + +[[package]] +name = "lock_api" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "markupsafe-rust" +version = "0.1.0" +dependencies = [ + "pyo3", +] + +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + +[[package]] +name = "portable-atomic" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + +[[package]] +name = "proc-macro2" +version = "1.0.78" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "deaa5745de3f5231ce10517a1f5dd97d53e5a2fd77aa6b5842292085831d48d7" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b42531d03e08d4ef1f6e85a2ed422eb678b8cd62b762e53891c05faf0d4afa" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7305c720fa01b8055ec95e484a6eca7a83c841267f0dd5280f0c8b8551d2c158" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c7e9b68bb9c3149c5b0cade5d07f953d6d125eb4337723c4ccdb665f1f96185" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "smallvec" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" + +[[package]] +name = "syn" +version = "2.0.52" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b699d15b36d1f02c3e7c69f8ffef53de37aefae075d8488d4ba1a7788d574a07" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" diff --git a/src/rust/Cargo.toml b/src/rust/Cargo.toml new file mode 100644 index 00000000..005928a4 --- /dev/null +++ b/src/rust/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "markupsafe-rust" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +pyo3 = "0.20.3" + +[lib] +name = "_rust_speedups" +crate-type = ["cdylib"] diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs new file mode 100644 index 00000000..5c6893d0 --- /dev/null +++ b/src/rust/src/lib.rs @@ -0,0 +1,454 @@ +use std::ops::{BitAnd, BitOr}; +use std::str::from_utf8_unchecked; + +use pyo3::prelude::*; +use pyo3::{ + types::{PyString, PyStringData}, + PyResult, Python, +}; + +/// A Rust implemented find and replace for the characters `<`, `>`, `&`, `"`, and `'` +/// into the sanitized strings `<`, `>`, `&`, `#34;`, and `#39;` respectively + +// Since speed was is a concern, I try to strike a balance between fast and readable, borrowing +// some lessons from simd-json. Namely I use vectorized and branchless-ish processing, but instead +// of using explicit vectorization, I let the compiler auto-vectorize the code for us, keeping the +// code readable, portable, and safe. +// +// To guarantee this optimization we need to help the compiler to recognize the patterns. We need to +// operate on 128-bit chunks whenever possible, and to do so generically because python could give +// us characters encoded as u8, u16, or u32. +// > By default, the i686 and x86_64 Rust targets enable sse and sse2 (128-bit vector extensions) +// https://github.com/rust-lang/portable-simd/blob/master/beginners-guide.md +// +// To enable this, I used some intermediate Rust features, so I'll include comments labelled +// RUST_INTRO where they're relevant, but they include: +// +// - const generics - similar to generics, but for compile-time constants. This means that +// different calls to functions can work with different numbers in their types. +// - traits - like an interface in java. Defines the behavior we might want out of a +// generic type that can be implemented by many different concrete types. +// - macro_rules - used to create syntax sugar for some repetative code. +// - iterators - a lazy data stream used mostly to make for-in loops look nice. They have a +// similar interface to java's stream api, but compile to optimized loops. +// - lifetimes - less of a language feature than a constraint. It's a hint to the compiler +// that the reference it annotates (&_) lives (is not destroyed/freed) for at +// least as long as anything else with the same lifetime. If function arguments +// have the same lifetime as the return type, those arguments must live for at +// least the duration of the function. + +// A trait that describes anything we might want to do with bits so we can generically work with +// u8, u16, or u32 +// RUST_INTRO: The traits after the colon are trait-bounds. Any types implementing Bits must already +// implement these other traits. These are all handled by the standard library. +trait Bits: Copy + Eq + BitAnd + BitOr + Into + From { + fn ones() -> Self; + fn zeroes() -> Self; +} + +// RUST_INTRO: We can implement our trait on types that already exist in the standard library +impl Bits for u8 { + fn ones() -> Self { + Self::MAX + } + + fn zeroes() -> Self { + 0 + } +} + +impl Bits for u16 { + fn ones() -> Self { + Self::MAX + } + + fn zeroes() -> Self { + 0 + } +} + +impl Bits for u32 { + fn ones() -> Self { + Self::MAX + } + + fn zeroes() -> Self { + 0 + } +} + +// auto-vectorized OR +// RUST_INTRO: the `(ax: &[T; N], bx: &[T; N]) -> [T; N] { + let mut result = [T::zeroes(); N]; + for ((&a, &b), r) in ax.iter().zip(bx.iter()).zip(result.iter_mut()) { + *r = a | b; + } + result +} + +const MASK_BITS: [u64; 8] = [1, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6, 1 << 7]; + +// I attempted to autovectorize to pmovmskb, but since I can't get it to +// recognize it, I'll just take advantage of v_eq setting all bits to 1. +// +// converts chunks of ones created by v_eq into a bitmask with 1s where the +// matching characters were. e.g. +// input = "xx(ax: &[T; N]) -> u64 { + let mut result = 0u64; + for (i, &a) in ax.iter().enumerate() { + result |= (a.into() & MASK_BITS[i % 8]) << ((i / 8) * 8); + } + result +} + +// auto-vectorized equal. designed to compile into the instruction pcmpeqb +#[inline(always)] +fn v_eq(ax: &[T], bx: &[T; N]) -> [T; N] { + let mut result = [T::zeroes(); N]; + for ((&a, &b), r) in ax.iter().zip(bx.iter()).zip(result.iter_mut()) { + *r = if a == b { T::ones() } else { T::zeroes() }; + } + result +} + +fn mask(input: &[T], splats: [[T; N]; M]) -> u64 { + let mut result = 0u64; + // split into 128-bit chunks to vectorize inside the loop + for (i, lane) in input.chunks_exact(N).enumerate() { + result |= v_bitmask( + &splats + .iter() + .map(|splat| v_eq(lane, splat)) + .reduce(|a, b| v_or(&a, &b)) + .unwrap(), + ) << (i * N); + } + result +} + +// a splat is an array containing the same element repeated. +// e.g. [u8; 16] splat of "<" is "<<<<<<<<<<<<<<<<" +// a single call to mask might need many of those, so we construct them with a macro +// RUST_INTRO: `$items:expr` matches any input that looks like an expression. `$(...),+` means that +// the pattern inside is repeated one or more times, separated by commas +macro_rules! make_splats { + ($($items:expr),+) => { + [$([$items.into(); N],)+] + }; +} + +macro_rules! is_equal_any { + ($lhs:expr, $($rhs:literal)|+) => { + $(($lhs == $rhs.into()))|+ + }; +} + +// Tying everything together, we calculate the delta between the input size and the output. +// the algorithm: take chunks of 64 items at a time so mask() can create a u64 representing +// which items, if any, have characters that need replacing. We can then count how many of each +// character class is in the input and keep any indices if needed. Finally if the input doesn't +// neatly fit into 64 item chunks, we need a slow version to do the same to the remainder. +fn delta(input: &[T], replacement_indices: &mut Vec) -> usize { + // calls to mask() create a u64 mask representing 64 items + let chunks = input.chunks_exact(64); + let remainder = chunks.remainder(); + + let mut delta = 0; + for (i, chunk) in chunks.enumerate() { + let delta_3_mask = mask::(chunk, make_splats!(b'<', b'>')); + let delta_4_mask = mask::(chunk, make_splats!(b'"', b'\'', b'&')); + // count_ones() is a single instruction on x86 + let delta_3 = delta_3_mask.count_ones(); + let delta_4 = delta_4_mask.count_ones(); + delta += (delta_3 * 3) + (delta_4 * 4); + + let mut all_mask = delta_3_mask | delta_4_mask; + let mut count = delta_3 + delta_4; + let idx = (i * 64) as u32; + while count > 0 { + replacement_indices.push(idx + all_mask.trailing_zeros()); + all_mask &= all_mask.wrapping_sub(1); + count -= 1; + } + } + + let idx = ((input.len() / 64) * 64) as u32; + for (i, &item) in remainder.iter().enumerate() { + if is_equal_any!(item, b'<' | b'>') { + delta += 3; + replacement_indices.push(idx + i as u32); + } else if is_equal_any!(item, b'"' | b'\'' | b'&') { + delta += 4; + replacement_indices.push(idx + i as u32); + } + } + delta as usize +} + +// very similar to delta(), but short-circuits because if there is anything to replace, we need to +// convert to utf-8 anyway to calculate delta and indices +fn no_change(input: &[T]) -> bool { + // calls to mask() create a u64 mask representing 64 items + let chunks = input.chunks_exact(64); + let remainder = chunks.remainder(); + + for chunk in chunks { + let any_mask = mask::(chunk, make_splats!(b'<', b'>', b'"', b'\'', b'&')); + if any_mask != 0 { + return false; + } + } + + for &item in remainder { + if is_equal_any!(item, b'<' | b'>' | b'"' | b'\'' | b'&') { + return false; + } + } + true +} + +// builds the sanitized output string. Copies the input bytes from sections that haven't changed +// and replaces the characters that need sanitizing. +fn build_replaced( + RebuildArgs { + delta, + replacement_indices, + input_str, + }: RebuildArgs<'_>, +) -> String { + // we could create the string without the size, but with_capacity means we + // never need to re-allocate the backing memory + let mut builder = String::with_capacity(input_str.len() + delta); + let mut prev_idx = 0usize; + for idx in replacement_indices { + let idx = idx as usize; + if prev_idx < idx { + builder.push_str(&input_str[prev_idx..idx]); + } + builder.push_str(match &input_str[idx..idx + 1] { + "<" => "<", + ">" => ">", + "\"" => """, + "\'" => "'", + "&" => "&", + _ => unreachable!(""), + }); + prev_idx = idx + 1; + } + if prev_idx != input_str.len() - 1 { + builder.push_str(&input_str[prev_idx..input_str.len()]); + } + builder +} + +struct RebuildArgs<'a> { + delta: usize, + replacement_indices: Vec, + input_str: &'a str, +} + +impl<'a> RebuildArgs<'a> { + fn new(delta: usize, replacement_indices: Vec, input_str: &'a str) -> Self { + RebuildArgs { + delta, + replacement_indices, + input_str, + } + } +} + +fn check_utf8(input: &[u8]) -> Option { + let mut replacement_indices = Vec::with_capacity(8); + let delta = delta::<16, u8>(input, &mut replacement_indices); + if delta == 0 { + None + } else { + // SAFETY: The rest of the code assumes that python gives us valid utf-8, so to avoid + // validation or copying, we will here too + // https://docs.python.org/3.12//c-api/unicode.html + let input_str = unsafe { from_utf8_unchecked(input) }; + Some(RebuildArgs::new(delta, replacement_indices, input_str)) + } +} + +fn escape_utf8<'a>(py: Python<'a>, orig: &'a PyString, input: &[u8]) -> &'a PyString { + check_utf8(input) + .map(|rb| PyString::new(py, build_replaced(rb).as_str())) + .unwrap_or(orig) +} + +fn escape_other_format<'a, const N: usize, T: Bits>( + py: Python<'a>, + orig: &'a PyString, + input: &[T], +) -> PyResult<&'a PyString> { + // there's no safe way to construct a utf-16 or unicode string to pass back to python without + // using the unsafe ffi bindings, so best case scenario we short circuit and pass the original + // string back, otherwise just slow-path convert to utf-8 and process it that way since we'd + // have to recompute the delta and indices anyway + if no_change::(input) { + Ok(orig) + } else { + orig.to_str() + .map(|utf8| escape_utf8(py, orig, utf8.as_bytes())) + } +} + +#[pyfunction] +pub fn escape_inner<'a>(py: Python<'a>, s: &'a PyString) -> PyResult<&'a PyString> { + // SAFETY: from the py03 docs: + // This function implementation relies on manually decoding a C bitfield. + // In practice, this works well on common little-endian architectures such + // as x86_64, where the bitfield has a common representation (even if it is + // not part of the C spec). The PyO3 CI tests this API on x86_64 platforms. + // + // The C implementation already does this. And Rust can't compete on performance unless it can + // access the raw bytes. Converting it to a rust string (validating utf-8 and probably copying + // the full string) is probably already slower than C. + let data_res = unsafe { s.data() }; + match data_res { + Ok(data) => match data { + PyStringData::Ucs1(raw) => Ok(escape_utf8(py, s, raw)), + PyStringData::Ucs2(raw) => escape_other_format::<8, u16>(py, s, raw), + PyStringData::Ucs4(raw) => escape_other_format::<4, u32>(py, s, raw), + }, + Err(e) => Err(e), + } +} + +#[cfg(test)] +mod tests { + use crate::{build_replaced, check_utf8, no_change}; + + #[test] + fn empty() { + let res = check_utf8("".as_bytes()); + assert!(res.is_none()) + } + + #[test] + fn middle() { + let res = build_replaced(check_utf8("abcd&><'\"efgh".as_bytes()).unwrap()); + assert_eq!("abcd&><'"efgh", res); + } + + #[test] + fn begin() { + let res = build_replaced(check_utf8("&><'\"efgh".as_bytes()).unwrap()); + assert_eq!("&><'"efgh", res); + } + + #[test] + fn end() { + let res = build_replaced(check_utf8("abcd&><'\"".as_bytes()).unwrap()); + assert_eq!("abcd&><'"", res); + } + + #[test] + fn middle_large() { + let res = build_replaced(check_utf8("abcd&><'\"efgh".repeat(1024).as_bytes()).unwrap()); + assert_eq!("abcd&><'"efgh".repeat(1024).as_str(), res); + } + + #[test] + fn begin_large() { + let res = build_replaced(check_utf8("&><'\"efgh".repeat(1024).as_bytes()).unwrap()); + assert_eq!("&><'"efgh".repeat(1024).as_str(), res); + } + + #[test] + fn end_large() { + let res = build_replaced(check_utf8("abcd&><'\"".repeat(1024).as_bytes()).unwrap()); + assert_eq!("abcd&><'"".repeat(1024).as_str(), res); + } + + #[test] + fn middle_16() { + let input = "こんにちは&><'\"こんばんは" + .encode_utf16() + .collect::>(); + assert!(!no_change::<8, u16>(input.as_slice())); + let utf8 = String::from_utf16(input.as_slice()).unwrap(); + let res = build_replaced(check_utf8(utf8.as_bytes()).unwrap()); + assert_eq!("こんにちは&><'"こんばんは", res); + } + + #[test] + fn begin_16() { + let input = "&><'\"こんばんは".encode_utf16().collect::>(); + assert!(!no_change::<8, u16>(input.as_slice())); + let utf8 = String::from_utf16(input.as_slice()).unwrap(); + let res = build_replaced(check_utf8(utf8.as_bytes()).unwrap()); + assert_eq!("&><'"こんばんは", res); + } + + #[test] + fn end_16() { + let input = "こんにちは&><'\"".encode_utf16().collect::>(); + assert!(!no_change::<8, u16>(input.as_slice())); + let utf8 = String::from_utf16(input.as_slice()).unwrap(); + let res = build_replaced(check_utf8(utf8.as_bytes()).unwrap()); + assert_eq!("こんにちは&><'"", res); + } + + #[test] + fn middle_16_large() { + let input = "こんにちは&><'\"こんばんは" + .repeat(1024) + .encode_utf16() + .collect::>(); + assert!(!no_change::<8, u16>(input.as_slice())); + let utf8 = String::from_utf16(input.as_slice()).unwrap(); + let res = build_replaced(check_utf8(utf8.as_bytes()).unwrap()); + assert_eq!( + "こんにちは&><'"こんばんは" + .repeat(1024) + .as_str(), + res + ); + } + + #[test] + fn begin_16_large() { + let input = "&><'\"こんばんは" + .repeat(1024) + .encode_utf16() + .collect::>(); + assert!(!no_change::<8, u16>(input.as_slice())); + let utf8 = String::from_utf16(input.as_slice()).unwrap(); + let res = build_replaced(check_utf8(utf8.as_bytes()).unwrap()); + assert_eq!( + "&><'"こんばんは".repeat(1024).as_str(), + res + ); + } + + #[test] + fn end_16_large() { + let input = "こんにちは&><'\"" + .repeat(1024) + .encode_utf16() + .collect::>(); + assert!(!no_change::<8, u16>(input.as_slice())); + let utf8 = String::from_utf16(input.as_slice()).unwrap(); + let res = build_replaced(check_utf8(utf8.as_bytes()).unwrap()); + assert_eq!( + "こんにちは&><'"".repeat(1024).as_str(), + res + ); + } +} From 420abebf578179e7bbce4fadd58f1406ddb2c90d Mon Sep 17 00:00:00 2001 From: carsonburr Date: Tue, 5 Mar 2024 22:58:16 -0500 Subject: [PATCH 02/15] docs tweaks --- src/rust/src/lib.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index 5c6893d0..d3639983 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -31,11 +31,6 @@ use pyo3::{ // - macro_rules - used to create syntax sugar for some repetative code. // - iterators - a lazy data stream used mostly to make for-in loops look nice. They have a // similar interface to java's stream api, but compile to optimized loops. -// - lifetimes - less of a language feature than a constraint. It's a hint to the compiler -// that the reference it annotates (&_) lives (is not destroyed/freed) for at -// least as long as anything else with the same lifetime. If function arguments -// have the same lifetime as the return type, those arguments must live for at -// least the duration of the function. // A trait that describes anything we might want to do with bits so we can generically work with // u8, u16, or u32 From ad8caf11fd879645d29992d991acfc7b5a527c67 Mon Sep 17 00:00:00 2001 From: carsonburr Date: Wed, 6 Mar 2024 17:15:17 -0500 Subject: [PATCH 03/15] comment tweaks --- src/rust/src/lib.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index d3639983..0421fe54 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -25,7 +25,8 @@ use pyo3::{ // RUST_INTRO where they're relevant, but they include: // // - const generics - similar to generics, but for compile-time constants. This means that -// different calls to functions can work with different numbers in their types. +// different calls to the same function can work with different numbers in their +// types. // - traits - like an interface in java. Defines the behavior we might want out of a // generic type that can be implemented by many different concrete types. // - macro_rules - used to create syntax sugar for some repetative code. @@ -311,9 +312,10 @@ pub fn escape_inner<'a>(py: Python<'a>, s: &'a PyString) -> PyResult<&'a PyStrin // as x86_64, where the bitfield has a common representation (even if it is // not part of the C spec). The PyO3 CI tests this API on x86_64 platforms. // - // The C implementation already does this. And Rust can't compete on performance unless it can - // access the raw bytes. Converting it to a rust string (validating utf-8 and probably copying - // the full string) is probably already slower than C. + // The C implementation already does this. Python strings can be represented + // as u8, u16, or u32, and this avoids converting to utf-8 if it's not + // necessary, meaning if a u16 or u32 string doesn't need any characters + // replaced, we can short-circuit without doing any converting let data_res = unsafe { s.data() }; match data_res { Ok(data) => match data { From 62c408774f02ae1c46ce33fe6bceca2f9ce8fdee Mon Sep 17 00:00:00 2001 From: carsonburr Date: Wed, 6 Mar 2024 17:33:46 -0500 Subject: [PATCH 04/15] add tests --- src/rust/src/lib.rs | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index 0421fe54..b8b8e0c3 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -334,7 +334,13 @@ mod tests { #[test] fn empty() { let res = check_utf8("".as_bytes()); - assert!(res.is_none()) + assert!(res.is_none()); + } + + #[test] + fn no_change_test() { + let res = check_utf8("abcdefgh".as_bytes()); + assert!(res.is_none()); } #[test] @@ -355,6 +361,11 @@ mod tests { assert_eq!("abcd&><'"", res); } + #[test] + fn no_change_large() { + assert!(check_utf8("abcdefgh".repeat(1024).as_bytes()).is_none()); + } + #[test] fn middle_large() { let res = build_replaced(check_utf8("abcd&><'\"efgh".repeat(1024).as_bytes()).unwrap()); @@ -373,6 +384,12 @@ mod tests { assert_eq!("abcd&><'"".repeat(1024).as_str(), res); } + #[test] + fn no_change_16() { + let input = "こんにちはこんばんは".encode_utf16().collect::>(); + assert!(no_change::<8, u16>(input.as_slice())); + } + #[test] fn middle_16() { let input = "こんにちは&><'\"こんばんは" @@ -402,6 +419,15 @@ mod tests { assert_eq!("こんにちは&><'"", res); } + #[test] + fn no_change_16_large() { + let input = "こんにちはこんばんは" + .repeat(1024) + .encode_utf16() + .collect::>(); + assert!(no_change::<8, u16>(input.as_slice())); + } + #[test] fn middle_16_large() { let input = "こんにちは&><'\"こんばんは" From ec70e303dd2b9fe87640797ff99ed6a084c40b14 Mon Sep 17 00:00:00 2001 From: carsonburr Date: Sun, 17 Mar 2024 11:15:34 -0400 Subject: [PATCH 05/15] added benchmark against naive implementation --- bench.py | 21 ++++++ setup.py | 95 ++++++---------------------- src/markupsafe/__init__.py | 3 + src/markupsafe/_native.py | 6 +- src/markupsafe/_rust_speedups.pyi | 1 + src/rust/Cargo.toml | 3 + src/rust/src/lib.rs | 102 +++++++++++++++++++++++++++--- 7 files changed, 143 insertions(+), 88 deletions(-) create mode 100644 bench.py diff --git a/bench.py b/bench.py new file mode 100644 index 00000000..f3318738 --- /dev/null +++ b/bench.py @@ -0,0 +1,21 @@ +import pyperf + +runner = pyperf.Runner() + +name = "native" +runner.timeit( + f"escape_inner {name}", + setup=f"from markupsafe._{name} import escape_inner", + stmt='escape_inner("Hello, World!" * 1024)', +) +name = "rust_speedups" +runner.timeit( + f"escape_inner_naive {name}", + setup=f"from markupsafe._{name} import escape_inner_naive", + stmt='escape_inner_naive("Hello, World!" * 1024)', +) +runner.timeit( + f"escape_inner {name}", + setup=f"from markupsafe._{name} import escape_inner", + stmt='escape_inner("Hello, World!" * 1024)', +) diff --git a/setup.py b/setup.py index a63fca99..b5eef907 100644 --- a/setup.py +++ b/setup.py @@ -1,85 +1,30 @@ import os import platform -import sys from setuptools import Extension -from setuptools_rust import RustExtension from setuptools import setup -from setuptools.command.build_ext import build_ext -from setuptools.errors import CCompilerError -from setuptools.errors import ExecError -from setuptools.errors import PlatformError - -ext_modules = [Extension("markupsafe._speedups", ["src/markupsafe/_speedups.c"])] - - -class BuildFailed(Exception): - pass - - -class ve_build_ext(build_ext): - """This class allows C extension building to fail.""" - - def run(self): - try: - super().run() - except PlatformError as e: - raise BuildFailed() from e - - def build_extension(self, ext): - try: - super().build_extension(ext) - except (CCompilerError, ExecError, PlatformError) as e: - raise BuildFailed() from e - except ValueError as e: - # this can happen on Windows 64 bit, see Python issue 7511 - if "'path'" in str(sys.exc_info()[1]): # works with Python 2 and 3 - raise BuildFailed() from e - raise - - -def run_setup(with_binary): - setup( - cmdclass={"build_ext": ve_build_ext}, - ext_modules=ext_modules if with_binary else [], - ) - - -def show_message(*lines): - print("=" * 74) - for line in lines: - print(line) - print("=" * 74) - +from setuptools_rust import RustExtension -supports_speedups = platform.python_implementation() not in { +if platform.python_implementation() not in { "PyPy", "Jython", "GraalVM", -} - -if os.environ.get("CIBUILDWHEEL", "0") == "1" and supports_speedups: - run_setup(True) -elif supports_speedups: - try: - run_setup(True) - except BuildFailed: - show_message( - "WARNING: The C extension could not be compiled, speedups" - " are not enabled.", - "Failure information, if any, is above.", - "Retrying the build without the C extension now.", - ) - run_setup(False) - show_message( - "WARNING: The C extension could not be compiled, speedups" - " are not enabled.", - "Plain-Python build succeeded.", - ) -else: - run_setup(False) - show_message( - "WARNING: C extensions are not supported on this Python" - " platform, speedups are not enabled.", - "Plain-Python build succeeded.", +}: + local = os.environ.get("CIBUILDWHEEL", "0") != "1" + setup( + ext_modules=[ + Extension( + "markupsafe._speedups", ["src/markupsafe/_speedups.c"], optional=local + ) + ], + rust_extensions=[ + RustExtension( + "markupsafe._rust_speedups", + "src/rust/Cargo.toml", + optional=local, + debug=False, + ) + ], ) +else: + setup() diff --git a/src/markupsafe/__init__.py b/src/markupsafe/__init__.py index 2afbfbc6..6482bf4d 100644 --- a/src/markupsafe/__init__.py +++ b/src/markupsafe/__init__.py @@ -320,9 +320,11 @@ def __float__(self, /) -> float: # circular import try: from ._rust_speedups import escape_inner as escape_inner + from ._rust_speedups import escape_inner_naive as escape_inner_naive except ImportError: from ._native import escape_inner as escape_inner + def escape(s: t.Any, /) -> Markup: """Replace the characters ``&``, ``<``, ``>``, ``'``, and ``"`` in the string with HTML-safe sequences. Use this if you need to display @@ -339,6 +341,7 @@ def escape(s: t.Any, /) -> Markup: return Markup(escape_inner(str(s))) + def escape_silent(s: t.Any | None, /) -> Markup: """Like :func:`escape` but treats ``None`` as the empty string. Useful with optional values, as otherwise you get the string diff --git a/src/markupsafe/_native.py b/src/markupsafe/_native.py index 063b13c8..fa116bf1 100644 --- a/src/markupsafe/_native.py +++ b/src/markupsafe/_native.py @@ -1,11 +1,11 @@ from __future__ import annotations -import typing as t def escape_inner(s: str, /) -> str: - return s - .replace("&", "&") + return ( + s.replace("&", "&") .replace(">", ">") .replace("<", "<") .replace("'", "'") .replace('"', """) + ) diff --git a/src/markupsafe/_rust_speedups.pyi b/src/markupsafe/_rust_speedups.pyi index 764b8e61..6dc2a7e0 100644 --- a/src/markupsafe/_rust_speedups.pyi +++ b/src/markupsafe/_rust_speedups.pyi @@ -1 +1,2 @@ def escape_inner(s: str) -> str: ... +def escape_inner_naive(s: str) -> str: ... diff --git a/src/rust/Cargo.toml b/src/rust/Cargo.toml index 005928a4..e8a1f7ba 100644 --- a/src/rust/Cargo.toml +++ b/src/rust/Cargo.toml @@ -4,6 +4,9 @@ version = "0.1.0" edition = "2021" publish = false +[profile.release] +debug = true + [dependencies] pyo3 = "0.20.3" diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index b8b8e0c3..d211a8ad 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -7,6 +7,15 @@ use pyo3::{ PyResult, Python, }; +// #[cfg(target_arch = "x86")] +// use std::arch::x86 as arch; + +// #[cfg(target_arch = "x86_64")] +// use std::arch::x86_64 as arch; + +// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +// use arch::_mm_movemask_epi8; + /// A Rust implemented find and replace for the characters `<`, `>`, `&`, `"`, and `'` /// into the sanitized strings `<`, `>`, `&`, `#34;`, and `#39;` respectively @@ -40,6 +49,7 @@ use pyo3::{ trait Bits: Copy + Eq + BitAnd + BitOr + Into + From { fn ones() -> Self; fn zeroes() -> Self; + fn as_u8(self) -> u8; } // RUST_INTRO: We can implement our trait on types that already exist in the standard library @@ -51,6 +61,10 @@ impl Bits for u8 { fn zeroes() -> Self { 0 } + + fn as_u8(self) -> u8 { + self + } } impl Bits for u16 { @@ -61,6 +75,10 @@ impl Bits for u16 { fn zeroes() -> Self { 0 } + + fn as_u8(self) -> u8 { + self as u8 + } } impl Bits for u32 { @@ -71,6 +89,10 @@ impl Bits for u32 { fn zeroes() -> Self { 0 } + + fn as_u8(self) -> u8 { + self as u8 + } } // auto-vectorized OR @@ -82,7 +104,7 @@ impl Bits for u32 { // merges 2 streams into one with a tuple (a, b) for each item. We can do this twice to get a // tuple containing a tuple ((a, b), r). #[inline(always)] -fn v_or(ax: &[T; N], bx: &[T; N]) -> [T; N] { +fn v_or(ax: [T; N], bx: [T; N]) -> [T; N] { let mut result = [T::zeroes(); N]; for ((&a, &b), r) in ax.iter().zip(bx.iter()).zip(result.iter_mut()) { *r = a | b; @@ -90,7 +112,24 @@ fn v_or(ax: &[T; N], bx: &[T; N]) -> [T; N] { result } -const MASK_BITS: [u64; 8] = [1, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6, 1 << 7]; +const MASK_BITS: [u8; 16] = [ + 1, + 1 << 1, + 1 << 2, + 1 << 3, + 1 << 4, + 1 << 5, + 1 << 6, + 1 << 7, + 1, + 1 << 1, + 1 << 2, + 1 << 3, + 1 << 4, + 1 << 5, + 1 << 6, + 1 << 7, +]; // I attempted to autovectorize to pmovmskb, but since I can't get it to // recognize it, I'll just take advantage of v_eq setting all bits to 1. @@ -103,17 +142,21 @@ const MASK_BITS: [u64; 8] = [1, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6, // technically the last line has the bits reversed. I wrote it this way to // show the relationship between the matched character and the bits. #[inline(always)] -fn v_bitmask(ax: &[T; N]) -> u64 { +fn v_bitmask(ax: [T; N]) -> u64 { + let mut masked = [0u8; N]; + for ((a, m), r) in ax.iter().zip(MASK_BITS.iter()).zip(masked.iter_mut()) { + *r = a.as_u8() & m + } let mut result = 0u64; - for (i, &a) in ax.iter().enumerate() { - result |= (a.into() & MASK_BITS[i % 8]) << ((i / 8) * 8); + for (i, &m) in masked.iter().enumerate() { + result |= (m as u64) << ((i / 8) * 8); } result } // auto-vectorized equal. designed to compile into the instruction pcmpeqb #[inline(always)] -fn v_eq(ax: &[T], bx: &[T; N]) -> [T; N] { +fn v_eq(ax: &[T], bx: [T; N]) -> [T; N] { let mut result = [T::zeroes(); N]; for ((&a, &b), r) in ax.iter().zip(bx.iter()).zip(result.iter_mut()) { *r = if a == b { T::ones() } else { T::zeroes() }; @@ -121,16 +164,15 @@ fn v_eq(ax: &[T], bx: &[T; N]) -> [T; N] { result } +#[inline(always)] fn mask(input: &[T], splats: [[T; N]; M]) -> u64 { let mut result = 0u64; // split into 128-bit chunks to vectorize inside the loop for (i, lane) in input.chunks_exact(N).enumerate() { result |= v_bitmask( - &splats + splats .iter() - .map(|splat| v_eq(lane, splat)) - .reduce(|a, b| v_or(&a, &b)) - .unwrap(), + .fold([T::zeroes(); N], |acc, &splat| v_or(acc, v_eq(lane, splat))), ) << (i * N); } result @@ -327,6 +369,46 @@ pub fn escape_inner<'a>(py: Python<'a>, s: &'a PyString) -> PyResult<&'a PyStrin } } +fn delta_naive(input: &str, replacement_indices: &mut Vec) -> usize { + let mut delta = 0; + for (i, item) in input.chars().enumerate() { + if is_equal_any!(item, b'<' | b'>') { + delta += 3; + replacement_indices.push(i as u32); + } else if is_equal_any!(item, b'"' | b'\'' | b'&') { + delta += 4; + replacement_indices.push(i as u32); + } + } + delta +} + +fn check_utf8_naive(input: &str) -> Option { + let mut replacement_indices = Vec::with_capacity(8); + let delta = delta_naive(input, &mut replacement_indices); + if delta == 0 { + None + } else { + Some(RebuildArgs::new(delta, replacement_indices, input)) + } +} + +#[pyfunction] +pub fn escape_inner_naive<'a>(py: Python<'a>, s: &'a PyString) -> PyResult<&'a PyString> { + let input = s.to_str()?; + Ok(check_utf8_naive(input) + .map(|rb| PyString::new(py, build_replaced(rb).as_str())) + .unwrap_or(s)) +} + +#[pymodule] +#[pyo3(name = "_rust_speedups")] +fn speedups(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(escape_inner, m)?)?; + m.add_function(wrap_pyfunction!(escape_inner_naive, m)?)?; + Ok(()) +} + #[cfg(test)] mod tests { use crate::{build_replaced, check_utf8, no_change}; From 286993502a0c02690e324802de12042a1b0dcdfe Mon Sep 17 00:00:00 2001 From: carsonburr Date: Fri, 29 Mar 2024 18:56:51 -0400 Subject: [PATCH 06/15] added benchmark to compare native python vs naive rust speedups vs optimized rust speedups --- src/rust/src/lib.rs | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index d211a8ad..2d6e7489 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -7,15 +7,6 @@ use pyo3::{ PyResult, Python, }; -// #[cfg(target_arch = "x86")] -// use std::arch::x86 as arch; - -// #[cfg(target_arch = "x86_64")] -// use std::arch::x86_64 as arch; - -// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -// use arch::_mm_movemask_epi8; - /// A Rust implemented find and replace for the characters `<`, `>`, `&`, `"`, and `'` /// into the sanitized strings `<`, `>`, `&`, `#34;`, and `#39;` respectively From e48ca4b114dc6793b8b99100ff2875184709773f Mon Sep 17 00:00:00 2001 From: carsonburr Date: Fri, 29 Mar 2024 19:05:13 -0400 Subject: [PATCH 07/15] mask operates on 4 vectors at a time to try to squeeze some ilp perf --- src/rust/src/lib.rs | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index 2d6e7489..d97ff911 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -159,12 +159,30 @@ fn v_eq(ax: &[T], bx: [T; N]) -> [T; N] { fn mask(input: &[T], splats: [[T; N]; M]) -> u64 { let mut result = 0u64; // split into 128-bit chunks to vectorize inside the loop - for (i, lane) in input.chunks_exact(N).enumerate() { - result |= v_bitmask( - splats - .iter() - .fold([T::zeroes(); N], |acc, &splat| v_or(acc, v_eq(lane, splat))), - ) << (i * N); + let mut i = 0; + let mut iter = input.chunks_exact(N); + while let Some(v0) = iter.next() { + // operate on 4 counts of 128-bit chunks at a time to unlock instruction-level parallelism. + // meaning these v_or and v_eq can execute at the same time on the cpu + // the number of items in iter are always a multiple of 4: + // [u8; 16] * 4 = 64 items + // [u16; 8] * 8 = 64 items + // [u32; 4] * 16 = 64 items + let v1 = iter.next().unwrap(); + let v2 = iter.next().unwrap(); + let v3 = iter.next().unwrap(); + let v_masks = splats.iter().fold([[T::zeroes(); N]; 4], |acc, &splat| { + [ + v_or(acc[0], v_eq(v0, splat)), + v_or(acc[1], v_eq(v1, splat)), + v_or(acc[2], v_eq(v2, splat)), + v_or(acc[3], v_eq(v3, splat)), + ] + }); + for (j, &v_mask) in v_masks.iter().enumerate() { + result |= v_bitmask(v_mask) << ((i + j) * N) + } + i += 1; } result } From a21d5a1740061aeadb0576ef769d17c174ed4bad Mon Sep 17 00:00:00 2001 From: carsonburr Date: Mon, 22 Apr 2024 19:40:39 -0400 Subject: [PATCH 08/15] fix rust_speedups export and add it to bench.py --- bench.py | 2 +- src/markupsafe/__init__.py | 2 +- src/rust/src/lib.rs | 37 ++----------------------------------- 3 files changed, 4 insertions(+), 37 deletions(-) diff --git a/bench.py b/bench.py index 59617aa8..74555b34 100644 --- a/bench.py +++ b/bench.py @@ -8,7 +8,7 @@ ("long plain", '"Hello, World!" * 1000'), ("long suffix", '"Hello, World!" + "x" * 100_000'), ): - for mod in "native", "speedups": + for mod in "native", "speedups", "rust_speedups": subprocess.run( [ sys.executable, diff --git a/src/markupsafe/__init__.py b/src/markupsafe/__init__.py index 00cf6b8f..efdcfce1 100644 --- a/src/markupsafe/__init__.py +++ b/src/markupsafe/__init__.py @@ -6,7 +6,7 @@ import typing as t try: - from ._speedups import _escape_inner + from ._rust_speedups import _escape_inner except ImportError: from ._native import _escape_inner diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index d97ff911..28d1571c 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -356,7 +356,7 @@ fn escape_other_format<'a, const N: usize, T: Bits>( } #[pyfunction] -pub fn escape_inner<'a>(py: Python<'a>, s: &'a PyString) -> PyResult<&'a PyString> { +pub fn _escape_inner<'a>(py: Python<'a>, s: &'a PyString) -> PyResult<&'a PyString> { // SAFETY: from the py03 docs: // This function implementation relies on manually decoding a C bitfield. // In practice, this works well on common little-endian architectures such @@ -378,43 +378,10 @@ pub fn escape_inner<'a>(py: Python<'a>, s: &'a PyString) -> PyResult<&'a PyStrin } } -fn delta_naive(input: &str, replacement_indices: &mut Vec) -> usize { - let mut delta = 0; - for (i, item) in input.chars().enumerate() { - if is_equal_any!(item, b'<' | b'>') { - delta += 3; - replacement_indices.push(i as u32); - } else if is_equal_any!(item, b'"' | b'\'' | b'&') { - delta += 4; - replacement_indices.push(i as u32); - } - } - delta -} - -fn check_utf8_naive(input: &str) -> Option { - let mut replacement_indices = Vec::with_capacity(8); - let delta = delta_naive(input, &mut replacement_indices); - if delta == 0 { - None - } else { - Some(RebuildArgs::new(delta, replacement_indices, input)) - } -} - -#[pyfunction] -pub fn escape_inner_naive<'a>(py: Python<'a>, s: &'a PyString) -> PyResult<&'a PyString> { - let input = s.to_str()?; - Ok(check_utf8_naive(input) - .map(|rb| PyString::new(py, build_replaced(rb).as_str())) - .unwrap_or(s)) -} - #[pymodule] #[pyo3(name = "_rust_speedups")] fn speedups(_py: Python<'_>, m: &PyModule) -> PyResult<()> { - m.add_function(wrap_pyfunction!(escape_inner, m)?)?; - m.add_function(wrap_pyfunction!(escape_inner_naive, m)?)?; + m.add_function(wrap_pyfunction!(_escape_inner, m)?)?; Ok(()) } From 4a3e904f654e3956fd72b870e05af26b4beeb5aa Mon Sep 17 00:00:00 2001 From: carsonburr Date: Mon, 22 Apr 2024 19:52:06 -0400 Subject: [PATCH 09/15] remove naive rust impl signature --- src/markupsafe/_rust_speedups.pyi | 1 - 1 file changed, 1 deletion(-) diff --git a/src/markupsafe/_rust_speedups.pyi b/src/markupsafe/_rust_speedups.pyi index 6dc2a7e0..764b8e61 100644 --- a/src/markupsafe/_rust_speedups.pyi +++ b/src/markupsafe/_rust_speedups.pyi @@ -1,2 +1 @@ def escape_inner(s: str) -> str: ... -def escape_inner_naive(s: str) -> str: ... From 97ee7e8c85d5741a2472b5295d643a899d40388a Mon Sep 17 00:00:00 2001 From: carsonburr Date: Mon, 22 Apr 2024 19:55:17 -0400 Subject: [PATCH 10/15] fix _escape_inner signature --- src/markupsafe/_rust_speedups.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/markupsafe/_rust_speedups.pyi b/src/markupsafe/_rust_speedups.pyi index 764b8e61..839223ed 100644 --- a/src/markupsafe/_rust_speedups.pyi +++ b/src/markupsafe/_rust_speedups.pyi @@ -1 +1 @@ -def escape_inner(s: str) -> str: ... +def _escape_inner(s: str) -> str: ... From 7755c7e2914ba129a47dc4a537db54d819564cda Mon Sep 17 00:00:00 2001 From: carsonburr Date: Mon, 22 Apr 2024 20:01:14 -0400 Subject: [PATCH 11/15] fix _escape_inner signature attempt 2 --- src/markupsafe/_rust_speedups.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/markupsafe/_rust_speedups.pyi b/src/markupsafe/_rust_speedups.pyi index 839223ed..8c888585 100644 --- a/src/markupsafe/_rust_speedups.pyi +++ b/src/markupsafe/_rust_speedups.pyi @@ -1 +1 @@ -def _escape_inner(s: str) -> str: ... +def _escape_inner(s: str, /) -> str: ... From adcba6f326b87b5de247352964d3a2d57bb74811 Mon Sep 17 00:00:00 2001 From: carsonburr Date: Thu, 29 Aug 2024 18:09:41 -0700 Subject: [PATCH 12/15] rust speedups: simplified to lookup table instead of simd --- bench.py | 2 +- src/rust/Cargo.lock | 151 ++---------- src/rust/Cargo.toml | 2 +- src/rust/src/lib.rs | 577 +++++++++----------------------------------- 4 files changed, 128 insertions(+), 604 deletions(-) diff --git a/bench.py b/bench.py index 74555b34..88c66fa2 100644 --- a/bench.py +++ b/bench.py @@ -8,7 +8,7 @@ ("long plain", '"Hello, World!" * 1000'), ("long suffix", '"Hello, World!" + "x" * 100_000'), ): - for mod in "native", "speedups", "rust_speedups": + for mod in "native", "rust_speedups": subprocess.run( [ sys.executable, diff --git a/src/rust/Cargo.lock b/src/rust/Cargo.lock index f8aa6734..5f43fd20 100644 --- a/src/rust/Cargo.lock +++ b/src/rust/Cargo.lock @@ -8,12 +8,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - [[package]] name = "cfg-if" version = "1.0.0" @@ -22,9 +16,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "heck" -version = "0.4.1" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "indoc" @@ -38,16 +32,6 @@ version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" -[[package]] -name = "lock_api" -version = "0.4.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" -dependencies = [ - "autocfg", - "scopeguard", -] - [[package]] name = "markupsafe-rust" version = "0.1.0" @@ -70,29 +54,6 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" -[[package]] -name = "parking_lot" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets", -] - [[package]] name = "portable-atomic" version = "1.6.0" @@ -101,24 +62,24 @@ checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" [[package]] name = "proc-macro2" -version = "1.0.78" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" dependencies = [ "unicode-ident", ] [[package]] name = "pyo3" -version = "0.20.3" +version = "0.22.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233" +checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433" dependencies = [ "cfg-if", "indoc", "libc", "memoffset", - "parking_lot", + "once_cell", "portable-atomic", "pyo3-build-config", "pyo3-ffi", @@ -128,9 +89,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.20.3" +version = "0.22.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "deaa5745de3f5231ce10517a1f5dd97d53e5a2fd77aa6b5842292085831d48d7" +checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8" dependencies = [ "once_cell", "target-lexicon", @@ -138,9 +99,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.20.3" +version = "0.22.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b42531d03e08d4ef1f6e85a2ed422eb678b8cd62b762e53891c05faf0d4afa" +checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6" dependencies = [ "libc", "pyo3-build-config", @@ -148,9 +109,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.20.3" +version = "0.22.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7305c720fa01b8055ec95e484a6eca7a83c841267f0dd5280f0c8b8551d2c158" +checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -160,9 +121,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.20.3" +version = "0.22.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c7e9b68bb9c3149c5b0cade5d07f953d6d125eb4337723c4ccdb665f1f96185" +checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372" dependencies = [ "heck", "proc-macro2", @@ -180,32 +141,11 @@ dependencies = [ "proc-macro2", ] -[[package]] -name = "redox_syscall" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" -dependencies = [ - "bitflags", -] - -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - -[[package]] -name = "smallvec" -version = "1.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" - [[package]] name = "syn" -version = "2.0.52" +version = "2.0.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b699d15b36d1f02c3e7c69f8ffef53de37aefae075d8488d4ba1a7788d574a07" +checksum = "578e081a14e0cefc3279b0472138c513f37b41a08d5a3cca9b6e4e8ceb6cd525" dependencies = [ "proc-macro2", "quote", @@ -229,60 +169,3 @@ name = "unindent" version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" - -[[package]] -name = "windows-targets" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" -dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" - -[[package]] -name = "windows_i686_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" - -[[package]] -name = "windows_i686_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" diff --git a/src/rust/Cargo.toml b/src/rust/Cargo.toml index e8a1f7ba..f97c4901 100644 --- a/src/rust/Cargo.toml +++ b/src/rust/Cargo.toml @@ -8,7 +8,7 @@ publish = false debug = true [dependencies] -pyo3 = "0.20.3" +pyo3 = "0.22.2" [lib] name = "_rust_speedups" diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index 28d1571c..510c2c9e 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -1,535 +1,176 @@ -use std::ops::{BitAnd, BitOr}; -use std::str::from_utf8_unchecked; - use pyo3::prelude::*; -use pyo3::{ - types::{PyString, PyStringData}, - PyResult, Python, +use pyo3::{types::PyString, PyResult, Python}; + +static NEEDS_SANITIZE: [bool; 256] = { + const __: bool = false; + const XX: bool = true; + [ + // 1 2 3 4 5 6 7 8 9 A B C D E F + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 0 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 1 + __, __, XX, __, __, __, XX, XX, __, __, __, __, __, __, __, __, // 2 + __, __, __, __, __, __, __, __, __, __, __, __, XX, __, XX, __, // 3 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 4 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 5 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 6 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 7 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 8 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 9 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // A + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // B + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // C + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // D + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // E + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // F + ] }; -/// A Rust implemented find and replace for the characters `<`, `>`, `&`, `"`, and `'` -/// into the sanitized strings `<`, `>`, `&`, `#34;`, and `#39;` respectively - -// Since speed was is a concern, I try to strike a balance between fast and readable, borrowing -// some lessons from simd-json. Namely I use vectorized and branchless-ish processing, but instead -// of using explicit vectorization, I let the compiler auto-vectorize the code for us, keeping the -// code readable, portable, and safe. -// -// To guarantee this optimization we need to help the compiler to recognize the patterns. We need to -// operate on 128-bit chunks whenever possible, and to do so generically because python could give -// us characters encoded as u8, u16, or u32. -// > By default, the i686 and x86_64 Rust targets enable sse and sse2 (128-bit vector extensions) -// https://github.com/rust-lang/portable-simd/blob/master/beginners-guide.md -// -// To enable this, I used some intermediate Rust features, so I'll include comments labelled -// RUST_INTRO where they're relevant, but they include: -// -// - const generics - similar to generics, but for compile-time constants. This means that -// different calls to the same function can work with different numbers in their -// types. -// - traits - like an interface in java. Defines the behavior we might want out of a -// generic type that can be implemented by many different concrete types. -// - macro_rules - used to create syntax sugar for some repetative code. -// - iterators - a lazy data stream used mostly to make for-in loops look nice. They have a -// similar interface to java's stream api, but compile to optimized loops. - -// A trait that describes anything we might want to do with bits so we can generically work with -// u8, u16, or u32 -// RUST_INTRO: The traits after the colon are trait-bounds. Any types implementing Bits must already -// implement these other traits. These are all handled by the standard library. -trait Bits: Copy + Eq + BitAnd + BitOr + Into + From { - fn ones() -> Self; - fn zeroes() -> Self; - fn as_u8(self) -> u8; -} - -// RUST_INTRO: We can implement our trait on types that already exist in the standard library -impl Bits for u8 { - fn ones() -> Self { - Self::MAX - } - - fn zeroes() -> Self { - 0 - } - - fn as_u8(self) -> u8 { - self - } -} - -impl Bits for u16 { - fn ones() -> Self { - Self::MAX - } - - fn zeroes() -> Self { - 0 - } - - fn as_u8(self) -> u8 { - self as u8 - } -} - -impl Bits for u32 { - fn ones() -> Self { - Self::MAX - } - - fn zeroes() -> Self { - 0 - } - - fn as_u8(self) -> u8 { - self as u8 - } -} - -// auto-vectorized OR -// RUST_INTRO: the `(ax: [T; N], bx: [T; N]) -> [T; N] { - let mut result = [T::zeroes(); N]; - for ((&a, &b), r) in ax.iter().zip(bx.iter()).zip(result.iter_mut()) { - *r = a | b; - } - result -} - -const MASK_BITS: [u8; 16] = [ - 1, - 1 << 1, - 1 << 2, - 1 << 3, - 1 << 4, - 1 << 5, - 1 << 6, - 1 << 7, - 1, - 1 << 1, - 1 << 2, - 1 << 3, - 1 << 4, - 1 << 5, - 1 << 6, - 1 << 7, -]; - -// I attempted to autovectorize to pmovmskb, but since I can't get it to -// recognize it, I'll just take advantage of v_eq setting all bits to 1. -// -// converts chunks of ones created by v_eq into a bitmask with 1s where the -// matching characters were. e.g. -// input = "xx(ax: [T; N]) -> u64 { - let mut masked = [0u8; N]; - for ((a, m), r) in ax.iter().zip(MASK_BITS.iter()).zip(masked.iter_mut()) { - *r = a.as_u8() & m - } - let mut result = 0u64; - for (i, &m) in masked.iter().enumerate() { - result |= (m as u64) << ((i / 8) * 8); - } - result -} - -// auto-vectorized equal. designed to compile into the instruction pcmpeqb -#[inline(always)] -fn v_eq(ax: &[T], bx: [T; N]) -> [T; N] { - let mut result = [T::zeroes(); N]; - for ((&a, &b), r) in ax.iter().zip(bx.iter()).zip(result.iter_mut()) { - *r = if a == b { T::ones() } else { T::zeroes() }; - } - result -} - -#[inline(always)] -fn mask(input: &[T], splats: [[T; N]; M]) -> u64 { - let mut result = 0u64; - // split into 128-bit chunks to vectorize inside the loop - let mut i = 0; - let mut iter = input.chunks_exact(N); - while let Some(v0) = iter.next() { - // operate on 4 counts of 128-bit chunks at a time to unlock instruction-level parallelism. - // meaning these v_or and v_eq can execute at the same time on the cpu - // the number of items in iter are always a multiple of 4: - // [u8; 16] * 4 = 64 items - // [u16; 8] * 8 = 64 items - // [u32; 4] * 16 = 64 items - let v1 = iter.next().unwrap(); - let v2 = iter.next().unwrap(); - let v3 = iter.next().unwrap(); - let v_masks = splats.iter().fold([[T::zeroes(); N]; 4], |acc, &splat| { - [ - v_or(acc[0], v_eq(v0, splat)), - v_or(acc[1], v_eq(v1, splat)), - v_or(acc[2], v_eq(v2, splat)), - v_or(acc[3], v_eq(v3, splat)), - ] - }); - for (j, &v_mask) in v_masks.iter().enumerate() { - result |= v_bitmask(v_mask) << ((i + j) * N) - } - i += 1; - } - result -} - -// a splat is an array containing the same element repeated. -// e.g. [u8; 16] splat of "<" is "<<<<<<<<<<<<<<<<" -// a single call to mask might need many of those, so we construct them with a macro -// RUST_INTRO: `$items:expr` matches any input that looks like an expression. `$(...),+` means that -// the pattern inside is repeated one or more times, separated by commas -macro_rules! make_splats { - ($($items:expr),+) => { - [$([$items.into(); N],)+] - }; -} - -macro_rules! is_equal_any { - ($lhs:expr, $($rhs:literal)|+) => { - $(($lhs == $rhs.into()))|+ - }; -} - -// Tying everything together, we calculate the delta between the input size and the output. -// the algorithm: take chunks of 64 items at a time so mask() can create a u64 representing -// which items, if any, have characters that need replacing. We can then count how many of each -// character class is in the input and keep any indices if needed. Finally if the input doesn't -// neatly fit into 64 item chunks, we need a slow version to do the same to the remainder. -fn delta(input: &[T], replacement_indices: &mut Vec) -> usize { - // calls to mask() create a u64 mask representing 64 items - let chunks = input.chunks_exact(64); - let remainder = chunks.remainder(); - - let mut delta = 0; - for (i, chunk) in chunks.enumerate() { - let delta_3_mask = mask::(chunk, make_splats!(b'<', b'>')); - let delta_4_mask = mask::(chunk, make_splats!(b'"', b'\'', b'&')); - // count_ones() is a single instruction on x86 - let delta_3 = delta_3_mask.count_ones(); - let delta_4 = delta_4_mask.count_ones(); - delta += (delta_3 * 3) + (delta_4 * 4); - - let mut all_mask = delta_3_mask | delta_4_mask; - let mut count = delta_3 + delta_4; - let idx = (i * 64) as u32; - while count > 0 { - replacement_indices.push(idx + all_mask.trailing_zeros()); - all_mask &= all_mask.wrapping_sub(1); - count -= 1; - } - } - - let idx = ((input.len() / 64) * 64) as u32; - for (i, &item) in remainder.iter().enumerate() { - if is_equal_any!(item, b'<' | b'>') { - delta += 3; - replacement_indices.push(idx + i as u32); - } else if is_equal_any!(item, b'"' | b'\'' | b'&') { - delta += 4; - replacement_indices.push(idx + i as u32); +pub fn needs_sanitize(bytes: &[u8]) -> Option { + for (i, &b) in bytes.iter().enumerate() { + if NEEDS_SANITIZE[b as usize] { + return Some(i); } } - delta as usize -} -// very similar to delta(), but short-circuits because if there is anything to replace, we need to -// convert to utf-8 anyway to calculate delta and indices -fn no_change(input: &[T]) -> bool { - // calls to mask() create a u64 mask representing 64 items - let chunks = input.chunks_exact(64); - let remainder = chunks.remainder(); - - for chunk in chunks { - let any_mask = mask::(chunk, make_splats!(b'<', b'>', b'"', b'\'', b'&')); - if any_mask != 0 { - return false; - } - } + None +} + +static SANITIZE_INDEX: [i8; 256] = { + const __: i8 = -1; + [ + // 1 2 3 4 5 6 7 8 9 A B C D E F + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 0 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 1 + __, __, 00, __, __, __, 01, 02, __, __, __, __, __, __, __, __, // 2 + __, __, __, __, __, __, __, __, __, __, __, __, 03, __, 04, __, // 3 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 4 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 5 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 6 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 7 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 8 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 9 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // A + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // B + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // C + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // D + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // E + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // F + ] +}; - for &item in remainder { - if is_equal_any!(item, b'<' | b'>' | b'"' | b'\'' | b'&') { - return false; +static SANITIZED_VALUE: [&str; 5] = [""", "&", "'", "<", ">"]; + +pub fn lut_replace(input: &str) -> Option { + let bytes = input.as_bytes(); + if let Some(mut idx) = needs_sanitize(bytes) { + let mut out = String::with_capacity(input.len()); + let mut prev_idx = 0; + for &b in bytes[idx..].iter() { + let replace_idx = SANITIZE_INDEX[b as usize]; + if replace_idx >= 0 { + if prev_idx < idx { + out.push_str(&input[prev_idx..idx]); + } + out.push_str(SANITIZED_VALUE[replace_idx as usize]); + prev_idx = idx + 1; + } + idx += 1; } - } - true -} - -// builds the sanitized output string. Copies the input bytes from sections that haven't changed -// and replaces the characters that need sanitizing. -fn build_replaced( - RebuildArgs { - delta, - replacement_indices, - input_str, - }: RebuildArgs<'_>, -) -> String { - // we could create the string without the size, but with_capacity means we - // never need to re-allocate the backing memory - let mut builder = String::with_capacity(input_str.len() + delta); - let mut prev_idx = 0usize; - for idx in replacement_indices { - let idx = idx as usize; if prev_idx < idx { - builder.push_str(&input_str[prev_idx..idx]); - } - builder.push_str(match &input_str[idx..idx + 1] { - "<" => "<", - ">" => ">", - "\"" => """, - "\'" => "'", - "&" => "&", - _ => unreachable!(""), - }); - prev_idx = idx + 1; - } - if prev_idx != input_str.len() - 1 { - builder.push_str(&input_str[prev_idx..input_str.len()]); - } - builder -} - -struct RebuildArgs<'a> { - delta: usize, - replacement_indices: Vec, - input_str: &'a str, -} - -impl<'a> RebuildArgs<'a> { - fn new(delta: usize, replacement_indices: Vec, input_str: &'a str) -> Self { - RebuildArgs { - delta, - replacement_indices, - input_str, + out.push_str(&input[prev_idx..idx]); } - } -} - -fn check_utf8(input: &[u8]) -> Option { - let mut replacement_indices = Vec::with_capacity(8); - let delta = delta::<16, u8>(input, &mut replacement_indices); - if delta == 0 { - None - } else { - // SAFETY: The rest of the code assumes that python gives us valid utf-8, so to avoid - // validation or copying, we will here too - // https://docs.python.org/3.12//c-api/unicode.html - let input_str = unsafe { from_utf8_unchecked(input) }; - Some(RebuildArgs::new(delta, replacement_indices, input_str)) - } -} - -fn escape_utf8<'a>(py: Python<'a>, orig: &'a PyString, input: &[u8]) -> &'a PyString { - check_utf8(input) - .map(|rb| PyString::new(py, build_replaced(rb).as_str())) - .unwrap_or(orig) -} - -fn escape_other_format<'a, const N: usize, T: Bits>( - py: Python<'a>, - orig: &'a PyString, - input: &[T], -) -> PyResult<&'a PyString> { - // there's no safe way to construct a utf-16 or unicode string to pass back to python without - // using the unsafe ffi bindings, so best case scenario we short circuit and pass the original - // string back, otherwise just slow-path convert to utf-8 and process it that way since we'd - // have to recompute the delta and indices anyway - if no_change::(input) { - Ok(orig) + Some(out) } else { - orig.to_str() - .map(|utf8| escape_utf8(py, orig, utf8.as_bytes())) + None } } #[pyfunction] -pub fn _escape_inner<'a>(py: Python<'a>, s: &'a PyString) -> PyResult<&'a PyString> { - // SAFETY: from the py03 docs: - // This function implementation relies on manually decoding a C bitfield. - // In practice, this works well on common little-endian architectures such - // as x86_64, where the bitfield has a common representation (even if it is - // not part of the C spec). The PyO3 CI tests this API on x86_64 platforms. - // - // The C implementation already does this. Python strings can be represented - // as u8, u16, or u32, and this avoids converting to utf-8 if it's not - // necessary, meaning if a u16 or u32 string doesn't need any characters - // replaced, we can short-circuit without doing any converting - let data_res = unsafe { s.data() }; - match data_res { - Ok(data) => match data { - PyStringData::Ucs1(raw) => Ok(escape_utf8(py, s, raw)), - PyStringData::Ucs2(raw) => escape_other_format::<8, u16>(py, s, raw), - PyStringData::Ucs4(raw) => escape_other_format::<4, u32>(py, s, raw), - }, - Err(e) => Err(e), +pub fn _escape_inner<'py>( + py: Python<'py>, + s: Bound<'py, PyString>, +) -> PyResult> { + if let Some(out) = lut_replace(s.to_str()?) { + Ok(PyString::new_bound(py, out.as_str())) + } else { + Ok(s) } } #[pymodule] #[pyo3(name = "_rust_speedups")] -fn speedups(_py: Python<'_>, m: &PyModule) -> PyResult<()> { +fn speedups<'py>(_py: Python<'py>, m: &Bound<'py, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(_escape_inner, m)?)?; Ok(()) } #[cfg(test)] mod tests { - use crate::{build_replaced, check_utf8, no_change}; + use crate::lut_replace; #[test] fn empty() { - let res = check_utf8("".as_bytes()); - assert!(res.is_none()); + let inp = ""; + assert_eq!(inp, lut_replace(inp).unwrap()); } #[test] fn no_change_test() { - let res = check_utf8("abcdefgh".as_bytes()); - assert!(res.is_none()); + let inp = "abcdefgh"; + assert_eq!(inp, lut_replace(inp).unwrap()); } #[test] fn middle() { - let res = build_replaced(check_utf8("abcd&><'\"efgh".as_bytes()).unwrap()); - assert_eq!("abcd&><'"efgh", res); + assert_eq!( + "abcd&><'"efgh", + lut_replace("abcd&><'\"efgh").unwrap() + ); } #[test] fn begin() { - let res = build_replaced(check_utf8("&><'\"efgh".as_bytes()).unwrap()); - assert_eq!("&><'"efgh", res); + assert_eq!( + "&><'"efgh", + lut_replace("&><'\"efgh").unwrap() + ); } #[test] fn end() { - let res = build_replaced(check_utf8("abcd&><'\"".as_bytes()).unwrap()); - assert_eq!("abcd&><'"", res); + assert_eq!( + "abcd&><'"", + lut_replace("abcd&><'\"").unwrap() + ); } #[test] fn no_change_large() { - assert!(check_utf8("abcdefgh".repeat(1024).as_bytes()).is_none()); + let inp = "abcdefgh".repeat(1024); + assert_eq!(inp, lut_replace(inp.as_str()).unwrap()); } #[test] fn middle_large() { - let res = build_replaced(check_utf8("abcd&><'\"efgh".repeat(1024).as_bytes()).unwrap()); - assert_eq!("abcd&><'"efgh".repeat(1024).as_str(), res); - } - - #[test] - fn begin_large() { - let res = build_replaced(check_utf8("&><'\"efgh".repeat(1024).as_bytes()).unwrap()); - assert_eq!("&><'"efgh".repeat(1024).as_str(), res); - } - - #[test] - fn end_large() { - let res = build_replaced(check_utf8("abcd&><'\"".repeat(1024).as_bytes()).unwrap()); - assert_eq!("abcd&><'"".repeat(1024).as_str(), res); - } - - #[test] - fn no_change_16() { - let input = "こんにちはこんばんは".encode_utf16().collect::>(); - assert!(no_change::<8, u16>(input.as_slice())); - } - - #[test] - fn middle_16() { - let input = "こんにちは&><'\"こんばんは" - .encode_utf16() - .collect::>(); - assert!(!no_change::<8, u16>(input.as_slice())); - let utf8 = String::from_utf16(input.as_slice()).unwrap(); - let res = build_replaced(check_utf8(utf8.as_bytes()).unwrap()); - assert_eq!("こんにちは&><'"こんばんは", res); - } - - #[test] - fn begin_16() { - let input = "&><'\"こんばんは".encode_utf16().collect::>(); - assert!(!no_change::<8, u16>(input.as_slice())); - let utf8 = String::from_utf16(input.as_slice()).unwrap(); - let res = build_replaced(check_utf8(utf8.as_bytes()).unwrap()); - assert_eq!("&><'"こんばんは", res); - } - - #[test] - fn end_16() { - let input = "こんにちは&><'\"".encode_utf16().collect::>(); - assert!(!no_change::<8, u16>(input.as_slice())); - let utf8 = String::from_utf16(input.as_slice()).unwrap(); - let res = build_replaced(check_utf8(utf8.as_bytes()).unwrap()); - assert_eq!("こんにちは&><'"", res); - } - - #[test] - fn no_change_16_large() { - let input = "こんにちはこんばんは" - .repeat(1024) - .encode_utf16() - .collect::>(); - assert!(no_change::<8, u16>(input.as_slice())); - } - - #[test] - fn middle_16_large() { - let input = "こんにちは&><'\"こんばんは" - .repeat(1024) - .encode_utf16() - .collect::>(); - assert!(!no_change::<8, u16>(input.as_slice())); - let utf8 = String::from_utf16(input.as_slice()).unwrap(); - let res = build_replaced(check_utf8(utf8.as_bytes()).unwrap()); assert_eq!( - "こんにちは&><'"こんばんは" - .repeat(1024) - .as_str(), - res + "abcd&><'"efgh".repeat(1024).as_str(), + lut_replace("abcd&><'\"efgh".repeat(1024).as_str()).unwrap() ); } #[test] - fn begin_16_large() { - let input = "&><'\"こんばんは" - .repeat(1024) - .encode_utf16() - .collect::>(); - assert!(!no_change::<8, u16>(input.as_slice())); - let utf8 = String::from_utf16(input.as_slice()).unwrap(); - let res = build_replaced(check_utf8(utf8.as_bytes()).unwrap()); + fn begin_large() { assert_eq!( - "&><'"こんばんは".repeat(1024).as_str(), - res + "&><'"efgh".repeat(1024).as_str(), + lut_replace("&><'\"efgh".repeat(1024).as_str()).unwrap() ); } #[test] - fn end_16_large() { - let input = "こんにちは&><'\"" - .repeat(1024) - .encode_utf16() - .collect::>(); - assert!(!no_change::<8, u16>(input.as_slice())); - let utf8 = String::from_utf16(input.as_slice()).unwrap(); - let res = build_replaced(check_utf8(utf8.as_bytes()).unwrap()); + fn end_large() { assert_eq!( - "こんにちは&><'"".repeat(1024).as_str(), - res + "abcd&><'"".repeat(1024).as_str(), + lut_replace("abcd&><'\"".repeat(1024).as_str()).unwrap() ); } } From ca768ad3d62345a37cd10964b72c80c99f1cd213 Mon Sep 17 00:00:00 2001 From: carsonburr Date: Thu, 29 Aug 2024 20:20:16 -0700 Subject: [PATCH 13/15] rust speedups: instruction-level-parallelsim for short circuit check --- src/rust/src/lib.rs | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index 510c2c9e..a3fcecb4 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -26,9 +26,22 @@ static NEEDS_SANITIZE: [bool; 256] = { }; pub fn needs_sanitize(bytes: &[u8]) -> Option { - for (i, &b) in bytes.iter().enumerate() { + let chunks = bytes.chunks_exact(4); + let rest = chunks.remainder(); + + for (i, chunk) in chunks.enumerate() { + let a = NEEDS_SANITIZE[chunk[0] as usize]; + let b = NEEDS_SANITIZE[chunk[1] as usize]; + let c = NEEDS_SANITIZE[chunk[2] as usize]; + let d = NEEDS_SANITIZE[chunk[3] as usize]; + if a | b | c | d { + return Some(i * 4); + } + } + + for (i, &b) in rest.iter().enumerate() { if NEEDS_SANITIZE[b as usize] { - return Some(i); + return Some(((bytes.len() / 4) * 4) + i); } } @@ -86,12 +99,9 @@ pub fn lut_replace(input: &str) -> Option { } #[pyfunction] -pub fn _escape_inner<'py>( - py: Python<'py>, - s: Bound<'py, PyString>, -) -> PyResult> { +pub fn _escape_inner<'py>(py: Python<'py>, s: &'py PyString) -> PyResult<&'py PyString> { if let Some(out) = lut_replace(s.to_str()?) { - Ok(PyString::new_bound(py, out.as_str())) + Ok(PyString::new(py, out.as_str())) } else { Ok(s) } @@ -99,7 +109,7 @@ pub fn _escape_inner<'py>( #[pymodule] #[pyo3(name = "_rust_speedups")] -fn speedups<'py>(_py: Python<'py>, m: &Bound<'py, PyModule>) -> PyResult<()> { +fn speedups<'py>(_py: Python<'py>, m: &'py PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(_escape_inner, m)?)?; Ok(()) } From b14f35dea87222338391d8a19b8eada7d5f36a05 Mon Sep 17 00:00:00 2001 From: carsonburr Date: Thu, 29 Aug 2024 20:33:01 -0700 Subject: [PATCH 14/15] rust speedups: bug fix --- src/rust/src/lib.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index a3fcecb4..38a679a2 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -99,9 +99,12 @@ pub fn lut_replace(input: &str) -> Option { } #[pyfunction] -pub fn _escape_inner<'py>(py: Python<'py>, s: &'py PyString) -> PyResult<&'py PyString> { +pub fn _escape_inner<'py>( + py: Python<'py>, + s: Bound<'py, PyString>, +) -> PyResult> { if let Some(out) = lut_replace(s.to_str()?) { - Ok(PyString::new(py, out.as_str())) + Ok(PyString::new_bound(py, out.as_str())) } else { Ok(s) } @@ -109,7 +112,7 @@ pub fn _escape_inner<'py>(py: Python<'py>, s: &'py PyString) -> PyResult<&'py Py #[pymodule] #[pyo3(name = "_rust_speedups")] -fn speedups<'py>(_py: Python<'py>, m: &'py PyModule) -> PyResult<()> { +fn speedups<'py>(_py: Python<'py>, m: &Bound<'py, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(_escape_inner, m)?)?; Ok(()) } From e431070fbd14a86cda530af7493ffd599da097b5 Mon Sep 17 00:00:00 2001 From: carsonburr Date: Fri, 30 Aug 2024 17:44:09 -0700 Subject: [PATCH 15/15] rust speedups: changed table representation --- src/rust/src/lib.rs | 61 +++++++++++++-------------------------------- 1 file changed, 17 insertions(+), 44 deletions(-) diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index 38a679a2..1a75564d 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -2,27 +2,13 @@ use pyo3::prelude::*; use pyo3::{types::PyString, PyResult, Python}; static NEEDS_SANITIZE: [bool; 256] = { - const __: bool = false; - const XX: bool = true; - [ - // 1 2 3 4 5 6 7 8 9 A B C D E F - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 0 - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 1 - __, __, XX, __, __, __, XX, XX, __, __, __, __, __, __, __, __, // 2 - __, __, __, __, __, __, __, __, __, __, __, __, XX, __, XX, __, // 3 - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 4 - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 5 - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 6 - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 7 - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 8 - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 9 - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // A - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // B - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // C - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // D - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // E - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // F - ] + let mut needs_sanitize = [false; 256]; + needs_sanitize[b'"' as usize] = true; + needs_sanitize[b'&' as usize] = true; + needs_sanitize[b'\'' as usize] = true; + needs_sanitize[b'<' as usize] = true; + needs_sanitize[b'>' as usize] = true; + needs_sanitize }; pub fn needs_sanitize(bytes: &[u8]) -> Option { @@ -49,26 +35,13 @@ pub fn needs_sanitize(bytes: &[u8]) -> Option { } static SANITIZE_INDEX: [i8; 256] = { - const __: i8 = -1; - [ - // 1 2 3 4 5 6 7 8 9 A B C D E F - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 0 - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 1 - __, __, 00, __, __, __, 01, 02, __, __, __, __, __, __, __, __, // 2 - __, __, __, __, __, __, __, __, __, __, __, __, 03, __, 04, __, // 3 - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 4 - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 5 - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 6 - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 7 - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 8 - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 9 - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // A - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // B - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // C - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // D - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // E - __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // F - ] + let mut sanitize_index = [-1; 256]; + sanitize_index[b'"' as usize] = 0; + sanitize_index[b'&' as usize] = 1; + sanitize_index[b'\'' as usize] = 2; + sanitize_index[b'<' as usize] = 3; + sanitize_index[b'>' as usize] = 4; + sanitize_index }; static SANITIZED_VALUE: [&str; 5] = [""", "&", "'", "<", ">"]; @@ -124,13 +97,13 @@ mod tests { #[test] fn empty() { let inp = ""; - assert_eq!(inp, lut_replace(inp).unwrap()); + assert!(lut_replace(inp).is_none()); } #[test] fn no_change_test() { let inp = "abcdefgh"; - assert_eq!(inp, lut_replace(inp).unwrap()); + assert!(lut_replace(inp).is_none()); } #[test] @@ -160,7 +133,7 @@ mod tests { #[test] fn no_change_large() { let inp = "abcdefgh".repeat(1024); - assert_eq!(inp, lut_replace(inp.as_str()).unwrap()); + assert!(lut_replace(inp.as_str()).is_none()); } #[test]