Skip to content

Commit

Permalink
Add Float16 pattern matches
Browse files Browse the repository at this point in the history
  • Loading branch information
tmcgilchrist committed May 17, 2024
1 parent c3230eb commit ef69ebe
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 0 deletions.
1 change: 1 addition & 0 deletions owl-base.opam
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ depends: [
"ocaml" {>= "4.10.0"}
"base-bigarray"
"dune" {>= "2.0.0"}
"cppo"
]
15 changes: 15 additions & 0 deletions src/base/core/owl_const.ml → src/base/core/owl_const.cppo.ml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ let eps = 1e-15
(** Functions that return constants using Bigarray kind *)

let zero : type a b. (a, b) kind -> a = function
#if OCAML_VERSION >= (5, 2, 0)
| Float16 -> 0.0
#endif
| Float32 -> 0.0
| Complex32 -> Complex.zero
| Float64 -> 0.0
Expand All @@ -62,6 +65,9 @@ let zero : type a b. (a, b) kind -> a = function


let one : type a b. (a, b) kind -> a = function
#if OCAML_VERSION >= (5, 2, 0)
| Float16 -> 1.0
#endif
| Float32 -> 1.0
| Complex32 -> Complex.one
| Float64 -> 1.0
Expand All @@ -78,6 +84,9 @@ let one : type a b. (a, b) kind -> a = function


let neg_one : type a b. (a, b) kind -> a = function
#if OCAML_VERSION >= (5, 2, 0)
| Float16 -> -1.0
#endif
| Float32 -> -1.0
| Float64 -> -1.0
| Complex32 -> Complex.{ re = -1.; im = 0. }
Expand All @@ -94,6 +103,9 @@ let neg_one : type a b. (a, b) kind -> a = function


let pos_inf : type a b. (a, b) kind -> a = function
#if OCAML_VERSION >= (5, 2, 0)
| Float16 -> infinity
#endif
| Float32 -> infinity
| Float64 -> infinity
| Complex32 -> Complex.{ re = infinity; im = infinity }
Expand All @@ -102,6 +114,9 @@ let pos_inf : type a b. (a, b) kind -> a = function


let neg_inf : type a b. (a, b) kind -> a = function
#if OCAML_VERSION >= (5, 2, 0)
| Float16 -> neg_infinity
#endif
| Float32 -> neg_infinity
| Float64 -> neg_infinity
| Complex32 -> Complex.{ re = neg_infinity; im = neg_infinity }
Expand Down
10 changes: 10 additions & 0 deletions src/base/dune
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@

