Skip to content

Commit 48176b5

Browse files
authored
Merge pull request #272 from oxideai/release/0.25.1
release: Version 0.25.1
2 parents db46794 + 473ca87 commit 48176b5

File tree

4 files changed

+106
-26
lines changed

4 files changed

+106
-26
lines changed

Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[workspace.package]
22
# All but mlx-sys should follow the same version. mlx-sys should follow
33
# the version of mlx-c.
4-
version = "0.25.0"
4+
version = "0.25.1"
55
edition = "2021"
66
authors = [
77
"Minghua Wu <michael.wu1107@gmail.com>",
@@ -29,9 +29,9 @@ resolver = "2"
2929
[workspace.dependencies]
3030
# workspace local dependencies
3131
mlx-sys = { version = "=0.2.0", path = "mlx-sys" }
32-
mlx-macros = { version = "0.25.0", path = "mlx-macros" }
33-
mlx-internal-macros = { version = "0.25.0", path = "mlx-internal-macros" }
34-
mlx-rs = { version = "0.25.0", path = "mlx-rs" }
32+
mlx-macros = { version = "0.25", path = "mlx-macros" }
33+
mlx-internal-macros = { version = "0.25", path = "mlx-internal-macros" }
34+
mlx-rs = { version = "0.25.1", path = "mlx-rs" }
3535

3636
# external dependencies
3737
thiserror = "2"

mlx-rs/CHANGELOG.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
# Changelog
22

3-
## 0.25.0-alpha.1
3+
## 0.25.1
44

5-
- Update `mlx-c` to version "0.2.0-alpha" and changes function signatures to
5+
- Fix bug with `index_mut`
6+
7+
## 0.25.0
8+
9+
- Update `mlx-c` to version "0.2.0" and changes function signatures to
610
match the new API
711
- Update `thiserror` to version "2"
812
- Fix wrong states number in `compile_with_state`

mlx-rs/src/ops/arithmetic.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,22 @@ impl Array {
652652
mlx_sys::mlx_square(res, self.as_ptr(), stream.as_ref().as_ptr())
653653
})
654654
}
655+
656+
/// Element-wise real part from a complex array.
657+
#[default_device]
658+
pub fn real_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
659+
Array::try_from_op(|res| unsafe {
660+
mlx_sys::mlx_real(res, self.as_ptr(), stream.as_ref().as_ptr())
661+
})
662+
}
663+
664+
/// Element-wise imag part from a complex array.
665+
#[default_device]
666+
pub fn imag_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
667+
Array::try_from_op(|res| unsafe {
668+
mlx_sys::mlx_imag(res, self.as_ptr(), stream.as_ref().as_ptr())
669+
})
670+
}
655671
}
656672

657673
/// Element-wise absolute value.
@@ -1306,6 +1322,24 @@ pub fn tanh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>)
13061322
})
13071323
}
13081324

1325+
/// Element-wise real part from a complex array.
1326+
#[generate_macro]
1327+
#[default_device]
1328+
pub fn real_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1329+
Array::try_from_op(|res| unsafe {
1330+
mlx_sys::mlx_real(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1331+
})
1332+
}
1333+
1334+
/// Element-wise imaginary part from a complex array.
1335+
#[generate_macro]
1336+
#[default_device]
1337+
pub fn imag_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1338+
Array::try_from_op(|res| unsafe {
1339+
mlx_sys::mlx_imag(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1340+
})
1341+
}
1342+
13091343
/// Matrix multiplication with block masking.
13101344
///
13111345
/// See the [python API docs](
@@ -2500,6 +2534,13 @@ mod tests {
25002534
.item::<bool>());
25012535
}
25022536

2537+
#[test]
2538+
fn test_unary_real_imag() {
2539+
let x = Array::from_complex(complex64::new(0.0, 1.0));
2540+
assert_eq!(real(&x).unwrap(), Array::from_f32(0.0));
2541+
assert_eq!(imag(&x).unwrap(), Array::from_f32(1.0));
2542+
}
2543+
25032544
#[test]
25042545
fn test_binary_add() {
25052546
let x = array![1.0];

mlx-rs/src/ops/indexing/indexmut_impl.rs

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::{
88
ops::{
99
broadcast_arrays_device, broadcast_to_device,
1010
indexing::{count_non_new_axis_operations, expand_ellipsis_operations},
11+
reshape_device,
1112
},
1213
utils::{resolve_index_signed_unchecked, VectorArray},
1314
Array, Stream,
@@ -84,40 +85,59 @@ fn update_slice(
8485
let operations = expand_ellipsis_operations(ndim, operations);
8586

8687
// If no non-None indices return the broadcasted update
87-
if count_non_new_axis_operations(&operations) == 0 {
88+
let non_new_axis_operation_count = count_non_new_axis_operations(&operations);
89+
if non_new_axis_operation_count == 0 {
8890
return Ok(Some(broadcast_to_device(&update, src.shape(), &stream)?));
8991
}
9092

9193
// Process entries
92-
let mut update_expand_dims: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::new();
93-
let mut axis = 0i32;
94-
for item in operations.iter() {
94+
// let mut update_expand_dims: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::new();
95+
let mut update_reshape: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = smallvec![0; ndim];
96+
let mut axis = src.ndim() - 1;
97+
let mut update_axis = update.ndim() as i32 - 1;
98+
99+
while axis >= non_new_axis_operation_count {
100+
if update_axis >= 0 {
101+
update_reshape[axis] = update.dim(update_axis);
102+
update_axis -= 1;
103+
} else {
104+
update_reshape[axis] = 1;
105+
}
106+
axis -= 1;
107+
}
108+
109+
for item in operations.iter().rev() {
95110
use ArrayIndexOp::*;
96111

97112
match item {
98113
TakeIndex { index } => {
99-
let size = src.dim(axis);
114+
let size = src.dim(axis as i32);
100115
let index = if index.is_negative() {
101116
size + index
102117
} else {
103118
*index
104119
};
105120
// SAFETY: axis is always non-negative
106-
starts[axis as usize] = index;
107-
ends[axis as usize] = index + 1;
108-
if ndim - (axis as usize) < update.ndim() {
109-
update_expand_dims.push(axis.saturating_sub_unsigned(ndim as u32));
110-
}
121+
starts[axis] = index;
122+
ends[axis] = index.saturating_add(1);
111123

112-
axis = axis.saturating_add(1);
124+
update_reshape[axis] = 1;
125+
axis = axis.saturating_sub(1);
113126
}
114-
Slice(range_index) => {
115-
let size = src.dim(axis);
127+
Slice(slice) => {
128+
let size = src.dim(axis as i32);
116129
// SAFETY: axis is always non-negative
117-
starts[axis as usize] = range_index.start(size);
118-
ends[axis as usize] = range_index.end(size);
119-
strides[axis as usize] = range_index.stride();
120-
axis = axis.saturating_add(1);
130+
starts[axis] = slice.start(size);
131+
ends[axis] = slice.end(size);
132+
strides[axis] = slice.stride();
133+
134+
if update_axis >= 0 {
135+
update_reshape[axis] = update.dim(update_axis);
136+
update_axis = update_axis.saturating_sub(1);
137+
} else {
138+
update_reshape[axis] = 1;
139+
}
140+
axis = axis.saturating_sub(1);
121141
}
122142
ExpandDims => {}
123143
Ellipsis | TakeArray { indices: _ } | TakeArrayRef { indices: _ } => {
@@ -126,8 +146,8 @@ fn update_slice(
126146
}
127147
}
128148

129-
if !update_expand_dims.is_empty() {
130-
update = Cow::Owned(update.expand_dims_axes_device(&update_expand_dims, &stream)?);
149+
if update.shape() != &update_reshape[..] {
150+
update = Cow::Owned(reshape_device(update, &update_reshape, &stream)?);
131151
}
132152

133153
Ok(Some(src.slice_update_device(
@@ -1277,7 +1297,10 @@ where
12771297
/// The unit tests below are adapted from the Swift binding tests
12781298
#[cfg(test)]
12791299
mod tests {
1280-
use crate::{ops::indexing::*, Array};
1300+
use crate::{
1301+
ops::{indexing::*, ones, zeros},
1302+
Array,
1303+
};
12811304

12821305
#[test]
12831306
fn test_array_mutate_single_index() {
@@ -1473,4 +1496,16 @@ mod tests {
14731496
128142
14741497
);
14751498
}
1499+
1500+
#[test]
1501+
fn test_slice_update_with_broadcast() {
1502+
let mut xs = zeros::<f32>(&[4, 3, 2]).unwrap();
1503+
let x = ones::<f32>(&[4, 2]).unwrap();
1504+
1505+
let result = xs.try_index_mut((.., 0, ..), x);
1506+
assert!(
1507+
result.is_ok(),
1508+
"Failed to update slice with broadcast: {result:?}"
1509+
);
1510+
}
14761511
}

0 commit comments

Comments
 (0)