Skip to content

Commit

Permalink
Merge 1d5b4fe into 2d846fe
Browse files Browse the repository at this point in the history
  • Loading branch information
theogf committed Mar 11, 2020
2 parents 2d846fe + 1d5b4fe commit 7dae3d5
Show file tree
Hide file tree
Showing 12 changed files with 70 additions and 10 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -16,6 +17,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
Compat = "2.2, 3"
Distances = "0.8"
MacroTools = "0.5"
Requires = "1.0.1"
SpecialFunctions = "0.8, 0.9, 0.10"
StatsBase = "0.32"
Expand Down
3 changes: 2 additions & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa
export transform
export params, duplicate, set! # Helpers

export Kernel
export Kernel, BaseKernel, @kernel
export ConstantKernel, WhiteKernel, ZeroKernel
export SqExponentialKernel, ExponentialKernel, GammaExponentialKernel
export ExponentiatedKernel
Expand Down Expand Up @@ -46,6 +46,7 @@ for k in ["exponential","matern","polynomial","constant","rationalquad","exponen
end
include("kernels/transformedkernel.jl")
include("kernels/scaledkernel.jl")
include("kernels/kernel_macro.jl")
include("matrix/kernelmatrix.jl")
include("kernels/kernelsum.jl")
include("kernels/kernelproduct.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ _scale(t::ScaleTransform, metric::Euclidean, x, y) = first(t.s) * evaluate(metr
_scale(t::ScaleTransform, metric::Union{SqEuclidean,DotProduct}, x, y) = first(t.s)^2 * evaluate(metric, x, y)
_scale(t::ScaleTransform, metric, x, y) = evaluate(metric, apply(t, x), apply(t, y))

printshifted(io::IO::Kernel,shift::Int) = print(io,"")
Base.show(io::IO::Kernel) = print(io,nameof(typeof(κ)))
printshifted(io::IO, κ::Kernel, shift::Int) = print(io, "")
Base.show(io::IO, κ::Kernel) = print(io, nameof(typeof(κ)))

### Syntactic sugar for creating matrices and using kernel functions
for k in subtypes(BaseKernel)
Expand Down
24 changes: 24 additions & 0 deletions src/kernels/kernel_macro.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using MacroTools: @capture

"""
"""
macro kernel(expr::Expr,arg=nothing)
@capture(expr,(scale_*k_ | k_)) || throw(error("@kernel first arguments should be of the form `σ*Kernel()` or `Kernel()`"))
t = if @capture(arg,kw_=val_)
if kw == :l
val
elseif kw == :t
val
else
throw(error("The additional argument could not be intepreted. Please see documentation of `@kernel`"))
end
else
arg
end
if isnothing(scale)
return esc(:(transform($k,$t)))
else
return esc(:($scale*transform($k,$t)))
end
end
11 changes: 8 additions & 3 deletions src/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ struct TransformedKernel{Tk<:Kernel,Tr<:Transform} <: Kernel
transform::Tr
end

function TransformedKernel(k::TransformedKernel,t::Transform)
TransformedKernel(kernel(k),tk.transform)
end
"""
```julia
transform(k::BaseKernel, t::Transform) (1)
Expand All @@ -15,11 +18,13 @@ end
"""
transform

transform(k::BaseKernel, t::Transform) = TransformedKernel(k, t)
transform(k::Kernel, t::Transform) = TransformedKernel(k, t)

transform(k::Kernel, ρ::Real) = TransformedKernel(k, ScaleTransform(ρ))

transform(k::BaseKernel, ρ::Real) = TransformedKernel(k, ScaleTransform(ρ))
transform(k::Kernel,ρ::AbstractVector) = TransformedKernel(k, ARDTransform(ρ))

transform(k::BaseKernel,ρ::AbstractVector) = TransformedKernel(k, ARDTransform(ρ))
transform(k::BaseKernel,::Nothing) = k

kernel(κ) = κ.kernel

Expand Down
2 changes: 2 additions & 0 deletions src/transform/ardtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,5 @@ apply(t::ARDTransform,x::AbstractVector{<:Real};obsdim::Int=defaultobs) = t.v .*
_transform(t::ARDTransform,X::AbstractMatrix{<:Real},obsdim::Int=defaultobs) = obsdim == 1 ? t.v'.*X : t.v .* X

Base.isequal(t::ARDTransform,t2::ARDTransform) = isequal(t.v,t2.v)

Base.show(io::IO, t::ARDTransform) = print(io,"ARD Transform, ρ = $(t.v)")
13 changes: 10 additions & 3 deletions src/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ params(t::ChainTransform) = (params.(t.transforms))
duplicate(t::ChainTransform,θ) = ChainTransform(duplicate.(t.transforms,θ))


Base.:(t₁::Transform,t₂::Transform) = ChainTransform([t₂,t₁])
Base.:(t::Transform,tc::ChainTransform) = ChainTransform(vcat(tc.transforms,t)) #TODO add test
Base.:(tc::ChainTransform,t::Transform) = ChainTransform(vcat(t,tc.transforms))
Base.:(t₁::Transform, t₂::Transform) = ChainTransform([t₂, t₁])
Base.:(t::Transform, tc::ChainTransform) = ChainTransform(vcat(tc.transforms, t)) #TODO add test
Base.:(tc::ChainTransform, t::Transform) = ChainTransform(vcat(t, tc.transforms))

function Base.show(io::IO, tc::ChainTransform)
print(io,"Chain Transform : $(first(tc.transforms))")
for t in tc.transforms[2:end]
print(io, " |> $t")
end
end
2 changes: 1 addition & 1 deletion src/transform/scaletransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ apply(t::ScaleTransform,x::AbstractVecOrMat;obsdim::Int=defaultobs) = first(t.s)

Base.isequal(t::ScaleTransform,t2::ScaleTransform) = isequal(first(t.s),first(t2.s))

Base.show(io::IO,t::ScaleTransform) = print(io,"Scale Transform s=$(first(t.s))")
Base.show(io::IO,t::ScaleTransform) = print(io,"Scale Transform, s = $(first(t.s))")
2 changes: 2 additions & 0 deletions src/transform/selecttransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,5 @@ function apply(t::SelectTransform, x::AbstractVector{<:Real}; obsdim::Int = defa
end

_transform(t::SelectTransform,X::AbstractMatrix{<:Real},obsdim::Int=defaultobs) = obsdim == 2 ? view(X,t.select,:) : view(X,:,t.select)

Base.show(io::IO, t::SelectTransform) = print(io, "Selected Dimensions : $(t.select)")
4 changes: 4 additions & 0 deletions src/transform/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ params(t::IdentityTransform) = nothing

apply(t::IdentityTransform, x; obsdim::Int=defaultobs) = x #TODO add test

Transform::Real) = ScaleTransform(ρ)
Transform::AbstractVector) = ARDTransform(ρ)
Transform(t::Transform) = t

### TODO Maybe defining adjoints could help but so far it's not working


Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using Random
include("test_kernelmatrix.jl")
include("test_approximations.jl")
include("test_constructors.jl")
include("test_macro.jl")
# include("test_AD.jl")
include("test_transform.jl")
include("test_distances.jl")
Expand Down
12 changes: 12 additions & 0 deletions test/test_macro.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using KernelFunctions
using Test

@testset "Kernel Macro" begin
@test (@kernel SqExponentialKernel()) isa SqExponentialKernel
@test (@kernel 3.0*SqExponentialKernel()) isa ScaledKernel{SqExponentialKernel,Float64}
@test (@kernel 3.0*SqExponentialKernel() l=3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64}
@test (@kernel 3.0*SqExponentialKernel() 3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64}
@test (@kernel 3.0*SqExponentialKernel() l=[3.0]) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ARDTransform{Float64,1}},Float64}
@test (@kernel 3.0*SqExponentialKernel() LowRankTransform(rand(3,2))) isa ScaledKernel{TransformedKernel{SqExponentialKernel,LowRankTransform{Array{Float64,2}}},Float64}
@test (@kernel (3.0*SqExponentialKernel()+5.0*Matern32Kernel()) 3.0) isa TransformedKernel{KernelSum,ScaleTransform{Float64}}
end

0 comments on commit 7dae3d5

Please sign in to comment.