In [3]:
using LinearAlgebra
using Plots
import Base: getproperty, \, show

In [4]:

struct LU_Fac{T<:Real}
    lu::Matrix{T}
    p::Array{Int, 1}
end


function getproperty(F::LU_Fac, d::Symbol)
    
    if d === :L
        return UnitLowerTriangular(F.lu)
    elseif d === :U
        return UpperTriangular(F.lu)
    else
        getfield(F, d)
    end
end


function propertynames(F::LU_Fac, private::Bool=false)
    properties = (:L, :U)
    if private
        return (fieldnames(typeof(F))..., properties...)
    else
        return properties
    end
end

function show(io::IO, mime::MIME{Symbol("text/plain")}, F::LU_Fac)
    print(io, "L = ")
    show(io, mime, F.L)
    print(io, "\n\nU = ")
    show(io, mime, F.U)
end


show (generic function with 346 methods)

In [494]:

function _factorize!(
        A::AbstractMatrix,
        i::OrdinalRange, 
        j::OrdinalRange, 
        piv::AbstractArray{<:Integer, 1}
    )
    Aij = @view A[i, j]
    # n = min(size(Aij)...)
    n, m = size(Aij)
    s, t = size(A)
    
    for k = 1:m
        # find pivot element
        col_k = k + i.start - 1
        pivot = col_k
        max_elem = abs(A[pivot, k])
        for j = col_k+1:s
            row_elem = abs(A[j, col_k])
            if row_elem > max_elem
                max_elem = row_elem
                pivot = j
            end
        end
        # piv[k] = pivot
        if col_k != pivot
            piv[col_k], piv[pivot] = piv[pivot], piv[col_k]
            # change rows
            for j = 1:t
                A[pivot, j], A[col_k, j] = A[col_k, j], A[pivot, j]
            end
        end

        Akkinv = inv(Aij[k,k])
        l = @view(Aij[k+1:n,k])
        BLAS.scal!(length(l), Akkinv, l, 1)
        @views BLAS.ger!(-1.0, Aij[k+1:n,k], Aij[k,k+1:m], Aij[k+1:n,k+1:m])
    end
end

_last_block(n, b) = n%b == 0 ? n - b +1 : n - (n%b) + 1

function blocked_lu!(A::AbstractMatrix{T}, b::Integer) where T <: AbstractFloat
    n = min(size(A)...)
    piv = collect(1:n)
    last_block = _last_block(n, b)
    for i = 1:b:last_block-b
        k = i:i+b-1
        l = i+b:n
        _factorize!(A, i:n, k, piv)
        Lkk = @view A[k, k]
        Ukl = @view A[k, l]
        Llk = @view A[l, k]
        All = @view A[l,  l]
        # Aji = Lkk \ Aji
        BLAS.trsm!('L', 'L', 'N', 'U', 1.0, Lkk, Ukl)
        # All = All - Llk*Ukl
        BLAS.gemm!('N', 'N', -1.0, Llk, Ukl, 1.0, All)
    end
    _factorize!(A, last_block:n, last_block:n, piv)
    return LU_Fac{T}(
        A,
        piv
    )
end
function _factorize!(
        A::AbstractMatrix,
        i::OrdinalRange, 
        j::OrdinalRange, 
        piv::AbstractArray{<:Integer, 1}
    )
    Aij = @view A[i, j]
    # n = min(size(Aij)...)
    n, m = size(Aij)
    s, t = size(A)
    
    for k = 1:m
        # find pivot element
        col_k = k + i.start - 1
        pivot = col_k
        max_elem = abs(A[pivot, k])
        for j = col_k+1:s
            row_elem = abs(A[j, col_k])
            if row_elem > max_elem
                max_elem = row_elem
                pivot = j
            end
        end
        # piv[k] = pivot
        if col_k != pivot
            piv[col_k], piv[pivot] = piv[pivot], piv[col_k]
            # change rows
            for j = 1:t
                A[pivot, j], A[col_k, j] = A[col_k, j], A[pivot, j]
            end
        end

        Akkinv = inv(Aij[k,k])
        l = @view(Aij[k+1:n,k])
        BLAS.scal!(length(l), Akkinv, l, 1)
        @views BLAS.ger!(-1.0, Aij[k+1:n,k], Aij[k,k+1:m], Aij[k+1:n,k+1:m])
    end
end

_last_block(n, b) = n%b == 0 ? n - b +1 : n - (n%b) + 1

