Skip to content

Commit

Permalink
saturating arithmetic builtins: add, sub, mul, shl (#9619)
Browse files Browse the repository at this point in the history
- adds 1 simple behavior tests for each
  which does integer and vector ops at
  runtime and comptime
- adds bigint_*_sat() methods for each

- use CreateIntrinsic() which accepts a
  variable number of arguments to pass
  the scale parameter

* update langref
- added case to test/compile_errors.zig given floats

- explain upstream bug in llvm.smul.fix.sat and link to #9643 in langref and commented out test cases

* sat-arithmetic: skip mul tests if arch == .wasm32 because ci is erroring with 'LLVM ERROR: Unable to expand fixed point multiplication' when compiling for wasm32
  • Loading branch information
travisstaloch committed Sep 1, 2021
1 parent 4f0aa7d commit 21a5769
Show file tree
Hide file tree
Showing 17 changed files with 613 additions and 3 deletions.
58 changes: 55 additions & 3 deletions doc/langref.html.in
Original file line number Diff line number Diff line change
Expand Up @@ -7031,6 +7031,16 @@ fn readFile(allocator: *Allocator, filename: []const u8) ![]u8 {
If no overflow or underflow occurs, returns {#syntax#}false{#endsyntax#}.
</p>
{#header_close#}
{#header_open|@addWithSaturation#}
<pre>{#syntax#}@addWithSaturation(a: T, b: T) T{#endsyntax#}</pre>
<p>
Returns {#syntax#}a + b{#endsyntax#}. The result will be clamped between the type maximum and minimum.
</p>
<p>
Once <a href="https://github.com/ziglang/zig/issues/1284">Saturating arithmetic</a>.
is completed, the syntax {#syntax#}a +| b{#endsyntax#} will be equivalent to calling {#syntax#}@addWithSaturation(a, b){#endsyntax#}.
</p>
{#header_close#}
{#header_open|@alignCast#}
<pre>{#syntax#}@alignCast(comptime alignment: u29, ptr: anytype) anytype{#endsyntax#}</pre>
<p>
Expand Down Expand Up @@ -8143,6 +8153,22 @@ test "@wasmMemoryGrow" {
If no overflow or underflow occurs, returns {#syntax#}false{#endsyntax#}.
</p>
{#header_close#}

{#header_open|@mulWithSaturation#}
<pre>{#syntax#}@mulWithSaturation(a: T, b: T) T{#endsyntax#}</pre>
<p>
Returns {#syntax#}a * b{#endsyntax#}. The result will be clamped between the type maximum and minimum.
</p>
<p>
Once <a href="https://github.com/ziglang/zig/issues/1284">Saturating arithmetic</a>.
is completed, the syntax {#syntax#}a *| b{#endsyntax#} will be equivalent to calling {#syntax#}@mulWithSaturation(a, b){#endsyntax#}.
</p>
<p>
NOTE: Currently there is a bug in the llvm.smul.fix.sat intrinsic which affects {#syntax#}@mulWithSaturation{#endsyntax#} of signed integers.
This may result in an incorrect sign bit when there is overflow. This will be fixed in zig's 0.9.0 release.
Check <a href="https://github.com/ziglang/zig/issues/9643">this issue</a> for more information.
</p>
{#header_close#}

{#header_open|@panic#}
<pre>{#syntax#}@panic(message: []const u8) noreturn{#endsyntax#}</pre>
Expand Down Expand Up @@ -8368,7 +8394,7 @@ test "@setRuntimeSafety" {
The type of {#syntax#}shift_amt{#endsyntax#} is an unsigned integer with {#syntax#}log2(T.bit_count){#endsyntax#} bits.
This is because {#syntax#}shift_amt >= T.bit_count{#endsyntax#} is undefined behavior.
</p>
{#see_also|@shrExact|@shlWithOverflow#}
{#see_also|@shrExact|@shlWithOverflow|@shlWithSaturation#}
{#header_close#}

{#header_open|@shlWithOverflow#}
Expand All @@ -8382,7 +8408,22 @@ test "@setRuntimeSafety" {
The type of {#syntax#}shift_amt{#endsyntax#} is an unsigned integer with {#syntax#}log2(T.bit_count){#endsyntax#} bits.
This is because {#syntax#}shift_amt >= T.bit_count{#endsyntax#} is undefined behavior.
</p>
{#see_also|@shlExact|@shrExact#}
{#see_also|@shlExact|@shrExact|@shlWithSaturation#}
{#header_close#}

{#header_open|@shlWithSaturation#}
<pre>{#syntax#}@shlWithSaturation(a: T, shift_amt: T) T{#endsyntax#}</pre>
<p>
Returns {#syntax#}a << b{#endsyntax#}. The result will be clamped between type minimum and maximum.
</p>
<p>
Once <a href="https://github.com/ziglang/zig/issues/1284">Saturating arithmetic</a>.
is completed, the syntax {#syntax#}a <<| b{#endsyntax#} will be equivalent to calling {#syntax#}@shlWithSaturation(a, b){#endsyntax#}.
</p>
<p>
Unlike other @shl builtins, shift_amt doesn't need to be a Log2T as saturated overshifting is well defined.
</p>
{#see_also|@shlExact|@shrExact|@shlWithOverflow#}
{#header_close#}

{#header_open|@shrExact#}
Expand All @@ -8395,7 +8436,7 @@ test "@setRuntimeSafety" {
The type of {#syntax#}shift_amt{#endsyntax#} is an unsigned integer with {#syntax#}log2(T.bit_count){#endsyntax#} bits.
This is because {#syntax#}shift_amt >= T.bit_count{#endsyntax#} is undefined behavior.
</p>
{#see_also|@shlExact|@shlWithOverflow#}
{#see_also|@shlExact|@shlWithOverflow|@shlWithSaturation#}
{#header_close#}

{#header_open|@shuffle#}
Expand Down Expand Up @@ -8694,6 +8735,17 @@ fn doTheTest() !void {
If no overflow or underflow occurs, returns {#syntax#}false{#endsyntax#}.
</p>
{#header_close#}

{#header_open|@subWithSaturation#}
<pre>{#syntax#}@subWithSaturation(a: T, b: T) T{#endsyntax#}</pre>
<p>
Returns {#syntax#}a - b{#endsyntax#}. The result will be clamped between the type maximum and minimum.
</p>
<p>
Once <a href="https://github.com/ziglang/zig/issues/1284">Saturating arithmetic</a>.
is completed, the syntax {#syntax#}a -| b{#endsyntax#} will be equivalent to calling {#syntax#}@subWithSaturation(a, b){#endsyntax#}.
</p>
{#header_close#}

{#header_open|@tagName#}
<pre>{#syntax#}@tagName(value: anytype) [:0]const u8{#endsyntax#}</pre>
Expand Down
23 changes: 23 additions & 0 deletions src/AstGen.zig
Original file line number Diff line number Diff line change
Expand Up @@ -7301,6 +7301,11 @@ fn builtinCall(
return rvalue(gz, rl, result, node);
},

.add_with_saturation => return saturatingArithmetic(gz, scope, rl, node, params, .add_with_saturation),
.sub_with_saturation => return saturatingArithmetic(gz, scope, rl, node, params, .sub_with_saturation),
.mul_with_saturation => return saturatingArithmetic(gz, scope, rl, node, params, .mul_with_saturation),
.shl_with_saturation => return saturatingArithmetic(gz, scope, rl, node, params, .shl_with_saturation),

.atomic_load => {
const int_type = try typeExpr(gz, scope, params[0]);
const ptr_type = try gz.add(.{ .tag = .ptr_type_simple, .data = .{
Expand Down Expand Up @@ -7693,6 +7698,24 @@ fn overflowArithmetic(
return rvalue(gz, rl, result, node);
}

fn saturatingArithmetic(
gz: *GenZir,
scope: *Scope,
rl: ResultLoc,
node: ast.Node.Index,
params: []const ast.Node.Index,
tag: Zir.Inst.Extended,
) InnerError!Zir.Inst.Ref {
const lhs = try expr(gz, scope, .none, params[0]);
const rhs = try expr(gz, scope, .none, params[1]);
const result = try gz.addExtendedPayload(tag, Zir.Inst.SaturatingArithmetic{
.node = gz.nodeIndexToRelative(node),
.lhs = lhs,
.rhs = rhs,
});
return rvalue(gz, rl, result, node);
}

fn callExpr(
gz: *GenZir,
scope: *Scope,
Expand Down
32 changes: 32 additions & 0 deletions src/BuiltinFn.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ const std = @import("std");

pub const Tag = enum {
add_with_overflow,
add_with_saturation,
align_cast,
align_of,
as,
Expand Down Expand Up @@ -65,6 +66,7 @@ pub const Tag = enum {
wasm_memory_grow,
mod,
mul_with_overflow,
mul_with_saturation,
panic,
pop_count,
ptr_cast,
Expand All @@ -79,10 +81,12 @@ pub const Tag = enum {
set_runtime_safety,
shl_exact,
shl_with_overflow,
shl_with_saturation,
shr_exact,
shuffle,
size_of,
splat,
sub_with_saturation,
reduce,
src,
sqrt,
Expand Down Expand Up @@ -527,6 +531,34 @@ pub const list = list: {
.param_count = 2,
},
},
.{
"@addWithSaturation",
.{
.tag = .add_with_saturation,
.param_count = 2,
},
},
.{
"@subWithSaturation",
.{
.tag = .sub_with_saturation,
.param_count = 2,
},
},
.{
"@mulWithSaturation",
.{
.tag = .mul_with_saturation,
.param_count = 2,
},
},
.{
"@shlWithSaturation",
.{
.tag = .shl_with_saturation,
.param_count = 2,
},
},
.{
"@memcpy",
.{
Expand Down
17 changes: 17 additions & 0 deletions src/Sema.zig
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,10 @@ fn zirExtended(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) CompileEr
.c_define => return sema.zirCDefine( block, extended),
.wasm_memory_size => return sema.zirWasmMemorySize( block, extended),
.wasm_memory_grow => return sema.zirWasmMemoryGrow( block, extended),
.add_with_saturation=> return sema.zirSatArithmetic( block, extended),
.sub_with_saturation=> return sema.zirSatArithmetic( block, extended),
.mul_with_saturation=> return sema.zirSatArithmetic( block, extended),
.shl_with_saturation=> return sema.zirSatArithmetic( block, extended),
// zig fmt: on
}
}
Expand Down Expand Up @@ -5691,6 +5695,19 @@ fn zirOverflowArithmetic(
return sema.mod.fail(&block.base, src, "TODO implement Sema.zirOverflowArithmetic", .{});
}

fn zirSatArithmetic(
sema: *Sema,
block: *Scope.Block,
extended: Zir.Inst.Extended.InstData,
) CompileError!Air.Inst.Ref {
const tracy = trace(@src());
defer tracy.end();

const extra = sema.code.extraData(Zir.Inst.SaturatingArithmetic, extended.operand).data;
const src: LazySrcLoc = .{ .node_offset = extra.node };
return sema.mod.fail(&block.base, src, "TODO implement Sema.zirSatArithmetic", .{});
}

fn analyzeArithmetic(
sema: *Sema,
block: *Scope.Block,
Expand Down
39 changes: 39 additions & 0 deletions src/Zir.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1629,6 +1629,22 @@ pub const Inst = struct {
wasm_memory_size,
/// `operand` is payload index to `BinNode`.
wasm_memory_grow,
/// Implements the `@addWithSaturation` builtin.
/// `operand` is payload index to `SaturatingArithmetic`.
/// `small` is unused.
add_with_saturation,
/// Implements the `@subWithSaturation` builtin.
/// `operand` is payload index to `SaturatingArithmetic`.
/// `small` is unused.
sub_with_saturation,
/// Implements the `@mulWithSaturation` builtin.
/// `operand` is payload index to `SaturatingArithmetic`.
/// `small` is unused.
mul_with_saturation,
/// Implements the `@shlWithSaturation` builtin.
/// `operand` is payload index to `SaturatingArithmetic`.
/// `small` is unused.
shl_with_saturation,

pub const InstData = struct {
opcode: Extended,
Expand Down Expand Up @@ -2751,6 +2767,12 @@ pub const Inst = struct {
ptr: Ref,
};

pub const SaturatingArithmetic = struct {
node: i32,
lhs: Ref,
rhs: Ref,
};

pub const Cmpxchg = struct {
ptr: Ref,
expected_value: Ref,
Expand Down Expand Up @@ -3231,6 +3253,11 @@ const Writer = struct {
.shl_with_overflow,
=> try self.writeOverflowArithmetic(stream, extended),

.add_with_saturation,
.sub_with_saturation,
.mul_with_saturation,
.shl_with_saturation,
=> try self.writeSaturatingArithmetic(stream, extended),
.struct_decl => try self.writeStructDecl(stream, extended),
.union_decl => try self.writeUnionDecl(stream, extended),
.enum_decl => try self.writeEnumDecl(stream, extended),
Expand Down Expand Up @@ -3584,6 +3611,18 @@ const Writer = struct {
try self.writeSrc(stream, src);
}

fn writeSaturatingArithmetic(self: *Writer, stream: anytype, extended: Inst.Extended.InstData) !void {
const extra = self.code.extraData(Zir.Inst.SaturatingArithmetic, extended.operand).data;
const src: LazySrcLoc = .{ .node_offset = extra.node };

try self.writeInstRef(stream, extra.lhs);
try stream.writeAll(", ");
try self.writeInstRef(stream, extra.rhs);
try stream.writeAll(", ");
try stream.writeAll(") ");
try self.writeSrc(stream, src);
}

fn writePlNodeCall(self: *Writer, stream: anytype, inst: Inst.Index) !void {
const inst_data = self.code.instructions.items(.data)[inst].pl_node;
const extra = self.code.extraData(Inst.Call, inst_data.payload_index);
Expand Down
8 changes: 8 additions & 0 deletions src/stage1/all_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,10 @@ enum BuiltinFnId {
BuiltinFnIdReduce,
BuiltinFnIdMaximum,
BuiltinFnIdMinimum,
BuiltinFnIdSatAdd,
BuiltinFnIdSatSub,
BuiltinFnIdSatMul,
BuiltinFnIdSatShl,
};

struct BuiltinFnEntry {
Expand Down Expand Up @@ -2946,6 +2950,10 @@ enum IrBinOp {
IrBinOpArrayMult,
IrBinOpMaximum,
IrBinOpMinimum,
IrBinOpSatAdd,
IrBinOpSatSub,
IrBinOpSatMul,
IrBinOpSatShl,
};

struct Stage1ZirInstBinOp {
Expand Down
60 changes: 60 additions & 0 deletions src/stage1/astgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4704,6 +4704,66 @@ static Stage1ZirInst *astgen_builtin_fn_call(Stage1AstGen *ag, Scope *scope, Ast
Stage1ZirInst *bin_op = ir_build_bin_op(ag, scope, node, IrBinOpMaximum, arg0_value, arg1_value, true);
return ir_lval_wrap(ag, scope, bin_op, lval, result_loc);
}
case BuiltinFnIdSatAdd:
{
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
Stage1ZirInst *arg0_value = astgen_node(ag, arg0_node, scope);
if (arg0_value == ag->codegen->invalid_inst_src)
return arg0_value;

AstNode *arg1_node = node->data.fn_call_expr.params.at(1);
Stage1ZirInst *arg1_value = astgen_node(ag, arg1_node, scope);
if (arg1_value == ag->codegen->invalid_inst_src)
return arg1_value;

Stage1ZirInst *bin_op = ir_build_bin_op(ag, scope, node, IrBinOpSatAdd, arg0_value, arg1_value, true);
return ir_lval_wrap(ag, scope, bin_op, lval, result_loc);
}
case BuiltinFnIdSatSub:
{
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
Stage1ZirInst *arg0_value = astgen_node(ag, arg0_node, scope);
if (arg0_value == ag->codegen->invalid_inst_src)
return arg0_value;

AstNode *arg1_node = node->data.fn_call_expr.params.at(1);
Stage1ZirInst *arg1_value = astgen_node(ag, arg1_node, scope);
if (arg1_value == ag->codegen->invalid_inst_src)
return arg1_value;

Stage1ZirInst *bin_op = ir_build_bin_op(ag, scope, node, IrBinOpSatSub, arg0_value, arg1_value, true);
return ir_lval_wrap(ag, scope, bin_op, lval, result_loc);
}
case BuiltinFnIdSatMul:
{
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
Stage1ZirInst *arg0_value = astgen_node(ag, arg0_node, scope);
if (arg0_value == ag->codegen->invalid_inst_src)
return arg0_value;

AstNode *arg1_node = node->data.fn_call_expr.params.at(1);
Stage1ZirInst *arg1_value = astgen_node(ag, arg1_node, scope);
if (arg1_value == ag->codegen->invalid_inst_src)
return arg1_value;

Stage1ZirInst *bin_op = ir_build_bin_op(ag, scope, node, IrBinOpSatMul, arg0_value, arg1_value, true);
return ir_lval_wrap(ag, scope, bin_op, lval, result_loc);
}
case BuiltinFnIdSatShl:
{
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
Stage1ZirInst *arg0_value = astgen_node(ag, arg0_node, scope);
if (arg0_value == ag->codegen->invalid_inst_src)
return arg0_value;

AstNode *arg1_node = node->data.fn_call_expr.params.at(1);
Stage1ZirInst *arg1_value = astgen_node(ag, arg1_node, scope);
if (arg1_value == ag->codegen->invalid_inst_src)
return arg1_value;

Stage1ZirInst *bin_op = ir_build_bin_op(ag, scope, node, IrBinOpSatShl, arg0_value, arg1_value, true);
return ir_lval_wrap(ag, scope, bin_op, lval, result_loc);
}
case BuiltinFnIdMemcpy:
{
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
Expand Down

0 comments on commit 21a5769

Please sign in to comment.