Skip to content

Commit

Permalink
added gemv!
Browse files Browse the repository at this point in the history
  • Loading branch information
ranocha committed Jun 25, 2018
1 parent 30cf7eb commit ecb9763
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/L2/L2.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include("gemv.jl")
138 changes: 138 additions & 0 deletions src/L2/gemv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@

for (func, elty) in [(:CLBlastSgemv, Float32), (:CLBlastDgemv, Float64),
(:CLBlastCgemv, Complex64), (:CLBlastZgemv, Complex128)]
#TODO: (:CLBlastHgemv, Float16)

@eval function $func(layout::CLBlastLayout, a_transpose::CLBlastTranspose,
m::Integer, n::Integer,
alpha::$elty,
a_buffer::cl.CL_mem, a_offset::Integer, a_ld::Integer,
x_buffer::cl.CL_mem, x_offset::Integer, x_inc::Integer,
beta::$elty,
y_buffer::cl.CL_mem, y_offset::Integer, y_inc::Integer,
queue::cl.CmdQueue, event::cl.Event)
err = ccall(
($(string(func)), libCLBlast),
cl.CL_int,
(Cint, Cint, Csize_t, Csize_t, $elty, Ptr{Void}, Csize_t, Csize_t, Ptr{Void}, Csize_t, Csize_t,
$elty, Ptr{Void}, Csize_t, Csize_t, Ptr{Void}, Ptr{Void}),
Cint(layout), Cint(a_transpose), m, n, alpha, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc,
beta, y_buffer, y_offset, y_inc, Ref(queue), Ref(event)
)
if err != cl.CL_SUCCESS
println(STDERR, "Calling function $(string($func)) failed!")
throw(CLBlastError(err))
end
return err
end

@eval function gemv!(tA::Char, α::Number, A::cl.CLArray{$elty,2},
x::cl.CLArray{$elty}, β::Number, y::cl.CLArray{$elty};
queue::cl.CmdQueue=cl.queue(y))
# check and convert arguments
m, n = size(A)
if tA == 'N'
a_transpose = CLBlastTransposeNo
if length(x) != n || length(y) != m
throw(DimensionMismatch("A has dimensions $(size(A)), x has length $(length(x)) and y has length $(length(y))."))
end
elseif tA == 'T' || tA == 'C' && $elty <: Real
a_transpose = CLBlastTransposeYes
if length(x) != m || length(y) != n
throw(DimensionMismatch("The adjoint of A has dimensions $n, $m, x has length $(length(x)) and y has length $(length(y))"))
end
elseif tA == 'C' && $elty <: Complex
a_transpose = CLBlastTransposeConjugate
if length(x) != m || length(y) != n
throw(DimensionMismatch("The transpose of A has dimensions $n, $m, x has length $(length(x)) and y has length $(length(y))."))
end
else
throw(ArgumentError("Transpose marker `tA` is $(tA) but only 'N', 'T', and 'C' are allowed."))
end
alpha = convert($elty, α)
beta = convert($elty, β)
layout = CLBlastLayoutColMajor

# output event
event = cl.Event(C_NULL)

$func(layout, a_transpose,
m, n,
alpha,
pointer(A), 0, m,
pointer(x), 0, 1,
beta,
pointer(y), 0, 1,
queue, event)

# wait for kernel
cl.wait(event)

nothing
end

