Skip to content

Commit

Permalink
Merge 7ff510f into aa93a76
Browse files Browse the repository at this point in the history
  • Loading branch information
elcritch committed Jan 7, 2020
2 parents aa93a76 + 7ff510f commit 25c6c7b
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 37 deletions.
101 changes: 65 additions & 36 deletions lib/matrex.ex
Expand Up @@ -298,31 +298,6 @@ defmodule Matrex do

@behaviour Access

defmacrop matrex_data(rows, columns, body) do
quote do
%Matrex{
data: <<
unquote(rows)::unsigned-integer-little-32,
unquote(columns)::unsigned-integer-little-32,
unquote(body)::binary
>>
}
end
end

defmacrop matrex_data(rows, columns, body, data) do
quote do
%Matrex{
data:
<<
unquote(rows)::unsigned-integer-little-32,
unquote(columns)::unsigned-integer-little-32,
unquote(body)::binary
>> = unquote(data)
}
end
end

@impl Access
def fetch(matrex, key)

Expand Down Expand Up @@ -416,17 +391,7 @@ defmodule Matrex do
# Matrix element size in bytes
@element_size 4

defmacrop matrex_data(rows, columns, data) do
quote do
%Matrex{
data: <<
unquote(rows)::unsigned-integer-little-32,
unquote(columns)::unsigned-integer-little-32,
unquote(data)::binary
>>
}
end
end
import Matrex.Guards

@doc false
def count(matrex_data(rows, cols, _data)), do: {:ok, rows * cols}
Expand Down Expand Up @@ -1082,6 +1047,32 @@ defmodule Matrex do
when columns1 == rows2,
do: %Matrex{data: NIFs.dot(first, second)}

@doc """
Matrix inner product for two "vector" matrices (e.g. rows == 1 and columns >= 1).
Number of columns of the first matrix must be equal to the number of rows of the second matrix.
Raises `ErlangError` if matrices' sizes do not match.
## Example
iex> Matrex.new([[1, 2, 3], [4, 5, 6]]) |>
...> Matrex.dot(Matrex.new([[1, 2], [3, 4], [5, 6]]))
#Matrex[2×2]
┌ ┐
│ 22.0 28.0 │
│ 49.0 64.0 │
└ ┘
"""
@spec inner_dot(matrex, matrex) :: matrex
def inner_dot(
vector_data(columns1, _data1, first),
vector_data(columns2, _data2, second)
)
when columns1 == columns2,
do: %Matrex{data: NIFs.dot_nt(first, second)}

@doc """
Matrix multiplication with addition of third matrix. NIF, via `cblas_sgemm()`.
Expand Down Expand Up @@ -1699,6 +1690,22 @@ defmodule Matrex do
new_matrix_from_function(size, rows, columns, function, initial)
end

@doc """
Creates new 1-column matrix (aka vector) from the given list.
## Examples
iex> [1,2,3] |> Matrex.from_list()
#Matrex[1×3]
┌ ┐
│ 1.0 2.0 3.0 │
└ ┘
"""
def from_list(lst) when is_list(lst) do
new([lst])
end

@spec float_to_binary(element | :nan | :inf | :neg_inf) :: binary
defp float_to_binary(val) when is_number(val), do: <<val::float-little-32>>
defp float_to_binary(:nan), do: @not_a_number
Expand Down Expand Up @@ -2410,6 +2417,28 @@ defmodule Matrex do
@spec square(matrex) :: matrex
def square(%Matrex{data: matrix}), do: %Matrex{data: Matrex.NIFs.multiply(matrix, matrix)}

@doc """
Produces element-wise pow matrix. NIF through `power/2`.
## Example
iex> m = Matrex.new("1 2 3; 4 5 6")
#Matrex[2×3]
┌ ┐
│ 1.0 2.0 3.0 │
│ 4.0 5.0 6.0 │
└ ┘
iex> Matrex.pow(m, 2)
#Matrex[2×3]
┌ ┐
│ 1.0 4.0 9.0 │
│ 16.0 25.0 36.0 │
└ ┘
"""
@spec pow(matrex, number) :: matrex
def pow(%Matrex{data: matrix}, exponent), do: %Matrex{data: Matrex.NIFs.power(exponent, matrix)}

@doc """
Returns submatrix for a given matrix. NIF.
Expand Down
50 changes: 50 additions & 0 deletions lib/matrex/guards.ex
Expand Up @@ -7,4 +7,54 @@ defmodule Matrex.Guards do
unquote(row) >= 1 and unquote(row) <= unquote(rows) and unquote(col) >= 1 and
unquote(col) <= unquote(columns)
)

defmacro vector_data(size, body) do
quote do
%Matrex{
data: <<
<<1, 0, 0, 0>>,
unquote(size)::unsigned-integer-little-32,
unquote(body)::binary
>>
}
end
end

defmacro vector_data(size, body, data) do
quote do
%Matrex{
data: <<
<<1, 0, 0, 0>>,
unquote(size)::unsigned-integer-little-32,
unquote(body)::binary
>> = unquote(data)
}
end
end

defmacro matrex_data(rows, columns, body) do
quote do
%Matrex{
data: <<
unquote(rows)::unsigned-integer-little-32,
unquote(columns)::unsigned-integer-little-32,
unquote(body)::binary
>>
}
end
end

defmacro matrex_data(rows, columns, body, data) do
quote do
%Matrex{
data:
<<
unquote(rows)::unsigned-integer-little-32,
unquote(columns)::unsigned-integer-little-32,
unquote(body)::binary
>> = unquote(data)
}
end
end

end
5 changes: 5 additions & 0 deletions lib/matrex/nifs.ex
Expand Up @@ -88,6 +88,11 @@ defmodule Matrex.NIFs do
when is_binary(first) and is_binary(second),
do: :erlang.nif_error(:nif_library_not_loaded)

@spec power(number, binary) :: binary
def power(exponent, matrix)
when is_number(exponent) and is_binary(matrix),
do: :erlang.nif_error(:nif_library_not_loaded)

@spec divide(binary, binary) :: binary
def divide(first, second)
when is_binary(first) and is_binary(second),
Expand Down
3 changes: 3 additions & 0 deletions native/include/matrix.h
Expand Up @@ -57,6 +57,9 @@ matrix_argmax(const Matrix matrix);
void
matrix_concat_columns(const Matrix first, const Matrix second, Matrix result);

void
matrix_pow(const float scalar, const Matrix matrex, Matrix result);

void
matrix_divide(const Matrix first, const Matrix second, Matrix result);

Expand Down
25 changes: 25 additions & 0 deletions native/nifs/matrix_nifs.c
Expand Up @@ -248,6 +248,30 @@ concat_columns(ErlNifEnv *env, int32_t argc, const ERL_NIF_TERM *argv) {
return result;
}

static ERL_NIF_TERM
power(ErlNifEnv *env, int argc, const ERL_NIF_TERM *argv) {
ErlNifBinary matrix;
ERL_NIF_TERM result;
float scalar;
float *matrix_data, *result_data;
uint64_t data_size;
size_t result_size;

(void)(argc);

scalar = get_scalar(env, argv[0]);
if (!enif_inspect_binary(env, argv[1], &matrix)) return enif_make_badarg(env);

matrix_data = (float *) matrix.data;
data_size = MX_LENGTH(matrix_data);

result_size = sizeof(float) * data_size;
result_data = (float *) enif_make_new_binary(env, result_size, &result);

matrix_pow(scalar, matrix_data, result_data);

return result;
}

static ERL_NIF_TERM
divide(ErlNifEnv *env, int32_t argc, const ERL_NIF_TERM *argv) {
Expand Down Expand Up @@ -1242,6 +1266,7 @@ static ErlNifFunc nif_functions[] = {
{"argmax", 1, argmax, 0},
{"column_to_list", 2, column_to_list, 0},
{"concat_columns", 2, concat_columns, 0},
{"power", 2, power, 0},
{"divide", 2, divide, 0},
{"divide_scalar", 2, divide_scalar, 0},
{"divide_by_scalar", 2, divide_by_scalar, 0},
Expand Down
11 changes: 11 additions & 0 deletions native/src/matrix.c
Expand Up @@ -181,6 +181,17 @@ matrix_concat_columns(const Matrix first, const Matrix second, Matrix result) {
}
}

void
matrix_pow(const float scalar, const Matrix matrix, Matrix result) {
const int64_t data_size = MX_LENGTH(matrix);

MX_SET_ROWS(result, MX_ROWS(matrix));
MX_SET_COLS(result, MX_COLS(matrix));

for (int64_t index = 2; index < data_size; index += 1) {
result[index] = powf(matrix[index], scalar);
}
}

void
matrix_divide(const Matrix first, const Matrix second, Matrix result) {
Expand Down
3 changes: 2 additions & 1 deletion test/algorithms_test.exs
Expand Up @@ -89,7 +89,7 @@ defmodule AlgorithmsTest do

# Split data into training and testing set, permute it randomly
@spec split_data(Matrex.t(), Matrex.t()) :: {Matrex.t(), Matrex.t(), Matrex.t(), Matrex.t()}
defp split_data(x, y) do
def split_data(x, y) do
n = x[:rows]
n_train = trunc(0.8 * n)
n_test = n - n_train
Expand All @@ -115,4 +115,5 @@ defmodule AlgorithmsTest do

{x_train, y_train, x_test, y_test}
end

end

0 comments on commit 25c6c7b

Please sign in to comment.