Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integer overflow in std.sort.sort due to incorrect lessThan implementation #8289

Closed
zbanks opened this issue Mar 18, 2021 · 6 comments · Fixed by #16183
Closed

integer overflow in std.sort.sort due to incorrect lessThan implementation #8289

zbanks opened this issue Mar 18, 2021 · 6 comments · Fixed by #16183
Labels
standard library This issue involves writing Zig code for the standard library.
Milestone

Comments

@zbanks
Copy link
Contributor

zbanks commented Mar 18, 2021

I encountered an integer overflow in std.sort.sort on master (96ae451)

Example code (sort_test.zig)

const std = @import("std");

fn compareLeq(context: void, left: u8, right: u8) bool {
    return left <= right;       // This crashes when buffer.len >= 1024
    // return true;             // This also has the same behavior
}

test {
    var buffer = try std.testing.allocator.alloc(u8, 1024);
    defer std.testing.allocator.free(buffer);
    for (buffer) |*b| b.* = 0;

    // These work fine...
    std.sort.sort(u8, buffer[0..1023], {}, comptime std.sort.asc(u8));
    std.sort.sort(u8, buffer[0..1024], {}, comptime std.sort.asc(u8));
    std.sort.sort(u8, buffer[0..1023], {}, comptime std.sort.desc(u8));
    std.sort.sort(u8, buffer[0..1024], {}, comptime std.sort.desc(u8));
    std.sort.sort(u8, buffer[0..1023], {}, compareLeq);

    // ..but this crashes
    std.sort.sort(u8, buffer[0..1024], {}, compareLeq);
}

zig test results:

> zig version
0.8.0-dev.4206+96ae451

> uname -mo
x86_64 GNU/Linux

> zig test sort_test.zig
thread 25201 panic: integer overflow
/home/zbanks/zig/build/lib/zig/std/sort.zig:584:85: 0x21d64a in std.sort.sort (test)
                        mem.rotate(T, items[range.start..range.end], range.length() - count);
                                                                                    ^
/tmp/sort_test.zig:22:18: 0x207ad1 in test "" (test)
    std.sort.sort(usize, buffer[0..1024], {}, compareLeq);
                 ^
/home/zbanks/zig/build/lib/zig/std/special/test_runner.zig:69:28: 0x258ca8 in std.special.main (test)
        } else test_fn.func();
                           ^
