Skip to content

Commit 0cd940e

Browse files
committed
Implement simd_masked_load and simd_masked_store
1 parent bd64989 commit 0cd940e

File tree

1 file changed

+187
-2
lines changed

1 file changed

+187
-2
lines changed

src/intrinsic/simd.rs

Lines changed: 187 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(
5353
};
5454
}
5555

56+
// TODO(antoyo): refactor with the above require_simd macro that was changed in cg_llvm.
57+
macro_rules! require_simd2 {
58+
($ty: expr, $variant:ident) => {{
59+
require!($ty.is_simd(), InvalidMonomorphization::$variant { span, name, ty: $ty });
60+
$ty.simd_size_and_type(bx.tcx())
61+
}};
62+
}
63+
5664
if name == sym::simd_select_bitmask {
5765
require_simd!(
5866
args[1].layout.ty,
@@ -464,9 +472,8 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(
464472
m_len == v_len,
465473
InvalidMonomorphization::MismatchedLengths { span, name, m_len, v_len }
466474
);
467-
// TODO: also support unsigned integers.
468475
match *m_elem_ty.kind() {
469-
ty::Int(_) => {}
476+
ty::Int(_) | ty::Uint(_) => {}
470477
_ => return_error!(InvalidMonomorphization::MaskWrongElementType {
471478
span,
472479
name,
@@ -1454,6 +1461,184 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(
14541461
bitwise_red!(simd_reduce_all: BinaryOp::BitwiseAnd, true);
14551462
bitwise_red!(simd_reduce_any: BinaryOp::BitwiseOr, true);
14561463

1464+
#[cfg(feature = "master")]
1465+
if name == sym::simd_masked_load {
1466+
// simd_masked_load<_, _, _, const ALIGN: SimdAlign>(mask: <N x i{M}>, pointer: *_ T, values: <N x T>) -> <N x T>
1467+
// * N: number of elements in the input vectors
1468+
// * T: type of the element to load
1469+
// * M: any integer width is supported, will be truncated to i1
1470+
// Loads contiguous elements from memory behind `pointer`, but only for
1471+
// those lanes whose `mask` bit is enabled.
1472+
// The memory addresses corresponding to the “off” lanes are not accessed.
1473+
1474+
// TODO: handle the alignment.
1475+
1476+
// The element type of the "mask" argument must be a signed integer type of any width
1477+
let mask_ty = in_ty;
1478+
let mask_len = in_len;
1479+
1480+
// The second argument must be a pointer matching the element type
1481+
let pointer_ty = args[1].layout.ty;
1482+
1483+
// The last argument is a passthrough vector providing values for disabled lanes
1484+
let values_ty = args[2].layout.ty;
1485+
let (values_len, values_elem) = require_simd2!(values_ty, SimdThird);
1486+
1487+
require_simd2!(ret_ty, SimdReturn);
1488+
1489+
// Of the same length:
1490+
require!(
1491+
values_len == mask_len,
1492+
InvalidMonomorphization::ThirdArgumentLength {
1493+
span,
1494+
name,
1495+
in_len: mask_len,
1496+
in_ty: mask_ty,
1497+
arg_ty: values_ty,
1498+
out_len: values_len
1499+
}
1500+
);
1501+
1502+
// The return type must match the last argument type
1503+
require!(
1504+
ret_ty == values_ty,
1505+
InvalidMonomorphization::ExpectedReturnType { span, name, in_ty: values_ty, ret_ty }
1506+
);
1507+
1508+
require!(
1509+
matches!(
1510+
*pointer_ty.kind(),
1511+
ty::RawPtr(p_ty, _) if p_ty == values_elem && p_ty.kind() == values_elem.kind()
1512+
),
1513+
InvalidMonomorphization::ExpectedElementType {
1514+
span,
1515+
name,
1516+
expected_element: values_elem,
1517+
second_arg: pointer_ty,
1518+
in_elem: values_elem,
1519+
in_ty: values_ty,
1520+
mutability: ExpectedPointerMutability::Not,
1521+
}
1522+
);
1523+
1524+
let mask = args[0].immediate();
1525+
1526+
let pointer = args[1].immediate();
1527+
let default = args[2].immediate();
1528+
let default_type = default.get_type();
1529+
let vector_type = default_type.unqualified().dyncast_vector().expect("vector type");
1530+
let value_type = vector_type.get_element_type();
1531+
let new_pointer_type = value_type.make_pointer();
1532+
1533+
let pointer = bx.context.new_cast(None, pointer, new_pointer_type);
1534+
1535+
let mask_vector_type = mask.get_type().unqualified().dyncast_vector().expect("vector type");
1536+
let elem_type = mask_vector_type.get_element_type();
1537+
let zero = bx.context.new_rvalue_zero(elem_type);
1538+
let mut elements = vec![];
1539+
for i in 0..mask_len {
1540+
let i = bx.context.new_rvalue_from_int(bx.int_type, i as i32);
1541+
let mask = bx.context.new_vector_access(None, mask, i).to_rvalue();
1542+
let mask = bx.context.new_comparison(None, ComparisonOp::NotEquals, mask, zero);
1543+
let then_val = bx.context.new_array_access(None, pointer, i).to_rvalue();
1544+
let else_val = bx.context.new_vector_access(None, default, i).to_rvalue();
1545+
let element = bx.select(mask, then_val, else_val);
1546+
elements.push(element);
1547+
}
1548+
let result = bx.context.new_rvalue_from_vector(None, default_type, &elements);
1549+
return Ok(result);
1550+
}
1551+
1552+
#[cfg(feature = "master")]
1553+
if name == sym::simd_masked_store {
1554+
// simd_masked_store<_, _, _, const ALIGN: SimdAlign>(mask: <N x i{M}>, pointer: *mut T, values: <N x T>) -> ()
1555+
// * N: number of elements in the input vectors
1556+
// * T: type of the element to load
1557+
// * M: any integer width is supported, will be truncated to i1
1558+
// Stores contiguous elements to memory behind `pointer`, but only for
1559+
// those lanes whose `mask` bit is enabled.
1560+
// The memory addresses corresponding to the “off” lanes are not accessed.
1561+
1562+
// TODO: handle the alignment.
1563+
1564+
// The element type of the "mask" argument must be a signed integer type of any width
1565+
let mask_ty = in_ty;
1566+
let mask_len = in_len;
1567+
1568+
// The second argument must be a pointer matching the element type
1569+
let pointer_ty = args[1].layout.ty;
1570+
1571+
// The last argument specifies the values to store to memory
1572+
let values_ty = args[2].layout.ty;
1573+
let (values_len, values_elem) = require_simd2!(values_ty, SimdThird);
1574+
1575+
// Of the same length:
1576+
require!(
1577+
values_len == mask_len,
1578+
InvalidMonomorphization::ThirdArgumentLength {
1579+
span,
1580+
name,
1581+
in_len: mask_len,
1582+
in_ty: mask_ty,
1583+
arg_ty: values_ty,
1584+
out_len: values_len
1585+
}
1586+
);
1587+
1588+
// The second argument must be a mutable pointer type matching the element type
1589+
require!(
1590+
matches!(
1591+
*pointer_ty.kind(),
1592+
ty::RawPtr(p_ty, p_mutbl)
1593+
if p_ty == values_elem && p_ty.kind() == values_elem.kind() && p_mutbl.is_mut()
1594+
),
1595+
InvalidMonomorphization::ExpectedElementType {
1596+
span,
1597+
name,
1598+
expected_element: values_elem,
1599+
second_arg: pointer_ty,
1600+
in_elem: values_elem,
1601+
in_ty: values_ty,
1602+
mutability: ExpectedPointerMutability::Mut,
1603+
}
1604+
);
1605+
1606+
let mask = args[0].immediate();
1607+
let pointer = args[1].immediate();
1608+
let values = args[2].immediate();
1609+
let values_type = values.get_type();
1610+
let vector_type = values_type.unqualified().dyncast_vector().expect("vector type");
1611+
let value_type = vector_type.get_element_type();
1612+
let new_pointer_type = value_type.make_pointer();
1613+
1614+
let pointer = bx.context.new_cast(None, pointer, new_pointer_type);
1615+
1616+
let vector_type = mask.get_type().unqualified().dyncast_vector().expect("vector type");
1617+
let elem_type = vector_type.get_element_type();
1618+
let zero = bx.context.new_rvalue_zero(elem_type);
1619+
for i in 0..mask_len {
1620+
let i = bx.context.new_rvalue_from_int(bx.int_type, i as i32);
1621+
let mask = bx.context.new_vector_access(None, mask, i).to_rvalue();
1622+
let mask = bx.context.new_comparison(None, ComparisonOp::NotEquals, mask, zero);
1623+
1624+
let after_block = bx.current_func().new_block("after");
1625+
let then_block = bx.current_func().new_block("then");
1626+
bx.llbb().end_with_conditional(None, mask, then_block, after_block);
1627+
1628+
bx.switch_to_block(then_block);
1629+
let lvalue = bx.context.new_array_access(None, pointer, i);
1630+
let value = bx.context.new_vector_access(None, values, i).to_rvalue();
1631+
bx.llbb().add_assignment(None, lvalue, value);
1632+
bx.llbb().end_with_jump(None, after_block);
1633+
1634+
bx.switch_to_block(after_block);
1635+
}
1636+
1637+
let dummy_value = bx.context.new_rvalue_zero(bx.int_type);
1638+
1639+
return Ok(dummy_value);
1640+
}
1641+
14571642
unimplemented!("simd {}", name);
14581643
}
14591644

0 commit comments

Comments
 (0)