From ef69ebe19f7e3b5ff1e51453999171e2d4b32ced Mon Sep 17 00:00:00 2001 From: Tim McGilchrist Date: Fri, 17 May 2024 13:35:37 +1000 Subject: [PATCH] Add Float16 pattern matches --- owl-base.opam | 1 + .../core/{owl_const.ml => owl_const.cppo.ml} | 15 ++++ src/base/dune | 10 +++ ...s_ndarray.ml => owl_utils_ndarray.cppo.ml} | 6 ++ ...cblas_basic.ml => owl_cblas_basic.cppo.ml} | 38 ++++++++ src/owl/dune | 10 +++ .../{owl_lapacke.ml => owl_lapacke.cppo.ml} | 86 +++++++++++++++++++ 7 files changed, 166 insertions(+) rename src/base/core/{owl_const.ml => owl_const.cppo.ml} (97%) rename src/base/misc/{owl_utils_ndarray.ml => owl_utils_ndarray.cppo.ml} (96%) rename src/owl/cblas/{owl_cblas_basic.ml => owl_cblas_basic.cppo.ml} (97%) rename src/owl/lapacke/{owl_lapacke.ml => owl_lapacke.cppo.ml} (98%) diff --git a/owl-base.opam b/owl-base.opam index b07e4e1fc..966b5bbc5 100644 --- a/owl-base.opam +++ b/owl-base.opam @@ -17,4 +17,5 @@ depends: [ "ocaml" {>= "4.10.0"} "base-bigarray" "dune" {>= "2.0.0"} + "cppo" ] diff --git a/src/base/core/owl_const.ml b/src/base/core/owl_const.cppo.ml similarity index 97% rename from src/base/core/owl_const.ml rename to src/base/core/owl_const.cppo.ml index d3bbd321d..27e5b715b 100644 --- a/src/base/core/owl_const.ml +++ b/src/base/core/owl_const.cppo.ml @@ -46,6 +46,9 @@ let eps = 1e-15 (** Functions that return constants using Bigarray kind *) let zero : type a b. (a, b) kind -> a = function +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> 0.0 +#endif | Float32 -> 0.0 | Complex32 -> Complex.zero | Float64 -> 0.0 @@ -62,6 +65,9 @@ let zero : type a b. (a, b) kind -> a = function let one : type a b. (a, b) kind -> a = function +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> 1.0 +#endif | Float32 -> 1.0 | Complex32 -> Complex.one | Float64 -> 1.0 @@ -78,6 +84,9 @@ let one : type a b. (a, b) kind -> a = function let neg_one : type a b. (a, b) kind -> a = function +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> -1.0 +#endif | Float32 -> -1.0 | Float64 -> -1.0 | Complex32 -> Complex.{ re = -1.; im = 0. } @@ -94,6 +103,9 @@ let neg_one : type a b. (a, b) kind -> a = function let pos_inf : type a b. (a, b) kind -> a = function +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> infinity +#endif | Float32 -> infinity | Float64 -> infinity | Complex32 -> Complex.{ re = infinity; im = infinity } @@ -102,6 +114,9 @@ let pos_inf : type a b. (a, b) kind -> a = function let neg_inf : type a b. (a, b) kind -> a = function +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> neg_infinity +#endif | Float32 -> neg_infinity | Float64 -> neg_infinity | Complex32 -> Complex.{ re = neg_infinity; im = neg_infinity } diff --git a/src/base/dune b/src/base/dune index 00e39c685..88c0e665f 100644 --- a/src/base/dune +++ b/src/base/dune @@ -20,6 +20,16 @@ (copy_files# misc/*) +(rule + (targets owl_const.ml) + (deps (:out owl_const.cppo.ml)) + (action (run cppo -V OCAML:%{ocaml_version} %{out} -o %{targets}))) + +(rule + (targets owl_utils_ndarray.ml) + (deps (:out owl_utils_ndarray.cppo.ml)) + (action (run cppo -V OCAML:%{ocaml_version} %{out} -o %{targets}))) + (library (name owl_base) (public_name owl-base) diff --git a/src/base/misc/owl_utils_ndarray.ml b/src/base/misc/owl_utils_ndarray.cppo.ml similarity index 96% rename from src/base/misc/owl_utils_ndarray.ml rename to src/base/misc/owl_utils_ndarray.cppo.ml index 1ff9a58ac..e4c90555a 100644 --- a/src/base/misc/owl_utils_ndarray.ml +++ b/src/base/misc/owl_utils_ndarray.cppo.ml @@ -16,6 +16,9 @@ let elt_to_str : type a b. (a, b) kind -> a -> string = function | Int -> fun v -> Printf.sprintf "%i" v | Int32 -> fun v -> Printf.sprintf "%ld" v | Int64 -> fun v -> Printf.sprintf "%Ld" v +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> fun v -> Printf.sprintf "%G" v +#endif | Float32 -> fun v -> Printf.sprintf "%G" v | Float64 -> fun v -> Printf.sprintf "%G" v | Complex32 -> fun v -> Printf.sprintf "(%G, %Gi)" Complex.(v.re) Complex.(v.im) @@ -33,6 +36,9 @@ let elt_of_str : type a b. (a, b) kind -> string -> a = function | Int -> fun v -> int_of_string v | Int32 -> fun v -> Int32.of_string v | Int64 -> fun v -> Int64.of_string v +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> fun v -> float_of_string v +#endif | Float32 -> fun v -> float_of_string v | Float64 -> fun v -> float_of_string v | Complex32 -> fun v -> diff --git a/src/owl/cblas/owl_cblas_basic.ml b/src/owl/cblas/owl_cblas_basic.cppo.ml similarity index 97% rename from src/owl/cblas/owl_cblas_basic.ml rename to src/owl/cblas/owl_cblas_basic.cppo.ml index a6387f3dd..5f9746135 100644 --- a/src/owl/cblas/owl_cblas_basic.ml +++ b/src/owl/cblas/owl_cblas_basic.cppo.ml @@ -856,6 +856,10 @@ let symv let _y = bigarray_start Ctypes_static.Array1 y in let _a = bigarray_start Ctypes_static.Array1 a in match Bigarray.Array1.kind x with +#if OCAML_VERSION >= (5, 2, 0) + | Bigarray.Float16 -> + C.ssymv ~order:_layout ~uplo:_uplo ~n ~alpha ~a:_a ~lda ~x:_x ~incx ~beta ~y:_y ~incy +#endif | Bigarray.Float32 -> C.ssymv ~order:_layout ~uplo:_uplo ~n ~alpha ~a:_a ~lda ~x:_x ~incx ~beta ~y:_y ~incy | Bigarray.Float64 -> @@ -887,6 +891,22 @@ let sbmv let _y = bigarray_start Ctypes_static.Array1 y in let _a = bigarray_start Ctypes_static.Array1 a in match Bigarray.Array1.kind x with +#if OCAML_VERSION >= (5, 2, 0) + | Bigarray.Float16 -> + C.ssbmv + ~order:_layout + ~uplo:_uplo + ~n + ~k + ~alpha + ~a:_a + ~lda + ~x:_x + ~incx + ~beta + ~y:_y + ~incy +#endif | Bigarray.Float32 -> C.ssbmv ~order:_layout @@ -941,6 +961,10 @@ let spmv let _y = bigarray_start Ctypes_static.Array1 y in let _ap = bigarray_start Ctypes_static.Array1 ap in match Bigarray.Array1.kind x with +#if OCAML_VERSION >= (5, 2, 0) + | Bigarray.Float16 -> + C.sspmv ~order:_layout ~uplo:_uplo ~n ~alpha ~ap:_ap ~x:_x ~incx ~beta ~y:_y ~incy +#endif | Bigarray.Float32 -> C.sspmv ~order:_layout ~uplo:_uplo ~n ~alpha ~ap:_ap ~x:_x ~incx ~beta ~y:_y ~incy | Bigarray.Float64 -> @@ -1049,6 +1073,9 @@ let syr let _x = bigarray_start Ctypes_static.Array1 x in let _a = bigarray_start Ctypes_static.Array1 a in match Bigarray.Array1.kind x with +#if OCAML_VERSION >= (5, 2, 0) + | Bigarray.Float16 -> C.ssyr ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~a:_a ~lda +#endif | Bigarray.Float32 -> C.ssyr ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~a:_a ~lda | Bigarray.Float64 -> C.dsyr ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~a:_a ~lda @@ -1072,6 +1099,9 @@ let spr let _x = bigarray_start Ctypes_static.Array1 x in let _ap = bigarray_start Ctypes_static.Array1 ap in match Bigarray.Array1.kind x with +#if OCAML_VERSION >= (5, 2, 0) + | Bigarray.Float16 -> C.sspr ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~ap:_ap +#endif | Bigarray.Float32 -> C.sspr ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~ap:_ap | Bigarray.Float64 -> C.dspr ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~ap:_ap @@ -1099,6 +1129,10 @@ let syr2 let _y = bigarray_start Ctypes_static.Array1 y in let _a = bigarray_start Ctypes_static.Array1 a in match Bigarray.Array1.kind x with +#if OCAML_VERSION >= (5, 2, 0) + | Bigarray.Float16 -> + C.ssyr2 ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~y:_y ~incy ~a:_a ~lda +#endif | Bigarray.Float32 -> C.ssyr2 ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~y:_y ~incy ~a:_a ~lda | Bigarray.Float64 -> @@ -1127,6 +1161,10 @@ let spr2 let _y = bigarray_start Ctypes_static.Array1 y in let _a = bigarray_start Ctypes_static.Array1 a in match Bigarray.Array1.kind x with +#if OCAML_VERSION >= (5, 2, 0) + | Bigarray.Float16 -> + C.sspr2 ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~y:_y ~incy ~a:_a +#endif | Bigarray.Float32 -> C.sspr2 ~order:_layout ~uplo:_uplo ~n ~alpha ~x:_x ~incx ~y:_y ~incy ~a:_a | Bigarray.Float64 -> diff --git a/src/owl/dune b/src/owl/dune index 0dcc3c817..b601df419 100644 --- a/src/owl/dune +++ b/src/owl/dune @@ -38,6 +38,16 @@ (copy_files# signal/*) +(rule + (targets owl_cblas_basic.ml) + (deps (:out owl_cblas_basic.cppo.ml)) + (action (run cppo -V OCAML:%{ocaml_version} %{out} -o %{targets}))) + +(rule + (targets owl_lapacke.ml) + (deps (:out owl_lapacke.cppo.ml)) + (action (run cppo -V OCAML:%{ocaml_version} %{out} -o %{targets}))) + (library (name owl) (public_name owl) diff --git a/src/owl/lapacke/owl_lapacke.ml b/src/owl/lapacke/owl_lapacke.cppo.ml similarity index 98% rename from src/owl/lapacke/owl_lapacke.ml rename to src/owl/lapacke/owl_lapacke.cppo.ml index c6bb56913..82ac2df4f 100644 --- a/src/owl/lapacke/owl_lapacke.ml +++ b/src/owl/lapacke/owl_lapacke.cppo.ml @@ -101,6 +101,9 @@ let _gbtrf let ldab = _stride ab in let ret = match _kind with +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> L.sgbtrf ~layout ~m ~n ~kl ~ku ~ab:_ab ~ldab ~ipiv:_ipiv +#endif | Float32 -> L.sgbtrf ~layout ~m ~n ~kl ~ku ~ab:_ab ~ldab ~ipiv:_ipiv | Float64 -> L.dgbtrf ~layout ~m ~n ~kl ~ku ~ab:_ab ~ldab ~ipiv:_ipiv | Complex32 -> L.cgbtrf ~layout ~m ~n ~kl ~ku ~ab:_ab ~ldab ~ipiv:_ipiv @@ -557,6 +560,9 @@ let ormrz let ldc = Stdlib.max 1 (_stride c) in let ret = match _kind with +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> L.sormrz ~layout ~side ~trans ~m ~n ~k ~l ~a:_a ~lda ~tau:_tau ~c:_c ~ldc +#endif | Float32 -> L.sormrz ~layout ~side ~trans ~m ~n ~k ~l ~a:_a ~lda ~tau:_tau ~c:_c ~ldc | Float64 -> L.dormrz ~layout ~side ~trans ~m ~n ~k ~l ~a:_a ~lda ~tau:_tau ~c:_c ~ldc in @@ -2132,6 +2138,9 @@ let orglq : type a. ?k:int -> a:(float, a) t -> tau:(float, a) t -> (float, a) t let _tau = bigarray_start Ctypes_static.Genarray tau in let ret = match _kind with +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> L.sorglq ~layout ~m:minmn ~n ~k ~a:_a ~lda ~tau:_tau +#endif | Float32 -> L.sorglq ~layout ~m:minmn ~n ~k ~a:_a ~lda ~tau:_tau | Float64 -> L.dorglq ~layout ~m:minmn ~n ~k ~a:_a ~lda ~tau:_tau in @@ -2196,6 +2205,9 @@ let orgqr : type a. ?k:int -> a:(float, a) t -> tau:(float, a) t -> (float, a) t let _tau = bigarray_start Ctypes_static.Genarray tau in let ret = match _kind with +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> L.sorgqr ~layout ~m ~n:minmn ~k ~a:_a ~lda ~tau:_tau +#endif | Float32 -> L.sorgqr ~layout ~m ~n:minmn ~k ~a:_a ~lda ~tau:_tau | Float64 -> L.dorgqr ~layout ~m ~n:minmn ~k ~a:_a ~lda ~tau:_tau in @@ -2260,6 +2272,9 @@ let orgql : type a. ?k:int -> a:(float, a) t -> tau:(float, a) t -> (float, a) t let _tau = bigarray_start Ctypes_static.Genarray tau in let ret = match _kind with +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> L.sorgql ~layout ~m ~n:minmn ~k ~a:_a ~lda ~tau:_tau +#endif | Float32 -> L.sorgql ~layout ~m ~n:minmn ~k ~a:_a ~lda ~tau:_tau | Float64 -> L.dorgql ~layout ~m ~n:minmn ~k ~a:_a ~lda ~tau:_tau in @@ -2291,6 +2306,9 @@ let orgrq : type a. ?k:int -> a:(float, a) t -> tau:(float, a) t -> (float, a) t let _tau = bigarray_start Ctypes_static.Genarray tau in let ret = match _kind with +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> L.sorgrq ~layout ~m ~n:minmn ~k ~a:_a ~lda ~tau:_tau +#endif | Float32 -> L.sorgrq ~layout ~m ~n:minmn ~k ~a:_a ~lda ~tau:_tau | Float64 -> L.dorgrq ~layout ~m ~n:minmn ~k ~a:_a ~lda ~tau:_tau in @@ -2333,6 +2351,9 @@ let ormlq let _tau = bigarray_start Ctypes_static.Genarray tau in let ret = match _kind with +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> L.sormlq ~layout ~side ~trans ~m ~n ~k ~a:_a ~lda ~tau:_tau ~c:_c ~ldc +#endif | Float32 -> L.sormlq ~layout ~side ~trans ~m ~n ~k ~a:_a ~lda ~tau:_tau ~c:_c ~ldc | Float64 -> L.dormlq ~layout ~side ~trans ~m ~n ~k ~a:_a ~lda ~tau:_tau ~c:_c ~ldc in @@ -2371,6 +2392,9 @@ let ormqr let _tau = bigarray_start Ctypes_static.Genarray tau in let ret = match _kind with +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> L.sormqr ~layout ~side ~trans ~m ~n ~k ~a:_a ~lda ~tau:_tau ~c:_c ~ldc +#endif | Float32 -> L.sormqr ~layout ~side ~trans ~m ~n ~k ~a:_a ~lda ~tau:_tau ~c:_c ~ldc | Float64 -> L.dormqr ~layout ~side ~trans ~m ~n ~k ~a:_a ~lda ~tau:_tau ~c:_c ~ldc in @@ -2409,6 +2433,9 @@ let ormql let _tau = bigarray_start Ctypes_static.Genarray tau in let ret = match _kind with +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> L.sormql ~layout ~side ~trans ~m ~n ~k ~a:_a ~lda ~tau:_tau ~c:_c ~ldc +#endif | Float32 -> L.sormql ~layout ~side ~trans ~m ~n ~k ~a:_a ~lda ~tau:_tau ~c:_c ~ldc | Float64 -> L.dormql ~layout ~side ~trans ~m ~n ~k ~a:_a ~lda ~tau:_tau ~c:_c ~ldc in @@ -2447,6 +2474,9 @@ let ormrq let _tau = bigarray_start Ctypes_static.Genarray tau in let ret = match _kind with +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> L.sormrq ~layout ~side ~trans ~m ~n ~k ~a:_a ~lda ~tau:_tau ~c:_c ~ldc +#endif | Float32 -> L.sormrq ~layout ~side ~trans ~m ~n ~k ~a:_a ~lda ~tau:_tau ~c:_c ~ldc | Float64 -> L.dormrq ~layout ~side ~trans ~m ~n ~k ~a:_a ~lda ~tau:_tau ~c:_c ~ldc in @@ -3127,6 +3157,9 @@ let stev let _z = bigarray_start Ctypes_static.Genarray z in let ret = match _kind with +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> L.sstev ~layout ~jobz ~n ~d:_d ~e:_e ~z:_z ~ldz +#endif | Float32 -> L.sstev ~layout ~jobz ~n ~d:_d ~e:_e ~z:_z ~ldz | Float64 -> L.dstev ~layout ~jobz ~n ~d:_d ~e:_e ~z:_z ~ldz in @@ -3167,6 +3200,25 @@ let stebz let _isplit = bigarray_start Ctypes_static.Genarray isplit in let ret = match _kind with +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> + L.sstebz + ~range + ~order + ~n + ~vl + ~vu + ~il + ~iu + ~abstol + ~d:_d + ~e:_e + ~m:_m + ~nsplit:_nsplit + ~w:_w + ~iblock:_iblock + ~isplit:_isplit +#endif | Float32 -> L.sstebz ~range @@ -3784,6 +3836,9 @@ let syev : type a. jobz:char -> uplo:char -> a:(float, a) t -> (float, a) t * (f let lda = Stdlib.max 1 (_stride a) in let ret = match _kind with +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> L.ssyev ~layout ~jobz ~uplo ~n ~a:_a ~lda ~w:_w +#endif | Float32 -> L.ssyev ~layout ~jobz ~uplo ~n ~a:_a ~lda ~w:_w | Float64 -> L.dsyev ~layout ~jobz ~uplo ~n ~a:_a ~lda ~w:_w in @@ -3838,6 +3893,27 @@ let syevr let _isuppz = bigarray_start Ctypes_static.Genarray isuppz in let ret = match _kind with +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> + L.ssyevr + ~layout + ~jobz + ~range + ~uplo + ~n + ~a:_a + ~lda + ~vl + ~vu + ~il + ~iu + ~abstol + ~m:_m + ~w:_w + ~z:_z + ~ldz + ~isuppz:_isuppz +#endif | Float32 -> L.ssyevr ~layout @@ -3914,6 +3990,9 @@ let sygvd let _b = bigarray_start Ctypes_static.Genarray b in let ret = match _kind with +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> L.ssygvd ~layout ~ityp ~jobz ~uplo ~n ~a:_a ~lda ~b:_b ~ldb ~w:_w +#endif | Float32 -> L.ssygvd ~layout ~ityp ~jobz ~uplo ~n ~a:_a ~lda ~b:_b ~ldb ~w:_w | Float64 -> L.dsygvd ~layout ~ityp ~jobz ~uplo ~n ~a:_a ~lda ~b:_b ~ldb ~w:_w in @@ -4091,6 +4170,10 @@ let bdsdc let _iq = bigarray_start Ctypes_static.Genarray iq in let ret = match _kind with +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> + L.sbdsdc ~layout ~uplo ~compq ~n ~d:_d ~e:_e ~u:_u ~ldu ~vt:_vt ~ldvt ~q:_q ~iq:_iq +#endif | Float32 -> L.sbdsdc ~layout ~uplo ~compq ~n ~d:_d ~e:_e ~u:_u ~ldu ~vt:_vt ~ldvt ~q:_q ~iq:_iq | Float64 -> @@ -4181,6 +4264,9 @@ let orghr let lda = Stdlib.max 1 (_stride a) in let ret = match _kind with +#if OCAML_VERSION >= (5, 2, 0) + | Float16 -> L.sorghr ~layout ~n ~ilo ~ihi ~a:_a ~lda ~tau:_tau +#endif | Float32 -> L.sorghr ~layout ~n ~ilo ~ihi ~a:_a ~lda ~tau:_tau | Float64 -> L.dorghr ~layout ~n ~ilo ~ihi ~a:_a ~lda ~tau:_tau in