Skip to content

Commit

Permalink
Minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Tim Thatcher committed Mar 6, 2016
1 parent 4ab669f commit 1dec013
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 73 deletions.
74 changes: 1 addition & 73 deletions src/qda.jl
Expand Up @@ -40,7 +40,7 @@ function qda!{T<:BlasReal,U<:Integer}(
γ::Nullable{T}
)
H = center_classes!(X, M, y)
isnull(λ) ? class_whiteners!(H, y, k, γ, get(λ)) : class_whiteners!(H, y, k, γ)
isnull(λ) ? class_whiteners!(H, y, γ) : class_whiteners!(H, y, γ, get(λ))
end

doc"`qda(X, y; M, lambda, gamma, priors)` Fits a regularized quadratic discriminant model to the
Expand Down Expand Up @@ -83,7 +83,6 @@ function discriminants_qda{T<:BlasReal}(
end
end
δ

end

doc"`discriminants(Model, Z)` Uses `Model` on input `Z` to product the class discriminants."
Expand All @@ -95,74 +94,3 @@ doc"`classify(Model, Z)` Uses `Model` on input `Z`."
function classify{T<:BlasReal}(mod::ModelQDA{T}, Z::Matrix{T})
mapslices(indmax, discriminants(mod, Z), 2)
end

# Create an array of class scatter matrices
# H is centered data matrix (with respect to class means)
# y is one-based vector of class IDs
#=
function class_covariances{T<:BlasReal,U<:Integer}(H::Matrix{T}, y::Vector{U},
n_k::Vector{Int64} = class_counts(y))
k = length(n_k)
Σ_k = Array(Array{T,2}, k) # Σ_k[i] = H_i'H_i/(n_i-1)
for i = 1:k
Σ_k[i] = BLAS.syrk!('U', 'T', one(T)/(n_k[i]-1), H[y .== i,:], zero(T), Array(T,p,p))
symml!(Σ_k[i])
end
Σ_k
end
=#




# Use eigendecomposition to generate class whitening transform
# Σ_k is array of references to each Σ_i covariance matrix
# λ is regularization parameter in [0,1]. λ = 0 is no regularization.
#=
function class_whiteners!{T<:BlasReal}(Σ_k::Vector{Matrix{T}}, γ::T)
for i = 1:length(Σ_k)
tol = eps(T) * prod(size(Σ_k[i])) * maximum(Σ_k[i])
Λ_i, V_i = LAPACK.syev!('V', 'U', Σ_k[i]) # Overwrite Σ_k with V such that VΛVᵀ = Σ_k
if γ > 0
λ_avg = mean(Λ_i) # Shrink towards average eigenvalue
for j = 1:length(Λ_i)
Λ_i[j] = (1-γ)*Λ_i[j] + γ*λ_avg # Σ = VΛVᵀ => (1-γ)Σ + γI = V((1-γ)Λ + γI)Vᵀ
end
end
all(Λ_i .>= tol) || error("Rank deficiency detected in class $i with tolerance $tol.")
scale!(V_i, one(T) ./ sqrt(Λ_i)) # Scale V so it whitens H*V where H is centered X
end
Σ_k
end
=#

# Fit regularized quadratic discriminant model. Returns whitening matrices for all classes.
# X in uncentered data matrix
# M is matrix of class means (one per row)
# y is one-based vector of class IDs
# λ is regularization parameter in [0,1]. λ = 0 is no regularization. See documentation.
# γ is shrinkage parameter in [0,1]. γ = 0 is no shrinkage. See documentation.
#=
function qda!{T<:BlasReal,U<:Integer}(X::Matrix{T}, M::Matrix{T}, y::Vector{U}, λ::T, γ::T)
k = maximum(y)
n_k = class_counts(y, k)
n, p = size(X)
H = center_classes!(X, M, y)
w_σ = 1 ./ vec(sqrt(var(H, 1))) # Variance normalizing factor for columns of H
scale!(H, w_σ)
Σ_k = class_covariances(H, y, n_k)
if λ > 0
Σ = scale!(H'H, one(T)/(n-1))
for i = 1:k
regularize!(Σ_k[i], λ, Σ)
end
end
W_k = class_whiteners!(Σ_k, γ)
for i = 1:k
scale!(w_σ, W_k[i]) # scale rows of W_k
end
W_k
end
=#


15 changes: 15 additions & 0 deletions test/test_qda.jl
Expand Up @@ -46,3 +46,18 @@ for T in FloatingPointTypes
end
end
end

info("Testing ", MOD.qda!)
for T in FloatingPointTypes
X_tmp = copy(convert(Matrix{T}, X))
M_tmp = convert(Matrix{T}, M)
H_tmp = convert(Matrix{T}, H)

W_k = MOD.qda!(copy(X_tmp), copy(M_tmp), y, Nullable{T}(), Nullable{T}())
W_k_tmp = MOD.class_whiteners!(copy(H_tmp), y, Nullable{T}())

for i in eachindex(W_k)
@test_approx_eq W_k[i] W_k_tmp[i]
end

end

0 comments on commit 1dec013

Please sign in to comment.