Skip to content

Commit

Permalink
Added test for checking parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
theogf committed Mar 12, 2020
1 parent 4c8c816 commit 3377137
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 21 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Expand Up @@ -25,11 +25,12 @@ julia = "1.0"

[extras]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Random", "Test", "FiniteDifferences", "Zygote", "PDMats", "Kronecker"]
test = ["Random", "Test", "FiniteDifferences", "Zygote", "PDMats", "Kronecker", "Flux"]
37 changes: 17 additions & 20 deletions src/trainable.jl
@@ -1,42 +1,39 @@
using .Flux: trainable

Flux.trainable(::Kernel) = () # By default no parameters are returned
Flux.trainable(::Transform) = ()
import .Flux.trainable

### Base Kernels

Flux.trainable(k::ConstantKernel) = (k.c,)
trainable(k::ConstantKernel) = (k.c,)

Flux.trainable(k::GammaExponentialKernel) = (γ,)
trainable(k::GammaExponentialKernel) = (k.γ,)

Flux.trainable(k::GammaRationalQuadraticKernel) = (k.α, k.γ)
trainable(k::GammaRationalQuadraticKernel) = (k.α, k.γ)

Flux.trainable(k::MaternKernel) = (k.ν,)
trainable(k::MaternKernel) = (k.ν,)

Flux.trainable(k::LinearKernel) = (k.c,)
trainable(k::LinearKernel) = (k.c,)

Flux.trainable(k::PolynomialKernel) = (k.d, k.c)
trainable(k::PolynomialKernel) = (k.d, k.c)

Flux.trainable(k::RationalQuadraticKernel) = (k.α,)
trainable(k::RationalQuadraticKernel) = (k.α,)

#### Composite kernels

Flux.trainable::KernelProduct) = k.kernels
trainable::KernelProduct) = κ.kernels

Flux.trainable::KernelSum) =.weights, κ.kernels) #To check
trainable::KernelSum) =.weights, κ.kernels) #To check

Flux.trainable::ScaledKernel) =.σ, κ.kernel)
trainable::ScaledKernel) =.σ, κ.kernel)

Flux.trainable::TransformedKernel) =.transform, κ.kernel)
trainable::TransformedKernel) =.transform, κ.kernel)

### Transforms

Flux.trainable(t::ARDTransform) = (t.v,)
trainable(t::ARDTransform) = (t.v,)

Flux.trainable(t::ChainTransform) = t.transforms
trainable(t::ChainTransform) = t.transforms

Flux.trainable(t::FunctionTransform) = (t.f,)
trainable(t::FunctionTransform) = (t.f,)

Flux.trainable(t::LowRankTransform) = (t.proj,)
trainable(t::LowRankTransform) = (t.proj,)

Flux.trainable(t::ScaleTransform) = (t.s,)
trainable(t::ScaleTransform) = (t.s,)
47 changes: 47 additions & 0 deletions test/test_flux.jl
@@ -0,0 +1,47 @@
using KernelFunctions
using Test
using Flux

@testset "Params" begin
ν = 2.0; c = 3.0; d = 2.0; γ = 2.0; α = 2.5
kc = ConstantKernel(c=c)
@test all(params(kc) .== params([c]))
km = MaternKernel=ν)
@test all(params(km) .== params([ν]))
kl = LinearKernel(c=c)
@test all(params(kl) .== params([c]))
kge = GammaExponentialKernel=γ)
@test all(params(kge) .== params([γ]))
kgr = GammaRationalQuadraticKernel=γ, α=α)
@test all(params(kgr) .== params([α], [γ]))
kp = PolynomialKernel(c=c, d=d)
@test all(params(kp) .== params([d], [c]))
kr = RationalQuadraticKernel=α)
@test all(params(kr) .== params([α]))

k = km + kc
@test all(params(k) .== params([k.weights], km, kc))

k = km * kc
@test all(params(k) .== params(km, kc))

s = 2.0
k = transform(km, s)
@test all(params(k) .== params([s], km))

v = [2.0]
k = transform(kc, v)
@test all(params(k) .== params(v, kc))

P = rand(3, 2)
k = transform(km,LowRankTransform(P))
@test all(params(k) .== params(P, km))

k = transform(km, LowRankTransform(P) ScaleTransform(s))
@test all(params(k) .== params([s], P, km))

c = Chain(Dense(3, 2))
k = transform(km, FunctionTransform(c))
@test all(params(k) .== params(c, km))

end

0 comments on commit 3377137

Please sign in to comment.