Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimise stringToEnum #3863

Open
daurnimator opened this issue Dec 8, 2019 · 16 comments
Open

Optimise stringToEnum #3863

daurnimator opened this issue Dec 8, 2019 · 16 comments
Labels
optimization standard library This issue involves writing Zig code for the standard library.
Milestone

Comments

@daurnimator
Copy link
Collaborator

The current implementation of stringToEnum is:

pub fn stringToEnum(comptime T: type, str: []const u8) ?T {
    inline for (@typeInfo(T).Enum.fields) |enumField| {
        if (std.mem.eql(u8, str, enumField.name)) {
            return @field(T, enumField.name);
        }
    }
    return null;
}

This could be much more efficient if a perfect hash was created and is therefore one motivation for a perfect hashing algorithm to be in the standard library.

FWIW my current usecase for this is converting known HTTP field names to an enum.

@daurnimator daurnimator added optimization standard library This issue involves writing Zig code for the standard library. labels Dec 8, 2019
@data-man
Copy link
Contributor

data-man commented Dec 8, 2019

perfect hashing algorithm to be in the standard library

Based on https://andrewkelley.me/post/string-matching-comptime-perfect-hashing-zig.html

perfecthash.zig
const std = @import("std");
const assert = std.debug.assert;

pub fn perfectHash(comptime strs: []const []const u8) type {
    const Op = union(enum) {
        /// add the length of the string
        Length,

        /// add the byte at index % len
        Index: usize,

        /// right shift then xor with constant
        XorShiftMultiply: u32,
    };
    const S = struct {
        fn hash(comptime plan: []Op, s: []const u8) u32 {
            var h: u32 = 0;
            inline for (plan) |op| {
                switch (op) {
                    Op.Length => {
                        h +%= @truncate(u32, s.len);
                    },
                    Op.Index => |index| {
                        h +%= s[index % s.len];
                    },
                    Op.XorShiftMultiply => |x| {
                        h ^= x >> 16;
                    },
                }
            }
            return h;
        }

        fn testPlan(comptime plan: []Op) bool {
            var hit = [1]bool{false} ** strs.len;
            for (strs) |s| {
                const h = hash(plan, s);
                const i = h % hit.len;
                if (hit[i]) {
                    // hit this index twice
                    return false;
                }
                hit[i] = true;
            }
            return true;
        }
    };

    var ops_buf: [10]Op = undefined;

    const plan = have_a_plan: {
        var seed: u32 = 0x45d9f3b;
        var index_i: usize = 0;
        const try_seed_count = 50;
        const try_index_count = 50;

        while (index_i < try_index_count) : (index_i += 1) {
            const bool_values = if (index_i == 0) [_]bool{true} else [_]bool{ false, true };
            for (bool_values) |try_length| {
                var seed_i: usize = 0;
                while (seed_i < try_seed_count) : (seed_i += 1) {
                    comptime var rand_state = std.rand.Xoroshiro128.init(seed + seed_i);
                    const rng = &rand_state.random;

                    var ops_index = 0;

                    if (try_length) {
                        ops_buf[ops_index] = Op.Length;
                        ops_index += 1;

                        if (S.testPlan(ops_buf[0..ops_index]))
                            break :have_a_plan ops_buf[0..ops_index];

                        ops_buf[ops_index] = Op{ .XorShiftMultiply = rng.scalar(u32) };
                        ops_index += 1;

                        if (S.testPlan(ops_buf[0..ops_index]))
                            break :have_a_plan ops_buf[0..ops_index];
                    }

                    ops_buf[ops_index] = Op{ .XorShiftMultiply = rng.scalar(u32) };
                    ops_index += 1;

                    if (S.testPlan(ops_buf[0..ops_index]))
                        break :have_a_plan ops_buf[0..ops_index];

                    const before_bytes_it_index = ops_index;

                    var byte_index = 0;
                    while (byte_index < index_i) : (byte_index += 1) {
                        ops_index = before_bytes_it_index;

                        ops_buf[ops_index] = Op{ .Index = rng.scalar(u32) % try_index_count };
                        ops_index += 1;

                        if (S.testPlan(ops_buf[0..ops_index]))
                            break :have_a_plan ops_buf[0..ops_index];

                        ops_buf[ops_index] = Op{ .XorShiftMultiply = rng.scalar(u32) };
                        ops_index += 1;

                        if (S.testPlan(ops_buf[0..ops_index]))
                            break :have_a_plan ops_buf[0..ops_index];
                    }
                }
            }
        }

        @compileError("unable to come up with perfect hash");
    };

    return struct {
        pub fn case(comptime s: []const u8) usize {
            inline for (strs) |str| {
                if (std.mem.eql(u8, str, s))
                    return hash(s);
            }
            @compileError("case value '" ++ s ++ "' not declared");
        }
        pub fn hash(s: []const u8) usize {
            const ok = for (strs) |str| {
                if (std.mem.eql(u8, str, s))
                    break true;
            } else false;
            if (ok) {
                return S.hash(plan, s) % strs.len;
            } else {
                return S.hash(plan, s);
            }
        }
    };
}
test_pf.zig
const std = @import("std");
const perfectHash = @import("perfecthash.zig").perfectHash;
const assert = std.debug.assert;