/home/zbanks/zig/build/lib/zig/std/start.zig:345:37: 0x21fd34 in std.start.posixCallMainAndExit (test)
            const result = root.main() catch |err| {
                                    ^
/home/zbanks/zig/build/lib/zig/std/start.zig:163:5: 0x21fbd2 in std.start._start (test)
    @call(.{ .modifier = .never_inline }, posixCallMainAndExit, .{});
    ^
error: the following test command crashed:
zig-cache/o/434532b9c0442450ef292055642d5b54/test /home/zbanks/zig/build/zig

Locals, from gdb:

>6  0x000000000024bd0b in std.sort.sort (items=...) at /home/zbanks/zig/build/lib/zig/std/sort.zig:584
584	                        mem.rotate(T, items[range.start..range.end], range.length() - count);
(gdb) info local
range = {start = 24, end = 24}
length = 24
A = {start = 0, end = 512}
last = 23
pull_index = 0
pull = {{from = 23, to = 0, count = 24, range = {start = 0, end = 1024}}, {from = 0, to = 0, count = 0, range = {start = 0, end = 0}}}
buffer1 = {start = 0, end = 24}
block_size = 22
buffer_size = 24
B = {start = 512, end = 1024}
count = 1
buffer2 = {start = 0, end = 0}
find_separately = false
index = 23
find = 24
start = 0
cache = {0 <repeats 256 times>, 12297829382473034410 <repeats 256 times>}
iterator = {size = 1024, power_of_two = 1024, numerator = 0, decimal = 1024, denominator = 256, decimal_step = 512, numerator_step = 0}

I encountered this on a list with unique (I think) elements -- but this is more concise example.

@LemonBoy
Copy link
Contributor

The problem is your predicate, the algorithm seems to work only when the comparison is less-than (or greater-than).
You can open a ticket here to let the algorithm author know of this failure mode.

@andrewrk andrewrk added bug Observed behavior contradicts documented or intended behavior standard library This issue involves writing Zig code for the standard library. labels May 19, 2021
@andrewrk andrewrk added this to the 0.9.0 milestone May 19, 2021
@andrewrk andrewrk modified the milestones: 0.9.0, 0.10.0 Nov 24, 2021
@andrewrk andrewrk modified the milestones: 0.10.0, 0.11.0 Apr 16, 2022
@andrewrk andrewrk removed the bug Observed behavior contradicts documented or intended behavior label Jun 8, 2022
@andrewrk
Copy link
Member

andrewrk commented Jun 8, 2022

Would be nice to have some debug safety checks, perhaps calling lessThan(a, b) and lessThan(b, a) and asserting that the return values are not equal.

@andrewrk andrewrk changed the title Integer overflow in std.sort.sort integer overflow in std.sort.sort due to incorrect lessThan implementation Apr 8, 2023
@wooster0
Copy link
Contributor

wooster0 commented Apr 17, 2023

When trying to implement the suggested solution I face the problem of obtaining two unique values A and B of any type T. It's easy for types that you can @bitCast because you can just set @sizeOf(T) bytes and cast them to a value of that type (do this two times and set the first bytes all to 0 and the second all to 1), but a lot of types (such as regular structs) can not be casted.

My current solution would be to handle each type separately which is very cumbersome and seems kind of fragile:

/// Applies a safety check to the user-provided lessThan function to make sure that
/// the comparison is either less-than or greater-than and does not include an equality comparison of A and B as that doesn't work on the algorithm.
fn checkLessThan(
    comptime T: type,
    context: anytype,
    comptime lessThan: fn (context: @TypeOf(context), lhs: T, rhs: T) bool,
) void {
    // We need two values A and B distinct from each other of type T. If A and B are equal, it means the comparison function checks for equality.

    const a = comptime getUnique(T, 0) orelse return; // TODO: do we @compileError if we can't obtain two unique values from this type?
    const b = comptime getUnique(T, 1) orelse return;
    assert(lessThan(context, a, b) != lessThan(context, b, a)); // Make sure lessThan only checks for less-than or greater-than and doesn't include an equality check!
}

fn getUnique(comptime T: type, comptime index: u1) ?T {
    switch (@typeInfo(T)) {
        .ComptimeInt, .Int, .ComptimeFloat, .Float => {
            return @as(T, index);
        },
        .Enum, .EnumLiteral => {
            return std.meta.intToEnum(T, index) orelse null;
        },
        .Void => {
            return null;
        },
        .Bool => {
            return index != 0;
        },
        .Optional, .Null => {
            return switch (index) {
                0 => @as(T, null),
                1 => @as(T, getUnique(T, index)),
            };
        },
        .Struct => |struct_info| {
            if (@sizeOf(T) == 0) return null;
            var value: T = undefined;
            inline for (struct_info.fields) |field| {
                if (!field.is_comptime) {
                    @field(value, field.name) = getUnique(field.type, index) orelse return null;
                }
            }
            return value;
        },
        .Pointer => return @intToPtr(T, index + 1),
        .Array => {
            var arr: T = undefined;
            arr[0] = getUnique(T, index);
            return arr;
        },
        .Vector => {
            // TODO: test this
            var arr: T = undefined;
            arr[0] = getUnique(T, index);
            return arr;
        },
        .Union => {
            return null; // TODO
        },
        .ErrorUnion => return @as(u1!error{}, index),
        .ErrorSet => {
            return switch (index) {
                0 => error{a},
                1 => error{b},
            };
        },
        .Fn => {
            const a = struct {
                fn a() void {}
            }.a;
            const b = struct {
                fn b() void {}
            }.b;
            return switch (index) {
                0 => a,
                1 => b,
            };
        },
        .Type => {
            return switch (index) {
                0 => u0,
                1 => u1,
            };
        },
        .NoReturn => return null,
        .Undefined => return null,
        .Opaque => @compileError("TODO"),
        .Frame => @compileError("TODO"),
        .AnyFrame => @compileError("TODO"),
    }
}

Is this approach a way at all?
checkLessThan would be used as checkLessThan(T, context, lessThan) at the start of for example pub fn sort.

Alternatively, should we only support this safety check for types that can be bit-casted, or just types where <, >, and = can be used?
If we do this, maybe we can have std.meta.bitCast (returns null if not bit-castable) to save others from writing this logic in the future.

So, the complexity of this safety check depends on whether we consider it possible for the user to, for example, check two function prototypes for equality in lessThan (meaning, do they have the same names, do they have the same parameters, and so on). It seems a bit weird to go to such lengths like I'm doing in my example.


As another alternative, maybe we can just copy-paste a note about how the lessThan has to be implemented to the public functions.

@kprotty
Copy link
Member

kprotty commented May 23, 2023

The check could just be done at each call to lessThan (pseudo code):

sort(lessThan):
  actualSort(lamda a, b:
    if debug: assert(lessThan(a, b) != lessThan(b, a)
    return lessThan(a, b) 
  )

The invariant doesn't need to be amortized to the start since it's only on debug and perf doesnt matter there. The invariant could also be broken on lessThan(a, c) instead of lessThan(a, b) so the check should always happen each time.

@alichraghi
Copy link
Contributor

this doesn't work when array contains duplicate items

const std = @import("std");

test {
    var items: [4]u8 = .{ 1, 2, 2, 3 };
    std.sort.block(u8, &items, {}, std.sort.asc(u8));
}

@andrewrk
Copy link
Member

It only needs to check that if lessThan(a, b) is true, then lessThan(b, a) is false. It's allowed to return false for both of them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
standard library This issue involves writing Zig code for the standard library.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants