Skip to content

Commit

Permalink
added her
Browse files Browse the repository at this point in the history
  • Loading branch information
ranocha committed Jun 26, 2018
1 parent 072e6f9 commit 05e1714
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/L2/L2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ include("sbmv.jl")
include("trmv.jl")
include("trsv.jl")
include("ger.jl")
include("her.jl")
2 changes: 1 addition & 1 deletion src/L2/ger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ for (func, elty) in [(:CLBlastSger, Float32), (:CLBlastDger, Float64),

@eval function ger!::Number, x::cl.CLArray{$elty}, y::cl.CLArray{$elty},
A::cl.CLArray{$elty,2};
queue::cl.CmdQueue=cl.queue(y))
queue::cl.CmdQueue=cl.queue(A))
# check and convert arguments
m, n = size(A)
if length(x) != m || length(y) != n
Expand Down
4 changes: 2 additions & 2 deletions src/L2/hemv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ for (func, elty) in [(:CLBlastChemv, Complex64), (:CLBlastZhemv, Complex128)]
# check and convert arguments
m, n = size(A)
if m != n
throw(DimensionMismatch("A has dimensions $(size(A)) but must be square."))
throw(DimensionMismatch("`A` has dimensions $(size(A)) but must be square."))
end
if length(x) != n || length(y) != n
throw(DimensionMismatch("A has dimensions $(size(A)), x has length $(length(x)) and y has length $(length(y))."))
throw(DimensionMismatch("`A` has dimensions $(size(A)), `x` has length $(length(x)) and `y` has length $(length(y))."))
end
if uplo == 'U'
triangle = CLBlastTriangleUpper
Expand Down
62 changes: 62 additions & 0 deletions src/L2/her.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@

for (func, elty, relty) in [(:CLBlastCher, Complex64, Float32), (:CLBlastZher, Complex128, Float64)]

@eval function $func(layout::CLBlastLayout, triangle::CLBlastTriangle,
n::Integer,
alpha::$relty,
x_buffer::cl.CL_mem, x_offset::Integer, x_inc::Integer,
a_buffer::cl.CL_mem, a_offset::Integer, a_ld::Integer,
queue::cl.CmdQueue, event::cl.Event)
err = ccall(
($(string(func)), libCLBlast),
cl.CL_int,
(Cint, Cint, Csize_t, $relty, Ptr{Void}, Csize_t, Csize_t, Ptr{Void}, Csize_t, Csize_t,
Ptr{Void}, Ptr{Void}),
Cint(layout), Cint(triangle), n, alpha, x_buffer, x_offset, x_inc,
a_buffer, a_offset, a_ld, Ref(queue), Ref(event)
)
if err != cl.CL_SUCCESS
println(STDERR, "Calling function $(string($func)) failed!")
throw(CLBlastError(err))
end
return err
end

@eval function her!(uplo::Char, α::Number, x::cl.CLArray{$elty},
A::cl.CLArray{$elty,2};
queue::cl.CmdQueue=cl.queue(A))
# check and convert arguments
m, n = size(A)
if m != n
throw(DimensionMismatch("`A` has dimensions $(size(A)) but must be square."))
end
if length(x) != n
throw(DimensionMismatch("`A` has dimensions $(size(A)) and `x` has length $(length(x))."))
end
if uplo == 'U'
triangle = CLBlastTriangleUpper
elseif uplo == 'L'
triangle = CLBlastTriangleLower
else
throw(ArgumentError("Upper/lower marker `uplo` is $(uplo) but only 'U' and 'L' are allowed."))
end
alpha = convert($relty, α)
layout = CLBlastLayoutColMajor

# output event
event = cl.Event(C_NULL)

$func(layout, triangle,
n,
alpha,
pointer(x), 0, 1,
pointer(A), 0, size(A,1),
queue, event)

# wait for kernel
cl.wait(event)

nothing
end

end
25 changes: 25 additions & 0 deletions test/L2_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,28 @@ end
@test_throws DimensionMismatch CLBlast.ger!(α, y_cl, x_cl, A_cl, queue=queue)
end
end

@testset "her!" begin
for elty in elty_L1
elty <: Complex || continue

A = rand(elty, n_L2, n_L2)
A_cl = cl.CLArray(queue, A)
x = rand(elty, n_L2)
x_cl = cl.CLArray(queue, x)
α = real(rand(elty))

is_linux() && elty == Complex64 && continue

for uplo in ['U', 'L']
CLBlast.her!(uplo, α, x_cl, A_cl, queue=queue)
LinAlg.BLAS.her!(uplo, α, x, A)
@test cl.to_host(A_cl, queue=queue) A
@test cl.to_host(x_cl, queue=queue) x
end

y = rand(elty, m_L2)
y_cl = cl.CLArray(queue, y)
@test_throws DimensionMismatch CLBlast.her!('U', α, y_cl, A_cl, queue=queue)
end
end

0 comments on commit 05e1714

Please sign in to comment.