test "perfect hashing" {
    basedOnLength("ab");
}

fn basedOnLength(target: []const u8) void {
    const ph = perfectHash(&[_][]const u8{
        "a",
        "ab",
        "abc",
    });
    switch (ph.hash(target)) {
        ph.case("a") => @panic("wrong one a\n"),
        ph.case("ab") => {}, // test pass
        ph.case("abc") => @panic("wrong one abc\n"),
        else => std.debug.warn("not found\n"),
    }
}

test "perfect hashing 2" {
    @setEvalBranchQuota(100000);
    const target = "eno";
    const ph = perfectHash(&[_][]const u8{
        "one",
        "eno",
        "two",
        "three",
        "four",
        "five",
    });
    switch (ph.hash(target)) {
        ph.case("one") => std.debug.warn("handle the one case\n"),
        ph.case("eno") => std.debug.warn("handle the eno case\n"),
        ph.case("two") => std.debug.warn("handle the two case\n"),
        ph.case("three") => std.debug.warn("handle the three case\n"),
        ph.case("four") => std.debug.warn("handle the four case\n"),
        ph.case("five") => std.debug.warn("handle the five case\n"),
        else => std.debug.warn("not found\n"),
    }
}

test "perfect hashing 3" {
    @setEvalBranchQuota(100000);
    const target = "six";
    const ph = perfectHash(&[_][]const u8{
        "one",
        "eno",
        "two",
        "three",
        "four",
        "five",
    });
    switch (ph.hash(target)) {
        ph.case("one") => std.debug.warn("handle the one case\n"),
        ph.case("eno") => std.debug.warn("handle the eno case\n"),
        ph.case("two") => std.debug.warn("handle the two case\n"),
        ph.case("three") => std.debug.warn("handle the three case\n"),
        ph.case("four") => std.debug.warn("handle the four case\n"),
        ph.case("five") => std.debug.warn("handle the five case\n"),
        else => std.debug.warn("{} not found\n", target),
    }
}

@andrewrk andrewrk added this to the 0.7.0 milestone Dec 9, 2019
@frmdstryr
Copy link
Contributor

I think it would be an awesome thing to have built into the language, eg

switch (target) : (hashFn) { // Or some way to set which hash fn to use at comptime
    "one" => ...,
    "two" => ...,
    // etc...
}

@data-man
Copy link
Contributor

Some useful links:

BBHash, go-bbhash and rust-boomphf based on this paper: Fast and scalable minimal perfect hashing for massive key sets

And rust-phf uses CHD algorithm.

@N00byEdge
Copy link
Sponsor Contributor

Or constructing a compile time trie...

@data-man
Copy link
Contributor

Attempt №2

perfectHash for any type
const std = @import("std");
const mem = std.mem;
const warn = std.debug.warn;