end
#=
function CLBlastSgemv(layout::CLBlastLayout, a_transpose::CLBlastTranspose,
m::Integer, n::Integer,
alpha::Float32,
a_buffer::cl.CL_mem, a_offset::Integer, a_ld::Integer,
x_buffer::cl.CL_mem, x_offset::Integer, x_inc::Integer,
beta::Float32,
y_buffer::cl.CL_mem, y_offset::Integer, y_inc::Integer,
queue::cl.CmdQueue, event::cl.Event)
err = ccall(
("CLBlastSgemv", libCLBlast),
cl.CL_int,
(Cint, Cint, Csize_t, Csize_t, Float32, Ptr{Void}, Csize_t, Csize_t, Ptr{Void}, Csize_t, Csize_t,
Float32, Ptr{Void}, Csize_t, Csize_t, Ptr{Void}, Ptr{Void}),
Cint(layout), Cint(a_transpose), m, n, alpha, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc,
beta, y_buffer, y_offset, y_inc, Ref(queue), Ref(event)
)
if err != cl.CL_SUCCESS
println(STDERR, "Calling function CLBlastSgemv failed!")
throw(cl.CLError(err))
end
return err
end
function gemv!(tA::Char, α::Number, A::cl.CLArray{Float32,2},
x::cl.CLArray{Float32}, β::Number, y::cl.CLArray{Float32};
queue::cl.CmdQueue=cl.queue(y))
# check and convert arguments
m, n = size(A)
if tA == 'N'
a_transpose = CLBlastTransposeNo
if length(x) != n || length(y) != m
#throw(DimensionMismatch("A has dimensions $(size(A)), x has length $(length(x)) and y has length $(length(y))."))
end
elseif tA == 'T' || tA == 'C'
a_transpose = CLBlastTransposeYes
if length(x) != m || length(y) != n
#throw(DimensionMismatch("The adjoint of A has dimensions $n, $m, x has length $(length(x)) and y has length $(length(y))"))
end
else
#throw(ArgumentError("Transpose marker `tA` is $(tA) but only 'N', 'T', and 'C' are allowed."))
end
alpha = convert(Float32, α)
beta = convert(Float32, β)
layout = CLBlastLayoutRowMajor
# output event
event = cl.Event(C_NULL)
CLBlastSgemv(layout, a_transpose,
m, n,
alpha,
pointer(A), 0, 1,
pointer(x), 0, 1,
beta,
pointer(y), 0, 1,
queue, event)
# wait for kernel
cl.wait(event)
nothing
end
=#
37 changes: 37 additions & 0 deletions test/L2_test.jl
Original file line number Diff line number Diff line change
@@ -1 +1,38 @@
srand(12345)

@testset "gemv!" begin
for elty in elty_L1
A = rand(elty, m_L2, n_L2)
A_cl = cl.CLArray(queue, A)
x = rand(elty, n_L2)
x_cl = cl.CLArray(queue, x)
y = rand(elty, m_L2)
y_cl = cl.CLArray(queue, y)
α = rand(elty)
β = rand(elty)

is_linux() && elty == Complex64 && continue

@test_throws DimensionMismatch CLBlast.gemv!('T', α, A_cl, x_cl, β, y_cl, queue=queue)
@test_throws DimensionMismatch CLBlast.gemv!('C', α, A_cl, x_cl, β, y_cl, queue=queue)
CLBlast.gemv!('N', α, A_cl, x_cl, β, y_cl, queue=queue)
LinAlg.BLAS.gemv!('N', α, A, x, β, y)
@test cl.to_host(A_cl, queue=queue) A
@test cl.to_host(x_cl, queue=queue) x
@test cl.to_host(y_cl, queue=queue) y

@test_throws DimensionMismatch CLBlast.gemv!('N', α, A_cl, y_cl, β, x_cl, queue=queue)
CLBlast.gemv!('T', α, A_cl, y_cl, β, x_cl, queue=queue)
LinAlg.BLAS.gemv!('T', α, A, y, β, x)
@test cl.to_host(A_cl, queue=queue) A
@test cl.to_host(x_cl, queue=queue) x
@test cl.to_host(y_cl, queue=queue) y

@test_throws DimensionMismatch CLBlast.gemv!('N', α, A_cl, y_cl, β, x_cl, queue=queue)
CLBlast.gemv!('C', α, A_cl, y_cl, β, x_cl, queue=queue)
LinAlg.BLAS.gemv!('C', α, A, y, β, x)
@test cl.to_host(A_cl, queue=queue) A
@test cl.to_host(x_cl, queue=queue) x
@test cl.to_host(y_cl, queue=queue) y
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ end
using CLBlast, OpenCL

const n_L1 = 64
const m_L2 = 60
const n_L2 = 50
const elty_L1 = (Float32, Float64, Complex64, Complex128)

device, ctx, queue = cl.create_compute_context()
Expand Down

0 comments on commit ecb9763

Please sign in to comment.