(copy_files# misc/*)

(rule
(targets owl_const.ml)
(deps (:out owl_const.cppo.ml))
(action (run cppo -V OCAML:%{ocaml_version} %{out} -o %{targets})))

(rule
(targets owl_utils_ndarray.ml)
(deps (:out owl_utils_ndarray.cppo.ml))
(action (run cppo -V OCAML:%{ocaml_version} %{out} -o %{targets})))

(library
(name owl_base)
(public_name owl-base)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ let elt_to_str : type a b. (a, b) kind -> a -> string = function
| Int -> fun v -> Printf.sprintf "%i" v
| Int32 -> fun v -> Printf.sprintf "%ld" v
| Int64 -> fun v -> Printf.sprintf "%Ld" v
#if OCAML_VERSION >= (5, 2, 0)
| Float16 -> fun v -> Printf.sprintf "%G" v
#endif
| Float32 -> fun v -> Printf.sprintf "%G" v
| Float64 -> fun v -> Printf.sprintf "%G" v
| Complex32 -> fun v -> Printf.sprintf "(%G, %Gi)" Complex.(v.re) Complex.(v.im)
Expand All @@ -33,6 +36,9 @@ let elt_of_str : type a b. (a, b) kind -> string -> a = function
| Int -> fun v -> int_of_string v
| Int32 -> fun v -> Int32.of_string v
| Int64 -> fun v -> Int64.of_string v
#if OCAML_VERSION >= (5, 2, 0)
| Float16 -> fun v -> float_of_string v
#endif
| Float32 -> fun v -> float_of_string v
| Float64 -> fun v -> float_of_string v
| Complex32 -> fun v ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,10 @@ let symv
let _y = bigarray_start Ctypes_static.Array1 y in
let _a = bigarray_start Ctypes_static.Array1 a in
match Bigarray.Array1.kind x with
#if OCAML_VERSION >= (5, 2, 0)
| Bigarray.Float16 ->
C.ssymv ~order:_layout ~uplo:_uplo ~n ~alpha ~a:_a ~lda ~x:_x ~incx ~beta ~y:_y ~incy
#endif
| Bigarray.Float32 ->
C.ssymv ~order:_layout ~uplo:_uplo ~n ~alpha ~a:_a ~lda ~x:_x ~incx ~beta ~y:_y ~incy
| Bigarray.Float64 ->
Expand Down Expand Up @@ -887,6 +891,22 @@ let sbmv
let _y = bigarray_start Ctypes_static.Array1 y in
let _a = bigarray_start Ctypes_static.Array1 a in
match Bigarray.Array1.kind x with
#if OCAML_VERSION >= (5, 2, 0)
| Bigarray.Float16 ->
C.ssbmv
~order:_layout
~uplo:_uplo
~n
~k
~alpha
~a:_a
~lda
~x:_x
~incx
~beta
~y:_y
~incy
#endif
| Bigarray.Float32 ->
C.ssbmv
~order:_layout
Expand Down Expand Up @@ -941,6 +961,10 @@ let spmv
let _y = bigarray_start Ctypes_static.Array1 y in
let _ap = bigarray_start Ctypes_static.Array1 ap in
match Bigarray.Array1.kind x with
#if OCAML_VERSION >= (5, 2, 0)
| Bigarray.Float16 ->
C.sspmv ~order:_layout ~uplo:_uplo ~n ~alpha ~ap:_ap ~x:_x ~incx ~beta ~y:_y ~incy
#endif
| Bigarray.Float32 ->
C.sspmv ~order:_layout ~uplo:_uplo ~n ~alpha ~ap:_ap ~x:_x ~incx ~beta ~y:_y ~incy
| Bigarray.Float64 ->
Expand Down Expand Up @@ -1049,6 +1073,9 @@ let syr
let _x = bigarray_start Ctypes_static.Array1 x in
let _a = bigarray_start Ctypes_static.Array1 a in
match Bigarray.Array1.kind x with
#if OCAML_VERSION >= (5, 2, 0)
| Bigarray.Float16 -> C.ssyr ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~a:_a ~lda
#endif
| Bigarray.Float32 -> C.ssyr ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~a:_a ~lda
| Bigarray.Float64 -> C.dsyr ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~a:_a ~lda

Expand All @@ -1072,6 +1099,9 @@ let spr
let _x = bigarray_start Ctypes_static.Array1 x in
let _ap = bigarray_start Ctypes_static.Array1 ap in
match Bigarray.Array1.kind x with
#if OCAML_VERSION >= (5, 2, 0)
| Bigarray.Float16 -> C.sspr ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~ap:_ap
#endif
| Bigarray.Float32 -> C.sspr ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~ap:_ap
| Bigarray.Float64 -> C.dspr ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~ap:_ap

Expand Down Expand Up @@ -1099,6 +1129,10 @@ let syr2
let _y = bigarray_start Ctypes_static.Array1 y in
let _a = bigarray_start Ctypes_static.Array1 a in
match Bigarray.Array1.kind x with
#if OCAML_VERSION >= (5, 2, 0)
| Bigarray.Float16 ->
C.ssyr2 ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~y:_y ~incy ~a:_a ~lda
#endif
| Bigarray.Float32 ->
C.ssyr2 ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~y:_y ~incy ~a:_a ~lda
| Bigarray.Float64 ->
Expand Down Expand Up @@ -1127,6 +1161,10 @@ let spr2
let _y = bigarray_start Ctypes_static.Array1 y in
let _a = bigarray_start Ctypes_static.Array1 a in
match Bigarray.Array1.kind x with
#if OCAML_VERSION >= (5, 2, 0)
| Bigarray.Float16 ->
C.sspr2 ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~y:_y ~incy ~a:_a
#endif
| Bigarray.Float32 ->
C.sspr2 ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~y:_y ~incy ~a:_a
| Bigarray.Float64 ->
Expand Down
10 changes: 10 additions & 0 deletions src/owl/dune
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@

(copy_files# signal/*)

(rule
(targets owl_cblas_basic.ml)
(deps (:out owl_cblas_basic.cppo.ml))
(action (run cppo -V OCAML:%{ocaml_version} %{out} -o %{targets})))

(rule
(targets owl_lapacke.ml)
(deps (:out owl_lapacke.cppo.ml))
(action (run cppo -V OCAML:%{ocaml_version} %{out} -o %{targets})))

(library
(name owl)
(public_name owl)
Expand Down

0 comments on commit ef69ebe

Please sign in to comment.