pub fn perfectHash(comptime T: type, comptime cases: []const T) type {
    const Op = union(enum) {
        /// add the length of the string
        Length,

        /// add the byte at index % len
        Index: usize,

        /// right shift then xor with constant
        XorShiftMultiply: u32,
    };
    const S = struct {
        fn hash(comptime plan: []Op, s: []const u8) u32 {
            var h: u32 = 0;
            inline for (plan) |op| {
                switch (op) {
                    Op.Length => {
                        h +%= @truncate(u32, s.len);
                    },
                    Op.Index => |index| {
                        h +%= s[index % s.len];
                    },
                    Op.XorShiftMultiply => |x| {
                        h ^= x >> 16;
                    },
                }
            }
            return h;
        }

        fn testPlan(comptime plan: []Op) bool {
            comptime var hit = [1]bool{false} ** cases.len;
            for (cases) |c| {
                const b = mem.toBytes(c);
                const h = hash(plan, b[0..]);
                const i = h % hit.len;
                if (hit[i]) {
                    // hit this index twice
                    return false;
                }
                hit[i] = true;
            }
            return true;
        }
    };

    var ops_buf: [10]Op = undefined;

    const plan = have_a_plan: {
        var seed: u32 = 0x45d9f3b;
        var index_i: usize = 0;
        const try_seed_count = 50;
        const try_index_count = 50;

        @setEvalBranchQuota(50000);

        while (index_i < try_index_count) : (index_i += 1) {
            const bool_values = if (index_i == 0) [_]bool{true} else [_]bool{ false, true };
            for (bool_values) |try_length| {
                var seed_i: usize = 0;
                while (seed_i < try_seed_count) : (seed_i += 1) {
                    comptime var rand_state = std.rand.Xoroshiro128.init(seed + seed_i);
                    const rng = &rand_state.random;

                    var ops_index = 0;

                    if (try_length) {
                        ops_buf[ops_index] = Op.Length;
                        ops_index += 1;

                        if (S.testPlan(ops_buf[0..ops_index]))
                            break :have_a_plan ops_buf[0..ops_index];

                        ops_buf[ops_index] = Op{ .XorShiftMultiply = rng.scalar(u32) };
                        ops_index += 1;

                        if (S.testPlan(ops_buf[0..ops_index]))
                            break :have_a_plan ops_buf[0..ops_index];
                    }

                    ops_buf[ops_index] = Op{ .XorShiftMultiply = rng.scalar(u32) };
                    ops_index += 1;

                    if (S.testPlan(ops_buf[0..ops_index]))
                        break :have_a_plan ops_buf[0..ops_index];

                    const before_bytes_it_index = ops_index;

                    var byte_index = 0;
                    while (byte_index < index_i) : (byte_index += 1) {
                        ops_index = before_bytes_it_index;

                        ops_buf[ops_index] = Op{ .Index = rng.scalar(u32) % try_index_count };
                        ops_index += 1;

                        if (S.testPlan(ops_buf[0..ops_index]))
                            break :have_a_plan ops_buf[0..ops_index];

                        ops_buf[ops_index] = Op{ .XorShiftMultiply = rng.scalar(u32) };
                        ops_index += 1;

                        if (S.testPlan(ops_buf[0..ops_index]))
                            break :have_a_plan ops_buf[0..ops_index];
                    }
                }
            }
        }

        @compileError("unable to come up with perfect hash");
    };

    return struct {
        pub fn case(comptime c: T) usize {
            inline for (cases) |c2| {
                if (std.meta.eql(c, c2))
                    return hash(c);
            }
            @compileLog("case value ", c, " not declared!");
        }
        pub fn hash(c: T) usize {
            const ok = for (cases) |c2| {
                if (std.meta.eql(c, c2))
                    break true;
            } else false;
            const b = mem.toBytes(c);
            if (ok) {
                return S.hash(plan, b[0..]) % cases.len;
            } else {
                return S.hash(plan, b[0..]);
            }
        }
    };
}

test "perfect hashing v2" {
    const target1 = 3;
    const target2 = 30;
    const ph = perfectHash(u16, &[_]u16{
        1,
        2,
        3,
        4,
        5,
        6,
    });
    switch (ph.hash(target1)) {
        ph.case(1) => warn("handle the {} case\n", .{target1}),
        ph.case(2) => warn("handle the {} case\n", .{target1}),
        ph.case(3) => warn("handle the {} case\n", .{target1}),
        ph.case(4) => warn("handle the {} case\n", .{target1}),
        ph.case(5) => warn("handle the {} case\n", .{target1}),
        ph.case(6) => warn("handle the {} case\n", .{target1}),
        else => warn("case {} not found\n", .{target1}),
    }
    switch (ph.hash(target2)) {
        ph.case(1) => warn("handle the {} case\n", .{target2}),
        ph.case(2) => warn("handle the {} case\n", .{target2}),
        ph.case(3) => warn("handle the {} case\n", .{target2}),
        ph.case(4) => warn("handle the {} case\n", .{target2}),
        ph.case(5) => warn("handle the {} case\n", .{target2}),
        ph.case(6) => warn("handle the {} case\n", .{target2}),
        else => warn("case {} not found\n", .{target2}),
    }
}

