Skip to content

Commit

Permalink
force_array -> is_consecutive
Browse files Browse the repository at this point in the history
The actual ABI implication here is that in some cases the values
are required to be "consecutive", i.e. must either all be passed
in registers or all on stack (without padding).

Adjust the code to either use Uniform::new() or Uniform::consecutive()
depending on which behavior is needed.

Then, when lowering this in LLVM, skip the [1 x i128] to i128
simplification if is_consecutive is set. i128 is the only case
I'm aware of where this is problematic right now. If we find
other cases, we can extend this (either based on target information
or possibly just by not simplifying for is_consecutive entirely).
  • Loading branch information
nikic committed Apr 8, 2024
1 parent 009280c commit 1b7342b
Show file tree
Hide file tree
Showing 13 changed files with 53 additions and 57 deletions.
5 changes: 4 additions & 1 deletion compiler/rustc_codegen_llvm/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,10 @@ impl LlvmType for CastTarget {
// Simplify to a single unit or an array if there's no prefix.
// This produces the same layout, but using a simpler type.
if self.prefix.iter().all(|x| x.is_none()) {
if rest_count == 1 && !self.rest.force_array {
// We can't do this if is_consecutive is set and the unit would get
// split on the target. Currently, this is only relevant for i128
// registers.
if rest_count == 1 && (!self.rest.is_consecutive || self.rest.unit != Reg::i128()) {
return rest_ll_unit;
}

Expand Down
8 changes: 4 additions & 4 deletions compiler/rustc_target/src/abi/call/aarch64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ where
RegKind::Vector => size.bits() == 64 || size.bits() == 128,
};

valid_unit.then_some(Uniform { unit, total: size, force_array: false })
valid_unit.then_some(Uniform::consecutive(unit, size))
})
}

Expand Down Expand Up @@ -60,7 +60,7 @@ where
let size = ret.layout.size;
let bits = size.bits();
if bits <= 128 {
ret.cast_to(Uniform { unit: Reg::i64(), total: size, force_array: false });
ret.cast_to(Uniform::new(Reg::i64(), size));
return;
}
ret.make_indirect();
Expand Down Expand Up @@ -100,9 +100,9 @@ where
};
if size.bits() <= 128 {
if align.bits() == 128 {
arg.cast_to(Uniform { unit: Reg::i128(), total: size, force_array: false });
arg.cast_to(Uniform::new(Reg::i128(), size));
} else {
arg.cast_to(Uniform { unit: Reg::i64(), total: size, force_array: false });
arg.cast_to(Uniform::new(Reg::i64(), size));
}
return;
}
Expand Down
10 changes: 3 additions & 7 deletions compiler/rustc_target/src/abi/call/arm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ where
RegKind::Vector => size.bits() == 64 || size.bits() == 128,
};

valid_unit.then_some(Uniform { unit, total: size, force_array: false })
valid_unit.then_some(Uniform::consecutive(unit, size))
})
}

Expand Down Expand Up @@ -49,7 +49,7 @@ where
let size = ret.layout.size;
let bits = size.bits();
if bits <= 32 {
ret.cast_to(Uniform { unit: Reg::i32(), total: size, force_array: false });
ret.cast_to(Uniform::new(Reg::i32(), size));
return;
}
ret.make_indirect();
Expand Down Expand Up @@ -78,11 +78,7 @@ where

let align = arg.layout.align.abi.bytes();
let total = arg.layout.size;
arg.cast_to(Uniform {
unit: if align <= 4 { Reg::i32() } else { Reg::i64() },
total,
force_array: false,
});
arg.cast_to(Uniform::consecutive(if align <= 4 { Reg::i32() } else { Reg::i64() }, total));
}

