Skip to content

Commit

Permalink
Support non-Diagonal diagonal-like types in fmul_shared!
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Feb 11, 2019
1 parent 3ec0493 commit d45035a
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 19 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SIMD = "2.3"

[extras]
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Parameters", "Test"]
2 changes: 2 additions & 0 deletions src/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ simdable(T::Type{<:SparseMatrixCSC{Tv,Ti}}) where {Tv,Ti} = allsimdable(Tv, Ti)
allsimdable() = true
allsimdable(x, args...) = simdable(x) && allsimdable(args...)

const DiagonalLike = Union{Diagonal,UniformScaling,Number}

asdiag(A::Diagonal, n) = A.diag
asdiag(A::UniformScaling, n) = asdiag(A.λ, n)
asdiag(a::Number, n) = Fill(a, n, n)
Expand Down
47 changes: 44 additions & 3 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,11 @@ julia> fmul_shared!((Y1, Y2), (D1, S1'), (D2, S2'), X1) === (Y1, Y2)
true
```
"""
@inline function fmul_shared!(Yβ_, rhs...)
@inline function fmul_shared!(Yβ_, rhs_...)
check_fmul_shared_args(Yβ_, rhs_)
= canonicalize_Yβ(Yβ_)
if length(rhs) == 0
elseif is_shared_simd3(rhs)
rhs = canonicalize_rhs(rhs_)
if is_shared_simd3(rhs)
return fmul_shared_simd3!(Yβ, rhs...)
elseif is_shared_simd2(rhs)
return _fmul_shared_simd!(Val(4), Yβ, butlast(rhs), rhs[end])
Expand All @@ -219,6 +220,46 @@ end
# It would be nice to have some computation graph exectuor on top of
# fmul*!, but it can be done later.

@inline function check_fmul_shared_args(Yβ, rhs)
if length(rhs) == 0
throw(ArgumentError("fmul_shared! needs one or more `rhs` arguments"))
end

# Fan-out case
ifisa Union{Tuple{Vararg{AbstractMatrix}},
Tuple{Vararg{Tuple{AbstractMatrix,Number}}}}
if rhs[end] isa AbstractVecOrMat
if length(Yβ) != length(rhs) - 1
throw(ArgumentError("""
Detected call signature:
fmul_shared!((Yβ1, ..., Yβm), (D1, S1'), ..., (Dn, Sn'), X)
with `m = $(length(Yβ))` and `n = $(length(rhs) - 1)`. Note that `m` and
`n` must match."""))
end
else
if length(Yβ) != length(rhs)
throw(ArgumentError("""
Detected call signature:
fmul_shared!((Yβ1, ..., Yβm), (D1, S1', X1), ..., (Dn, Sn', Xn))
with `m = $(length(Yβ))` and `n = $(length(rhs))`. Note that `m` and
`n` must match."""))
end
end
end

# TODO: check array size

return
end

@inline canonicalize_rhs(rhs) = map(canonicalize_term, rhs)

@inline canonicalize_term(term) = term
@inline function canonicalize_term(DSX::Tuple{DiagonalLike,Any,Vararg})
D, S = DSX
return (Diagonal(asdiag(D, size(S, 1))), Base.tail(DSX)...)
end

@inline canonicalize_Yβ(Yβ::Tuple{AbstractMatrix,Number}) =
@inline canonicalize_Yβ(Y::AbstractMatrix) = (Y, false)

Expand Down
1 change: 1 addition & 0 deletions test/preamble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ using Test
using SparseArrays
using LinearAlgebra
using Random
using Parameters
using SparseXX
66 changes: 51 additions & 15 deletions test/test_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,40 +42,80 @@ end
@test C D * C0
end

@testset "fmul_shared!" begin
fmul_shared_test_params = let params = []
m = 10
n = 3
D1 = Diagonal(randn(m))
D2 = Diagonal(randn(m))
D3 = Diagonal(randn(m))

S1 = sprandn(m, m, 0.3)
S2 = spshared(S1)
S3 = spshared(S1)
randn!(nonzeros(S2))
randn!(nonzeros(S3))
push!(params, (
label = "default",
D1 = Diagonal(randn(m)),
D2 = Diagonal(randn(m)),
D3 = Diagonal(randn(m)),
S1 = S1,
S2 = S2,
S3 = S3,
X1 = randn(m, n),
X2 = randn(m, n),
X3 = randn(m, n),
))

S1 = sprandn(m, m, 0.3)
S2 = spshared(S1)
S3 = spshared(S1)
randn!(nonzeros(S2))
randn!(nonzeros(S3))
X1 = randn(m, n)
X2 = randn(m, n)
X3 = randn(m, n)
push!(params, (
label = "mixed D type",
D1 = randn(),
D2 = randn() * I,
D3 = Diagonal(randn(m)),
S1 = S1,
S2 = S2,
S3 = S3,
X1 = randn(m, n),
X2 = randn(m, n),
X3 = randn(m, n),
))

params
end

@testset "is_shared_simd" begin
@unpack D1, D2, D3, S1, S2, S3, X1, X2, X3 = fmul_shared_test_params[1]

@test SparseXX.is_shared_simd3(((D1, S1', X1), (D2, S2', X2)))
@test SparseXX.is_shared_simd2(((D1, S1'), (D2, S2'), X1))
@test SparseXX.is_shared_simd3(((D1, S1', X1),
(D2, S2', X2),
(D3, S3', X3)))
@test SparseXX.is_shared_simd2(((D1, S1'),
(D2, S2'),
(D3, S3'),
X1))
end

@testset "fmul_shared! $(p.label)" for p in fmul_shared_test_params
@unpack D1, D2, D3, S1, S2, S3, X1, X2, X3 = p

Y = fmul_shared!(zero(X1), (D1, S1', X1), (D2, S2', X2))
@test Y D1 * S1' * X1 + D2 * S2' * X2

Y1, Y2 = fmul_shared!((zero(X1), zero(X1)), (D1, S1', X1), (D2, S2', X2))
@test Y1 D1 * S1' * X1
@test Y2 D2 * S2' * X2

@test SparseXX.is_shared_simd2(((D1, S1'), (D2, S2'), X1))
Y = fmul_shared!(zero(X1), (D1, S1'), (D2, S2'), X1)
@test Y (D1 * S1' + D2 * S2') * X1

Y1, Y2 = fmul_shared!((zero(X1), zero(X1)), (D1, S1'), (D2, S2'), X1)
@test Y1 D1 * S1' * X1
@test Y2 D2 * S2' * X1

@test SparseXX.is_shared_simd3(((D1, S1', X1),
(D2, S2', X2),
(D3, S3', X3)))
Y = fmul_shared!(zero(X1),
(D1, S1', X1),
(D2, S2', X2),
Expand All @@ -90,10 +130,6 @@ end
@test Y2 D2 * S2' * X2
@test Y3 D3 * S3' * X3

@test SparseXX.is_shared_simd2(((D1, S1'),
(D2, S2'),
(D3, S3'),
X1))
Y = fmul_shared!(zero(X1),
(D1, S1'),
(D2, S2'),
Expand Down

0 comments on commit d45035a

Please sign in to comment.