My goals: autoPerfectHash and autoPerfectHashMap.

@daurnimator
Copy link
Collaborator Author

 const ok = for (cases) |c2| {
                if (std.meta.eql(c, c2))
                    break true;
            } else false;

This looks like the expensive comparison that perfect hashing is meant to avoid?

@data-man
Copy link
Contributor

All questions to the creator. 😄

I hope this loop executed in comptime only.

@squeek502
Copy link
Collaborator

I hope this loop executed in comptime only.

In Andrew's blog post, that loop is wrapped in if (std.debug.runtime_safety):

if (std.debug.runtime_safety) {
    const ok = for (strs) |str| {
        if (std.mem.eql(u8, str, s))
            break true;
    } else false;
    if (!ok) {
        std.debug.panic("attempt to perfect hash {} which was not declared", s);
    }
}

i.e. it's not included in ReleaseFast/ReleaseSmall.

@daurnimator
Copy link
Collaborator Author

For stringToEnum you wouldn't want that loop in there either: return null when the element is not in the perfect hash.

@squeek502
Copy link
Collaborator

squeek502 commented Dec 31, 2019

return null when the element is not in the perfect hash.

Is this feasible? How could the perfect hash know when something is not one of the original set of values?

As an example, if it turns out that in and not_in both hash to the same value with the chosen perfect hashing algorithm, wouldn't it just treat not_in the same as in? How could it know to return null for not_in instead?

@daurnimator
Copy link
Collaborator Author

Is this feasible? How could the perfect hash know when something is not one of the original set of values?

Once you hash to a given element, you then verify that the input matches that member.

@squeek502
Copy link
Collaborator

squeek502 commented Dec 31, 2019

Ah, I see; sounds similar to what the Zig tokenizer does now. Hashes for each keyword are computed at compile time and the lookup function checks the hash first before checking mem.eql. Perfect hashing could remove the need for the loop in getKeyword, though.

@daurnimator
Copy link
Collaborator Author

Yep! Seems like you found another place that would able to immediately make use of the new machinery :)

@andrewrk
Copy link
Member

One thing to remember is to perf test. @Hejsil did some experiments with this earlier and determined that, at least in release-fast mode, the optimizer given if-else chains was able to outperform a perfect hash implementation.

@pixelherodev
Copy link
Contributor

pixelherodev commented May 17, 2020

Dumb question: what does "perfect" hash mean? Got a quick answer, thanks haze :)

@squeek502
Copy link
Collaborator

As discovered in Vexu/arocc#524 (Vexu/arocc#524 (comment)), stringToEnum could have considerably better codegen for large enums if it were sorted by field length before the inline for:

zig/lib/std/meta.zig

Lines 41 to 46 in a126afa

inline for (@typeInfo(T).Enum.fields) |enumField| {
if (mem.eql(u8, str, enumField.name)) {
return @field(T, enumField.name);
}
}
return null;

Here's a benchmark focusing on just different possible sorting of enum fields (this is with 3948 fields in the enum, shortest field length is 3 and longest is 43):

-------------------- unsorted ---------------------
            always found: 3718ns per lookup
not found (random bytes): 6638ns per lookup
 not found (1 char diff): 3819ns per lookup

----------- sorted by length (desc) ---------------
            always found: 1176ns per lookup
not found (random bytes): 68ns per lookup
 not found (1 char diff): 1173ns per lookup

----------- sorted by length (asc) ----------------
            always found: 1054ns per lookup
not found (random bytes): 67ns per lookup
 not found (1 char diff): 1053ns per lookup

-------- sorted lexicographically (asc) -----------
            always found: 2764ns per lookup
not found (random bytes): 4615ns per lookup
 not found (1 char diff): 2750ns per lookup

This would ultimately be a trade-off between compile time and runtime performance. I haven't tested to see how much of an impact on the compile time the comptime sorting of the fields would incur. We might end up hitting #4055, in which case this optimization might need to wait a bit.

Note: Sorting would also make it easy to create a fast path that checks that the str.len is within the bounds of the longest/shortest enum field.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
optimization standard library This issue involves writing Zig code for the standard library.
Projects
None yet
Development

No branches or pull requests

7 participants