pub fn compute_abi_info<'a, Ty, C>(cx: &C, fn_abi: &mut FnAbi<'a, Ty>)
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_target/src/abi/call/csky.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ fn classify_ret<Ty>(arg: &mut ArgAbi<'_, Ty>) {
if total.bits() > 64 {
arg.make_indirect();
} else if total.bits() > 32 {
arg.cast_to(Uniform { unit: Reg::i32(), total, force_array: false });
arg.cast_to(Uniform::new(Reg::i32(), total));
} else {
arg.cast_to(Reg::i32());
}
Expand All @@ -38,7 +38,7 @@ fn classify_arg<Ty>(arg: &mut ArgAbi<'_, Ty>) {
if arg.layout.is_aggregate() {
let total = arg.layout.size;
if total.bits() > 32 {
arg.cast_to(Uniform { unit: Reg::i32(), total, force_array: false });
arg.cast_to(Uniform::new(Reg::i32(), total));
} else {
arg.cast_to(Reg::i32());
}
Expand Down
15 changes: 5 additions & 10 deletions compiler/rustc_target/src/abi/call/loongarch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,7 @@ where
if total.bits() <= xlen {
arg.cast_to(xlen_reg);
} else {
arg.cast_to(Uniform {
unit: xlen_reg,
total: Size::from_bits(xlen * 2),
force_array: false,
});
arg.cast_to(Uniform::new(xlen_reg, Size::from_bits(xlen * 2)));
}
return false;
}
Expand Down Expand Up @@ -282,11 +278,10 @@ fn classify_arg<'a, Ty, C>(
if total.bits() > xlen {
let align_regs = align > xlen;
if is_loongarch_aggregate(arg) {
arg.cast_to(Uniform {
unit: if align_regs { double_xlen_reg } else { xlen_reg },
total: Size::from_bits(xlen * 2),
force_array: false,
});
arg.cast_to(Uniform::new(
if align_regs { double_xlen_reg } else { xlen_reg },
Size::from_bits(xlen * 2),
));
}
if align_regs && is_vararg {
*avail_gprs -= *avail_gprs % 2;
Expand Down
5 changes: 1 addition & 4 deletions compiler/rustc_target/src/abi/call/mips.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ where

if arg.layout.is_aggregate() {
let pad_i32 = !offset.is_aligned(align);
arg.cast_to_and_pad_i32(
Uniform { unit: Reg::i32(), total: size, force_array: false },
pad_i32,
);
arg.cast_to_and_pad_i32(Uniform::new(Reg::i32(), size), pad_i32);
} else {
arg.extend_integer_width_to(32);
}
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_target/src/abi/call/mips64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ where
}

// Cast to a uniform int structure
ret.cast_to(Uniform { unit: Reg::i64(), total: size, force_array: false });
ret.cast_to(Uniform::new(Reg::i64(), size));
} else {
ret.make_indirect();
}
Expand Down Expand Up @@ -139,7 +139,7 @@ where
let rest_size = size - Size::from_bytes(8) * prefix_index as u64;
arg.cast_to(CastTarget {
prefix,
rest: Uniform { unit: Reg::i64(), total: rest_size, force_array: false },
rest: Uniform::new(Reg::i64(), rest_size),
attrs: ArgAttributes {
regular: ArgAttribute::default(),
arg_ext: ArgExtension::None,
Expand Down
20 changes: 17 additions & 3 deletions compiler/rustc_target/src/abi/call/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,20 +256,34 @@ pub struct Uniform {
/// this size will be rounded up to the nearest multiple of `unit.size`.
pub total: Size,

/// Force the use of an array, even if there is only a single element.
pub force_array: bool,
/// Indicate that the argument is consecutive, in the sense that either all values need to be
/// passed in register, or all on the stack. If they are passed on the stack, there should be
/// no additional padding between elements.
pub is_consecutive: bool,
}

impl From<Reg> for Uniform {
fn from(unit: Reg) -> Uniform {
Uniform { unit, total: unit.size, force_array: false }
Uniform { unit, total: unit.size, is_consecutive: false }
}
}

impl Uniform {
pub fn align<C: HasDataLayout>(&self, cx: &C) -> Align {
self.unit.align(cx)
}

/// Pass using one or more values of the given type, without requiring them to be consecutive.
/// That is, some values may be passed in register and some on the stack.
pub fn new(unit: Reg, total: Size) -> Self {
Uniform { unit, total, is_consecutive: false }
}

/// Pass using one or more consecutive values of the given type. Either all values will be
/// passed in registers, or all on the stack.
pub fn consecutive(unit: Reg, total: Size) -> Self {
Uniform { unit, total, is_consecutive: true }
}
}

/// Describes the type used for `PassMode::Cast`.
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_target/src/abi/call/nvptx64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ where
16 => Reg::i128(),
_ => unreachable!("Align is given as power of 2 no larger than 16 bytes"),
};
arg.cast_to(Uniform { unit, total: Size::from_bytes(2 * align_bytes), force_array: false });
arg.cast_to(Uniform::new(unit, Size::from_bytes(2 * align_bytes)));
} else {
// FIXME: find a better way to do this. See https://github.com/rust-lang/rust/issues/117271.
arg.make_direct_deprecated();
Expand Down
13 changes: 6 additions & 7 deletions compiler/rustc_target/src/abi/call/powerpc64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ where
RegKind::Vector => arg.layout.size.bits() == 128,
};

valid_unit.then_some(Uniform { unit, total: arg.layout.size, force_array: false })
valid_unit.then_some(Uniform::consecutive(unit, arg.layout.size))
})
}

