Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 137 additions & 5 deletions src/arch/x86_64/CodeGen.zig
Original file line number Diff line number Diff line change
Expand Up @@ -85049,10 +85049,132 @@ fn airShlShrBinOp(self: *CodeGen, inst: Air.Inst.Index) !void {
}

fn airShlSat(self: *CodeGen, inst: Air.Inst.Index) !void {
const zcu = self.pt.zcu;
const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
_ = bin_op;
return self.fail("TODO implement shl_sat for {}", .{self.target.cpu.arch});
//return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
const lhs_ty = self.typeOf(bin_op.lhs);
const rhs_ty = self.typeOf(bin_op.rhs);

const result: MCValue = result: {
switch (lhs_ty.zigTypeTag(zcu)) {
.int => {
const lhs_bits = lhs_ty.bitSize(zcu);
const rhs_bits = rhs_ty.bitSize(zcu);
if (!(lhs_bits <= 32 and rhs_bits <= 5) and !(lhs_bits > 32 and lhs_bits <= 64 and rhs_bits <= 6) and !(rhs_bits <= std.math.log2(lhs_bits))) {
return self.fail("TODO implement shl_sat for {} with lhs bits {}, rhs bits {}", .{ self.target.cpu.arch, lhs_bits, rhs_bits });
}

// clobberred by genShiftBinOp
try self.spillRegisters(&.{.rcx});

const lhs_mcv = try self.resolveInst(bin_op.lhs);
var lhs_temp1 = try self.tempInit(lhs_ty, lhs_mcv);
const rhs_mcv = try self.resolveInst(bin_op.rhs);

const lhs_lock = switch (lhs_mcv) {
.register => |reg| self.register_manager.lockRegAssumeUnused(reg),
else => null,
};
defer if (lhs_lock) |lock| self.register_manager.unlockReg(lock);

// shift left
const dst_mcv = try self.genShiftBinOp(.shl, null, lhs_mcv, rhs_mcv, lhs_ty, rhs_ty);
switch (dst_mcv) {
.register => |dst_reg| try self.truncateRegister(lhs_ty, dst_reg),
.register_pair => |dst_regs| try self.truncateRegister(lhs_ty, dst_regs[1]),
.load_frame => |frame_addr| {
const tmp_reg =
try self.register_manager.allocReg(null, abi.RegisterClass.gp);
const tmp_lock = self.register_manager.lockRegAssumeUnused(tmp_reg);
defer self.register_manager.unlockReg(tmp_lock);

const lhs_bits_u31: u31 = @intCast(lhs_bits);
const tmp_ty: Type = if (lhs_bits_u31 > 64) .usize else lhs_ty;
const off = frame_addr.off + (lhs_bits_u31 - 1) / 64 * 8;
try self.genSetReg(
tmp_reg,
tmp_ty,
.{ .load_frame = .{ .index = frame_addr.index, .off = off } },
.{},
);
try self.truncateRegister(lhs_ty, tmp_reg);
try self.genSetMem(
.{ .frame = frame_addr.index },
off,
tmp_ty,
.{ .register = tmp_reg },
.{},
);
},
else => {},
}
const dst_lock = switch (dst_mcv) {
.register => |reg| self.register_manager.lockRegAssumeUnused(reg),
else => null,
};
defer if (dst_lock) |lock| self.register_manager.unlockReg(lock);

// shift right
const tmp_mcv = try self.genShiftBinOp(.shr, null, dst_mcv, rhs_mcv, lhs_ty, rhs_ty);
var tmp_temp = try self.tempInit(lhs_ty, tmp_mcv);

// check if overflow happens
const cc_temp = lhs_temp1.cmpInts(.neq, &tmp_temp, self) catch |err| switch (err) {
error.SelectFailed => unreachable,
else => |e| return e,
};
try lhs_temp1.die(self);
try tmp_temp.die(self);
const overflow_reloc = try self.genCondBrMir(lhs_ty, cc_temp.tracking(self).short);
try cc_temp.die(self);

// if overflow,
// for unsigned integers, the saturating result is just its max
// for signed integers,
// if lhs is positive, the result is its max
// if lhs is negative, it is min
switch (lhs_ty.intInfo(zcu).signedness) {
.unsigned => {
const bound_mcv = try self.genTypedValue(try lhs_ty.maxIntScalar(self.pt, lhs_ty));
try self.genCopy(lhs_ty, dst_mcv, bound_mcv, .{});
},
.signed => {
// check the sign of lhs
// TODO: optimize this.
// we only need the highest bit so shifting the highest part of lhs_mcv
// is enough to check the signedness. other parts can be skipped here.
var lhs_temp2 = try self.tempInit(lhs_ty, lhs_mcv);
var zero_temp = try self.tempInit(lhs_ty, try self.genTypedValue(try self.pt.intValue(lhs_ty, 0)));
const sign_cc_temp = lhs_temp2.cmpInts(.lt, &zero_temp, self) catch |err| switch (err) {
error.SelectFailed => unreachable,
else => |e| return e,
};
try lhs_temp2.die(self);
try zero_temp.die(self);
const sign_reloc_condbr = try self.genCondBrMir(lhs_ty, sign_cc_temp.tracking(self).short);
try sign_cc_temp.die(self);

// if it is negative
const min_mcv = try self.genTypedValue(try lhs_ty.minIntScalar(self.pt, lhs_ty));
try self.genCopy(lhs_ty, dst_mcv, min_mcv, .{});
const sign_reloc_br = try self.asmJmpReloc(undefined);
self.performReloc(sign_reloc_condbr);

// if it is positive
const max_mcv = try self.genTypedValue(try lhs_ty.maxIntScalar(self.pt, lhs_ty));
try self.genCopy(lhs_ty, dst_mcv, max_mcv, .{});
self.performReloc(sign_reloc_br);
},
}

self.performReloc(overflow_reloc);
break :result dst_mcv;
},
else => {
return self.fail("TODO implement shl_sat for {} op type {}", .{ self.target.cpu.arch, lhs_ty.zigTypeTag(zcu) });
},
}
};
return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
}