function blocked_lu!(A::AbstractMatrix{T}, b::Integer) where T <: AbstractFloat
    n = min(size(A)...)
    piv = collect(1:n)
    last_block = _last_block(n, b)
    for i = 1:b:last_block-b
        k = i:i+b-1
        l = i+b:n
        _factorize!(A, i:n, k, piv)
        Lkk = @view A[k, k]
        Ukl = @view A[k, l]
        Llk = @view A[l, k]
        All = @view A[l,  l]
        # Aji = Lkk \ Aji
        BLAS.trsm!('L', 'L', 'N', 'U', 1.0, Lkk, Ukl)
        # All = All - Llk*Ukl
        BLAS.gemm!('N', 'N', -1.0, Llk, Ukl, 1.0, All)
    end
    _factorize!(A, last_block:n, last_block:n, piv)
    return LU_Fac{T}(
        A,
        piv
    )
end

blocked_lu! (generic function with 1 method)

In [495]:


swap_rows(b::AbstractArray, piv::AbstractArray{<:Integer}) = b[piv]


swap_rows(b::AbstractMatrix, piv::AbstractArray{<:Integer}) = b[piv, :]


function _solve!(b::AbstractArray{T}, F::LU_Fac{T}) where T<:AbstractFloat
    BLAS.trsv!('L', 'N', 'U', F.lu, bp)
    BLAS.trsv!('U', 'N', 'N', F.lu, bp)
end


function _solve!(B::AbstractMatrix{T}, F::LU_Fac{T}) where T<:AbstractFloat
    BLAS.trsm!('L', 'L', 'N', 'U', 1.0, F.lu, B)
    BLAS.trsm!('L', 'U', 'N', 'N', 1.0, F.lu, B)
end


function lu_solve!(b::AbstractArray{T}, A::AbstractMatrix{T}) where T<:AbstractFloat
    lu = lu_factorization(A)
    bp = swap_rows(b, lu.p)
    _solve!(bp, lu)
end

function lu_solve!(b::AbstractArray{T}, lu::LU_Fac{T}) where T<:AbstractFloat
    bp = swap_rows(b, lu.p)
    _solve!(bp, lu)
end


\(A::LU_Fac, b::AbstractArray) = lu_solve(A, b);

In [501]:
A = [4 3 4 5 2 1
     3 4 8 1 2 7
     7 8 9 2 1 3
     7 1 2 3 4 5
     1 9 7 3 3 1
     6 5 6 1 2 9.]

n = 3000
A = rand(n,n)
# b = A * rand(n, n)
#display(A)
# @assert F.L * F.U ≈ A[F.p, :]
@time lu(A)
@time F = blocked_lu!(A, 200)

# @time lu_solve!(copy(b), F)
# F = lu(A)
# @time F\b

  0.549581 seconds (9 allocations: 68.688 MiB, 3.54% gc time)
  0.729231 seconds (12.06 k allocations: 777.203 KiB)


L = 3000×3000 UnitLowerTriangular{Float64,Array{Float64,2}}:
 1.0          ⋅           ⋅          …    ⋅         ⋅         ⋅          ⋅ 
 0.0115676   1.0          ⋅               ⋅         ⋅         ⋅          ⋅ 
 0.941237    0.743226    1.0              ⋅         ⋅         ⋅          ⋅ 
 0.0634882   0.0534254  -0.496469         ⋅         ⋅         ⋅          ⋅ 
 0.900316   -0.163845    0.443745         ⋅         ⋅         ⋅          ⋅ 
 0.940982    0.760262    0.371038    …    ⋅         ⋅         ⋅          ⋅ 
 0.21612     0.233918   -0.423736         ⋅         ⋅         ⋅          ⋅ 
 0.364049    0.866341    0.0976943        ⋅         ⋅         ⋅          ⋅ 
 0.215959    0.0738068  -0.139786         ⋅         ⋅         ⋅          ⋅ 
 0.243579   -0.0405707  -0.187653         ⋅         ⋅         ⋅          ⋅ 
 0.945041    0.738181    0.838442    …    ⋅         ⋅         ⋅          ⋅ 
 0.534606   -0.0264596  -0.270274         ⋅         ⋅         ⋅          ⋅ 
 0.061941    0.420782    0.

In [500]:
BLAS.set_num_threads(5)

In [486]:
A = [4 3 4 5 2 1
     3 4 8 1 2 7
     7 8 9 2 1 3
     7 1 2 3 4 5
     1 9 7 3 3 1
     6 5 6 1 2 9.]
F = LinearAlgebra.generic_lufact!(A[1:3, 1:3], Val(false))
L11 = F.L
U11 = F.U


U12 = L11 \ A[1:3, 4:end]
L21 = A[4:end, 1:3] / U11
A22 = A[4:end, 4:end]
S = A22 - L21*U12 
F2 = LinearAlgebra.generic_lufact!(copy(S), Val(false))

F3 = [F.factors U12; L21 F2.factors]
UnitLowerTriangular(F3) * UpperTriangular(F3) ≈ A
#F3

true

blocked_lu! (generic function with 1 method)

In [330]:
Bool(1)

true

In [481]:
f = 4:10

4:10

In [483]:
f.stop

10