Expand Down Expand Up @@ -81,7 +81,7 @@ where
Reg::i64()
};

ret.cast_to(Uniform { unit, total: size, force_array: false });
ret.cast_to(Uniform::new(unit, size));
return;
}

Expand Down Expand Up @@ -117,11 +117,10 @@ where
// of i64s or i128s, depending on the aggregate alignment. Always use an array for
// this, even if there is only a single element.
let reg = if arg.layout.align.abi.bytes() > 8 { Reg::i128() } else { Reg::i64() };
arg.cast_to(Uniform {
unit: reg,
total: size.align_to(Align::from_bytes(reg.size.bytes()).unwrap()),
force_array: true,
})
arg.cast_to(Uniform::consecutive(
reg,
size.align_to(Align::from_bytes(reg.size.bytes()).unwrap()),
))
};
}

Expand Down
15 changes: 5 additions & 10 deletions compiler/rustc_target/src/abi/call/riscv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,7 @@ where
if total.bits() <= xlen {
arg.cast_to(xlen_reg);
} else {
arg.cast_to(Uniform {
unit: xlen_reg,
total: Size::from_bits(xlen * 2),
force_array: false,
});
arg.cast_to(Uniform::new(xlen_reg, Size::from_bits(xlen * 2)));
}
return false;
}
Expand Down Expand Up @@ -288,11 +284,10 @@ fn classify_arg<'a, Ty, C>(
if total.bits() > xlen {
let align_regs = align > xlen;
if is_riscv_aggregate(arg) {
arg.cast_to(Uniform {
unit: if align_regs { double_xlen_reg } else { xlen_reg },
total: Size::from_bits(xlen * 2),
force_array: false,
});
arg.cast_to(Uniform::new(
if align_regs { double_xlen_reg } else { xlen_reg },
Size::from_bits(xlen * 2),
));
}
if align_regs && is_vararg {
*avail_gprs -= *avail_gprs % 2;
Expand Down
5 changes: 1 addition & 4 deletions compiler/rustc_target/src/abi/call/sparc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ where

if arg.layout.is_aggregate() {
let pad_i32 = !offset.is_aligned(align);
arg.cast_to_and_pad_i32(
Uniform { unit: Reg::i32(), total: size, force_array: false },
pad_i32,
);
arg.cast_to_and_pad_i32(Uniform::new(Reg::i32(), size), pad_i32);
} else {
arg.extend_integer_width_to(32);
}
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_target/src/abi/call/sparc64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ where

arg.cast_to(CastTarget {
prefix: data.prefix,
rest: Uniform { unit: Reg::i64(), total: rest_size, force_array: false },
rest: Uniform::new(Reg::i64(), rest_size),
attrs: ArgAttributes {
regular: data.arg_attribute,
arg_ext: ArgExtension::None,
Expand All @@ -205,7 +205,7 @@ where
}
}

arg.cast_to(Uniform { unit: Reg::i64(), total, force_array: false });
arg.cast_to(Uniform::new(Reg::i64(), total));
}

pub fn compute_abi_info<'a, Ty, C>(cx: &C, fn_abi: &mut FnAbi<'a, Ty>)
Expand Down

0 comments on commit 1b7342b

Please sign in to comment.