Skip to content

Commit

Permalink
bug fix, fuller coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
gabgoh committed Aug 26, 2016
1 parent 55112d0 commit 4c470e0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 20 deletions.
16 changes: 8 additions & 8 deletions src/SymWoodburyMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ function plusBDBtx!(o, B::Array{Float64,2}, D, x::Array{Float64,2})
end

# Minor optimization for the rank one case
function plusBDBtx!(o, B::Array{Float64,1}, d::Real, x::Array{Float64,2})
if size(x,1) == 1
axpy!(dot(length(x),B,1,x,1)*d, B, o)
function plusBDBtx!(o, B::Array{Float64,1}, d::Real, x::Union{Array{Float64,2}, SubArray})
if size(x,2) == 1
axpy!(vecdot(B,x)*d, B, o)
else
w = d*gemm('T', 'N' ,reshape(B, size(B,1), 1),x);
gemm!('N','N',1.,B,w,1., o)
Expand All @@ -144,12 +144,12 @@ Base.full{T}(O::SymWoodbury{T}) = full(O.A) + O.B*O.D*O.B'
Base.copy{T}(O::SymWoodbury{T}) = SymWoodbury(copy(O.A), copy(O.B), copy(O.D))

function square(O::SymWoodbury)
A = O.A^2;
AB = O.A*O.B;
Z = [(AB + O.B) (AB - O.B)];
R = O.D*(O.B'*O.B)*O.D/4;
A = O.A^2
AB = O.A*O.B
Z = [(AB + O.B) (AB - O.B)]
R = O.D*(O.B'*O.B)*O.D/4
D = [ O.D/2 + R -R
-R -O.D/2 + R ];
-R -O.D/2 + R ]
SymWoodbury(A, Z, D)
end

Expand Down
29 changes: 17 additions & 12 deletions test/runtests_sym.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Base.Test
using WoodburyMatrices
using Compat:view

srand(123)
n = 5
Expand Down Expand Up @@ -44,19 +45,22 @@ for elty in (Float32, Float64, Complex64, Complex128, Int), AMat in (diagm,)
Z = randn(n,n)
@test_approx_eq full(W*Z) full(W)*Z

v = rand(n, 1)

R = rand(n,n)
@test_approx_eq (2*W)*v 2*(W*v)
@test_approx_eq (W*2)*v 2*(W*v)
@test_approx_eq (W'W)*v full(W)*(full(W)*v)
@test_approx_eq (W*W)*v full(W)*(full(W)*v)
@test_approx_eq (W*W')*v full(W)*(full(W)*v)
@test_approx_eq W[1:3,1:3]*v[1:3] full(W)[1:3,1:3]*v[1:3]
@test_approx_eq full(WoodburyMatrices.conjm(W, R)) R*full(W)*R'
@test_approx_eq full((copy(W)'W)*v) full(W)*(full(W)*v)
@test_approx_eq full(W + A) full(W)+full(A)
@test_approx_eq full(A + W) full(W)+full(A)

for v = (rand(n, 1), view(rand(n,1), 1:n), view(rand(n,2),1:n,1:2))
@test_approx_eq (2*W)*v 2*(W*v)
@test_approx_eq (W*2)*v 2*(W*v)
@test_approx_eq (W'W)*v full(W)*(full(W)*v)
@test_approx_eq (W*W)*v full(W)*(full(W)*v)
@test_approx_eq (W*W')*v full(W)*(full(W)*v)
@test_approx_eq W[1:3,1:3]*v[1:3] full(W)[1:3,1:3]*v[1:3]
@test_approx_eq full(WoodburyMatrices.conjm(W, R)) R*full(W)*R'
@test_approx_eq full((copy(W)'W)*v) full(W)*(full(W)*v)
@test_approx_eq full(W + A) full(W)+full(A)
@test_approx_eq full(A + W) full(W)+full(A)
end

v = rand(n,1)

W2 = convert(Woodbury, W)
@test_approx_eq full(W2) full(W)
Expand Down Expand Up @@ -154,3 +158,4 @@ V = randn(n,1)
# Mismatched sizes
@test_throws DimensionMismatch SymWoodbury(rand(5,5),rand(5,2),rand(2,3))
@test_throws DimensionMismatch SymWoodbury(rand(5,5),rand(5,2),rand(3,3))
@test_throws DimensionMismatch SymWoodbury(rand(5,5),rand(3),1.)

0 comments on commit 4c470e0

Please sign in to comment.