fn airOptionalPayload(self: *CodeGen, inst: Air.Inst.Index) !void {
Expand Down Expand Up @@ -88437,7 +88559,7 @@ fn genShiftBinOpMir(
) !void {
const pt = self.pt;
const zcu = pt.zcu;
const abi_size: u32 = @intCast(lhs_ty.abiSize(zcu));
const abi_size: u31 = @intCast(lhs_ty.abiSize(zcu));
const shift_abi_size: u32 = @intCast(rhs_ty.abiSize(zcu));
try self.spillEflagsIfOccupied();

Expand Down Expand Up @@ -88621,7 +88743,17 @@ fn genShiftBinOpMir(
.immediate => {},
else => self.performReloc(skip),
}
}
} else try self.asmRegisterMemory(.{ ._, .mov }, temp_regs[2].to64(), .{
.base = .{ .frame = lhs_mcv.load_frame.index },
.mod = .{ .rm = .{
.size = .qword,
.disp = switch (tag[0]) {
._l => lhs_mcv.load_frame.off,
._r => lhs_mcv.load_frame.off + abi_size - 8,
else => unreachable,
},
} },
});
switch (rhs_mcv) {
.immediate => |shift_imm| try self.asmRegisterImmediate(
tag,
Expand Down
48 changes: 47 additions & 1 deletion test/behavior/bit_shifting.zig
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
const std = @import("std");
const expect = std.testing.expect;
const expectEqual = std.testing.expectEqual;
const builtin = @import("builtin");

fn ShardedTable(comptime Key: type, comptime mask_bit_count: comptime_int, comptime V: type) type {
Expand Down Expand Up @@ -111,7 +112,6 @@ test "comptime shift safety check" {
}

test "Saturating Shift Left where lhs is of a computed type" {
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
Expand Down Expand Up @@ -159,3 +159,49 @@ comptime {
_ = &image;
_ = @shlExact(@as(u16, image[0]), 8);
}

test "Saturating Shift Left" {
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest;

const S = struct {
fn shlSat(x: anytype, y: std.math.Log2Int(@TypeOf(x))) @TypeOf(x) {
// workaround https://github.com/ziglang/zig/issues/23033
@setRuntimeSafety(false);
return x <<| y;
}

fn testType(comptime T: type) !void {
comptime var rhs: std.math.Log2Int(T) = 0;
inline while (true) : (rhs += 1) {
comptime var lhs: T = std.math.minInt(T);
inline while (true) : (lhs += 1) {
try expectEqual(lhs <<| rhs, shlSat(lhs, rhs));
if (lhs == std.math.maxInt(T)) break;
}
if (rhs == @bitSizeOf(T) - 1) break;
}
}
};

try S.testType(u2);
try S.testType(i2);
try S.testType(u3);
try S.testType(i3);
try S.testType(u4);
try S.testType(i4);

try expectEqual(0xfffffffffffffff0fffffffffffffff0, S.shlSat(@as(u128, 0x0fffffffffffffff0fffffffffffffff), 4));
try expectEqual(0xffffffffffffffffffffffffffffffff, S.shlSat(@as(u128, 0x0fffffffffffffff0fffffffffffffff), 5));
try expectEqual(-0x80000000000000000000000000000000, S.shlSat(@as(i128, -0x0fffffffffffffff0fffffffffffffff), 5));

// TODO
// try expectEqual(51146728248377216718956089012931236753385031969422887335676427626502090568823039920051095192592252455482604439493126109519019633529459266458258243583, S.shlSat(@as(i495, 0x2fe6bc5448c55ce18252e2c9d44777505dfe63ff249a8027a6626c7d8dd9893fd5731e51474727be556f757facb586a4e04bbc0148c6c7ad692302f46fbd), 0x31));
try expectEqual(-57896044618658097711785492504343953926634992332820282019728792003956564819968, S.shlSat(@as(i256, -0x53d4148cee74ea43477a65b3daa7b8fdadcbf4508e793f4af113b8d8da5a7eb6), 0x91));
try expectEqual(170141183460469231731687303715884105727, S.shlSat(@as(i128, 0x2fe6bc5448c55ce18252e2c9d4477750), 0x31));
try expectEqual(0, S.shlSat(@as(i128, 0), 127));
}