# NumPy with `rustimport_jupyter`

For Google colab, we install [rustimport_jupyter](https://github.com/thomasjpfan/rustimport_jupyter) and the rust toolchain:

In [1]:
import os
import sys

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install rustimport_jupyter
    !curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
    os.environ["PATH"] += ":/root/.cargo/bin"

In [3]:
%load_ext rustimport_jupyter

## Simple NumPy Function Written in Rust

Based on the example from [PyO3/rust-numpy](https://github.com/PyO3/rust-numpy#example), we define a NumPy function that computes `a*x+y`:

In [5]:
%%rustimport --release
//: [dependencies]
//: pyo3 = { version = "0.20", features = ["extension-module"] }
//: numpy = "0.20"

use pyo3::prelude::*;
use numpy::ndarray::{ArrayD, ArrayViewD};
use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn};

fn axpy(a: f64, x: ArrayViewD<'_, f64>, y: ArrayViewD<'_, f64>) -> ArrayD<f64> {
    a * &x + &y
}

#[pyfunction]
#[pyo3(name = "axpy")]
fn axpy_py<'py>(
    py: Python<'py>,
    a: f64,
    x: PyReadonlyArrayDyn<'py, f64>,
    y: PyReadonlyArrayDyn<'py, f64>,
) -> &'py PyArrayDyn<f64> {
    let x = x.as_array();
    let y = y.as_array();
    let z = axpy(a, x, y);
    z.into_pyarray(py)
}

[1m[32m    Updating[0m crates.io index
[1m[32m   Compiling[0m autocfg v1.1.0
[1m[32m   Compiling[0m target-lexicon v0.12.12
[1m[32m   Compiling[0m once_cell v1.19.0
[1m[32m   Compiling[0m proc-macro2 v1.0.71
[1m[32m   Compiling[0m libc v0.2.151
[1m[32m   Compiling[0m unicode-ident v1.0.12
[1m[32m   Compiling[0m parking_lot_core v0.9.9
[1m[32m   Compiling[0m heck v0.4.1
[1m[32m   Compiling[0m scopeguard v1.2.0
[1m[32m   Compiling[0m smallvec v1.11.2
[1m[32m   Compiling[0m cfg-if v1.0.0
[1m[32m   Compiling[0m rawpointer v0.2.1
[1m[32m   Compiling[0m indoc v2.0.4
[1m[32m   Compiling[0m unindent v0.2.3
[1m[32m   Compiling[0m rustc-hash v1.1.0
[1m[32m   Compiling[0m num-traits v0.2.17
[1m[32m   Compiling[0m lock_api v0.4.11
[1m[32m   Compiling[0m memoffset v0.9.0
[1m[32m   Compiling[0m matrixmultiply v0.3.8
[1m[32m   Compiling[0m num-integer v0.1.45
[1m[32m   Compiling[0m quote v1.0.33
[1m[32m   Compiling[0m pyo3-build-con

Using `pyo3(name=axpy)`, the public python function is named `axpy`, which is available in the notebook environment:

In [6]:
import numpy as np

a = 4.4
x = np.array([1.0, 3.0, 4.0], dtype=np.float64)
y = np.array([2.1, -1.0, -4.0], dtype=np.float64)

axpy(a, x, y)

array([ 6.5, 12.2, 13.6])