diff --git a/src/arch/x86_64/CodeGen.zig b/src/arch/x86_64/CodeGen.zig index 7b734994f43c..d426ee6dcb82 100644 --- a/src/arch/x86_64/CodeGen.zig +++ b/src/arch/x86_64/CodeGen.zig @@ -85049,10 +85049,132 @@ fn airShlShrBinOp(self: *CodeGen, inst: Air.Inst.Index) !void { } fn airShlSat(self: *CodeGen, inst: Air.Inst.Index) !void { + const zcu = self.pt.zcu; const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; - _ = bin_op; - return self.fail("TODO implement shl_sat for {}", .{self.target.cpu.arch}); - //return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none }); + const lhs_ty = self.typeOf(bin_op.lhs); + const rhs_ty = self.typeOf(bin_op.rhs); + + const result: MCValue = result: { + switch (lhs_ty.zigTypeTag(zcu)) { + .int => { + const lhs_bits = lhs_ty.bitSize(zcu); + const rhs_bits = rhs_ty.bitSize(zcu); + if (!(lhs_bits <= 32 and rhs_bits <= 5) and !(lhs_bits > 32 and lhs_bits <= 64 and rhs_bits <= 6) and !(rhs_bits <= std.math.log2(lhs_bits))) { + return self.fail("TODO implement shl_sat for {} with lhs bits {}, rhs bits {}", .{ self.target.cpu.arch, lhs_bits, rhs_bits }); + } + + // clobberred by genShiftBinOp + try self.spillRegisters(&.{.rcx}); + + const lhs_mcv = try self.resolveInst(bin_op.lhs); + var lhs_temp1 = try self.tempInit(lhs_ty, lhs_mcv); + const rhs_mcv = try self.resolveInst(bin_op.rhs); + + const lhs_lock = switch (lhs_mcv) { + .register => |reg| self.register_manager.lockRegAssumeUnused(reg), + else => null, + }; + defer if (lhs_lock) |lock| self.register_manager.unlockReg(lock); + + // shift left + const dst_mcv = try self.genShiftBinOp(.shl, null, lhs_mcv, rhs_mcv, lhs_ty, rhs_ty); + switch (dst_mcv) { + .register => |dst_reg| try self.truncateRegister(lhs_ty, dst_reg), + .register_pair => |dst_regs| try self.truncateRegister(lhs_ty, dst_regs[1]), + .load_frame => |frame_addr| { + const tmp_reg = + try self.register_manager.allocReg(null, abi.RegisterClass.gp); + const tmp_lock = self.register_manager.lockRegAssumeUnused(tmp_reg); + defer self.register_manager.unlockReg(tmp_lock); + + const lhs_bits_u31: u31 = @intCast(lhs_bits); + const tmp_ty: Type = if (lhs_bits_u31 > 64) .usize else lhs_ty; + const off = frame_addr.off + (lhs_bits_u31 - 1) / 64 * 8; + try self.genSetReg( + tmp_reg, + tmp_ty, + .{ .load_frame = .{ .index = frame_addr.index, .off = off } }, + .{}, + ); + try self.truncateRegister(lhs_ty, tmp_reg); + try self.genSetMem( + .{ .frame = frame_addr.index }, + off, + tmp_ty, + .{ .register = tmp_reg }, + .{}, + ); + }, + else => {}, + } + const dst_lock = switch (dst_mcv) { + .register => |reg| self.register_manager.lockRegAssumeUnused(reg), + else => null, + }; + defer if (dst_lock) |lock| self.register_manager.unlockReg(lock); + + // shift right + const tmp_mcv = try self.genShiftBinOp(.shr, null, dst_mcv, rhs_mcv, lhs_ty, rhs_ty); + var tmp_temp = try self.tempInit(lhs_ty, tmp_mcv); + + // check if overflow happens + const cc_temp = lhs_temp1.cmpInts(.neq, &tmp_temp, self) catch |err| switch (err) { + error.SelectFailed => unreachable, + else => |e| return e, + }; + try lhs_temp1.die(self); + try tmp_temp.die(self); + const overflow_reloc = try self.genCondBrMir(lhs_ty, cc_temp.tracking(self).short); + try cc_temp.die(self); + + // if overflow, + // for unsigned integers, the saturating result is just its max + // for signed integers, + // if lhs is positive, the result is its max + // if lhs is negative, it is min + switch (lhs_ty.intInfo(zcu).signedness) { + .unsigned => { + const bound_mcv = try self.genTypedValue(try lhs_ty.maxIntScalar(self.pt, lhs_ty)); + try self.genCopy(lhs_ty, dst_mcv, bound_mcv, .{}); + }, + .signed => { + // check the sign of lhs + // TODO: optimize this. + // we only need the highest bit so shifting the highest part of lhs_mcv + // is enough to check the signedness. other parts can be skipped here. + var lhs_temp2 = try self.tempInit(lhs_ty, lhs_mcv); + var zero_temp = try self.tempInit(lhs_ty, try self.genTypedValue(try self.pt.intValue(lhs_ty, 0))); + const sign_cc_temp = lhs_temp2.cmpInts(.lt, &zero_temp, self) catch |err| switch (err) { + error.SelectFailed => unreachable, + else => |e| return e, + }; + try lhs_temp2.die(self); + try zero_temp.die(self); + const sign_reloc_condbr = try self.genCondBrMir(lhs_ty, sign_cc_temp.tracking(self).short); + try sign_cc_temp.die(self); + + // if it is negative + const min_mcv = try self.genTypedValue(try lhs_ty.minIntScalar(self.pt, lhs_ty)); + try self.genCopy(lhs_ty, dst_mcv, min_mcv, .{}); + const sign_reloc_br = try self.asmJmpReloc(undefined); + self.performReloc(sign_reloc_condbr); + + // if it is positive + const max_mcv = try self.genTypedValue(try lhs_ty.maxIntScalar(self.pt, lhs_ty)); + try self.genCopy(lhs_ty, dst_mcv, max_mcv, .{}); + self.performReloc(sign_reloc_br); + }, + } + + self.performReloc(overflow_reloc); + break :result dst_mcv; + }, + else => { + return self.fail("TODO implement shl_sat for {} op type {}", .{ self.target.cpu.arch, lhs_ty.zigTypeTag(zcu) }); + }, + } + }; + return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none }); } fn airOptionalPayload(self: *CodeGen, inst: Air.Inst.Index) !void { @@ -88437,7 +88559,7 @@ fn genShiftBinOpMir( ) !void { const pt = self.pt; const zcu = pt.zcu; - const abi_size: u32 = @intCast(lhs_ty.abiSize(zcu)); + const abi_size: u31 = @intCast(lhs_ty.abiSize(zcu)); const shift_abi_size: u32 = @intCast(rhs_ty.abiSize(zcu)); try self.spillEflagsIfOccupied(); @@ -88621,7 +88743,17 @@ fn genShiftBinOpMir( .immediate => {}, else => self.performReloc(skip), } - } + } else try self.asmRegisterMemory(.{ ._, .mov }, temp_regs[2].to64(), .{ + .base = .{ .frame = lhs_mcv.load_frame.index }, + .mod = .{ .rm = .{ + .size = .qword, + .disp = switch (tag[0]) { + ._l => lhs_mcv.load_frame.off, + ._r => lhs_mcv.load_frame.off + abi_size - 8, + else => unreachable, + }, + } }, + }); switch (rhs_mcv) { .immediate => |shift_imm| try self.asmRegisterImmediate( tag, diff --git a/test/behavior/bit_shifting.zig b/test/behavior/bit_shifting.zig index da43fed7bb03..597f9c2182c5 100644 --- a/test/behavior/bit_shifting.zig +++ b/test/behavior/bit_shifting.zig @@ -1,5 +1,6 @@ const std = @import("std"); const expect = std.testing.expect; +const expectEqual = std.testing.expectEqual; const builtin = @import("builtin"); fn ShardedTable(comptime Key: type, comptime mask_bit_count: comptime_int, comptime V: type) type { @@ -111,7 +112,6 @@ test "comptime shift safety check" { } test "Saturating Shift Left where lhs is of a computed type" { - if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO @@ -159,3 +159,49 @@ comptime { _ = ℑ _ = @shlExact(@as(u16, image[0]), 8); } + +test "Saturating Shift Left" { + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; + + const S = struct { + fn shlSat(x: anytype, y: std.math.Log2Int(@TypeOf(x))) @TypeOf(x) { + // workaround https://github.com/ziglang/zig/issues/23033 + @setRuntimeSafety(false); + return x <<| y; + } + + fn testType(comptime T: type) !void { + comptime var rhs: std.math.Log2Int(T) = 0; + inline while (true) : (rhs += 1) { + comptime var lhs: T = std.math.minInt(T); + inline while (true) : (lhs += 1) { + try expectEqual(lhs <<| rhs, shlSat(lhs, rhs)); + if (lhs == std.math.maxInt(T)) break; + } + if (rhs == @bitSizeOf(T) - 1) break; + } + } + }; + + try S.testType(u2); + try S.testType(i2); + try S.testType(u3); + try S.testType(i3); + try S.testType(u4); + try S.testType(i4); + + try expectEqual(0xfffffffffffffff0fffffffffffffff0, S.shlSat(@as(u128, 0x0fffffffffffffff0fffffffffffffff), 4)); + try expectEqual(0xffffffffffffffffffffffffffffffff, S.shlSat(@as(u128, 0x0fffffffffffffff0fffffffffffffff), 5)); + try expectEqual(-0x80000000000000000000000000000000, S.shlSat(@as(i128, -0x0fffffffffffffff0fffffffffffffff), 5)); + + // TODO + // try expectEqual(51146728248377216718956089012931236753385031969422887335676427626502090568823039920051095192592252455482604439493126109519019633529459266458258243583, S.shlSat(@as(i495, 0x2fe6bc5448c55ce18252e2c9d44777505dfe63ff249a8027a6626c7d8dd9893fd5731e51474727be556f757facb586a4e04bbc0148c6c7ad692302f46fbd), 0x31)); + try expectEqual(-57896044618658097711785492504343953926634992332820282019728792003956564819968, S.shlSat(@as(i256, -0x53d4148cee74ea43477a65b3daa7b8fdadcbf4508e793f4af113b8d8da5a7eb6), 0x91)); + try expectEqual(170141183460469231731687303715884105727, S.shlSat(@as(i128, 0x2fe6bc5448c55ce18252e2c9d4477750), 0x31)); + try expectEqual(0, S.shlSat(@as(i128, 0), 127)); +}