Skip to content

[Feature request] ArrayRef<A, Ix2>.dot() for axis greater than Ix2 #1587

@boyleconnor

Description

@boyleconnor

In NumPy, the left-hand side of a matrix multiplication can have as many axes as desired, as long as it has more than 2 axes and the last axis's dimension matches that of the 0th axis of the right-hand side, e.g.:

import numpy as np
x = np.random.random((3, 2, 5, 9, 12))
y = np.random.random((12, 13))
(x @ y).shape

# (3, 2, 5, 9, 13)

In ndarray, you can't do this directly:

// Doesn't compile:
use ndarray::prelude::*;
use ndarray_rand::RandomExt;
use ndarray_rand::rand_distr::Uniform;


fn main() {
    let x: Array<f64, Ix3> = Array::random(
        (12, 4, 3),
        Uniform::new(0., 1.).unwrap()
    );
    let y: Array<f64, Ix2> = Array::random(
        (3, 2),
        Uniform::new(0., 1.).unwrap()
    );
    let x_y = x.dot(&y);
    println!("{}", x_y);
}
Compiler Output
$ cargo run
   Compiling playground v0.1.0 (/home/connor/RustroverProjects/playground)
error[E0275]: overflow evaluating the requirement `&ArrayBase<_, _, _>: Not`
  --> src/main.rs:15:17
   |
15 |     let x_y = x.dot(&y);
   |                 ^^^
   |
   = help: consider increasing the recursion limit by adding a `#![recursion_limit = "256"]` attribute to your crate (`playground`)
   = note: required for `&ArrayBase<_, _, _>` to implement `Not`
   = note: 127 redundant requirements hidden
   = note: required for `&ArrayBase<OwnedRepr<f64>, Dim<[usize; 3]>, f64>` to implement `Not`

For more information about this error, try `rustc --explain E0275`.
error: could not compile `playground` (bin "playground") due to 1 previous error

Emulating the behavior in the previous NumPy example requires a non-trivial amount of work, e.g. for x with 3 axes:

use ndarray::prelude::*;
use ndarray_rand::RandomExt;
use ndarray_rand::rand_distr::Uniform;


fn main() {
    let x: Array<f64, Ix3> = Array::random((12, 4, 3), Uniform::new(0., 1.).unwrap());
    let y = Array::random((3, 2), Uniform::new(0., 1.).unwrap());

    let (a, b, c) = (x.len_of(Axis(0)), x.len_of(Axis(1)), x.len_of(Axis(2)));
    let d = y.len_of(Axis(1));

    let x_y: Array3<f64> = x
        .to_shape((a * b, c)).unwrap()
        .dot(&y)
        .to_shape((a, b, d)).unwrap()
        .to_owned();
    println!("{:?}", x_y);
}

Therefore, I think it would be nice to have dot() be implemented for axis numbers greater than Ix2

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions