In [1]:
using Pkg

# Load package environment.
Pkg.activate(".")
Pkg.instantiate()

[32m[1mActivating[22m[39m environment at `~/.julia/dev/ConvCNP/Project.toml`


In [2]:
using Revise
using Flux
using Flux.Tracker
using Statistics
using Printf
using Random

In [111]:
struct SetConv{T<:AbstractVector{<:Real}}
    log_scales::T
    density::Bool
end 

@Flux.treelike SetConv

ErrorException: invalid redefinition of constant SetConv

In [112]:
"""
What do each of these parameters mean?
"""
function set_conv(in_channels::Int, scale::Float64; density::Bool=true)
    # Add one to `in_channels` to account for the density channel.
    density && (in_channels += 1)
    return SetConv(param(log.(scale .* ones(in_channels)), density)
end

LoadError: syntax: missing comma or ) in argument list

In [113]:
rbf(dists2::Real) = exp(-0.5 * dists2)

compute_dists2(x, y) = compute_dists2(x, y, Val(size(x, 2)))

compute_dists2(x, y, ::Val{1}) = (x .- permutedims(y, (2, 1, 3))).^2

function compute_dists2(x, y, d::Val)
    y = permutedims(y, (2, 1, 3))
    return sum(x.^2; dims=2) .+ sum(y.^2; dims=1) .- 2 .* batched_mul(x, y)
end

compute_dists2 (generic function with 3 methods)

In [114]:
uti

layer

In [115]:
abstract type Discretisation end

# Doesn't need to be performant.
struct UniformDiscretisation1d <: Discretisation
    points_per_unit::Real
    margin::Real
    multiple::Integer
end

function (d::UniformDiscretisation1d)(xs::AbstractArray...)
    x = cat(xs...; dims=1)
    range_lower = minimum(x) - d.margin
    range_upper = maximum(x) + d.margin
    num_points = Integer(round((range_upper - range_lower) * d.points_per_unit))
    num_points = Integer(ceil(num_points / d.multiple) * d.multiple)
    disc = collect(range(range_lower, range_upper; length=num_points))
    return repeat(disc, 1, 1, size(x, 3))
end

ErrorException: invalid redefinition of constant UniformDiscretisation1d

In [116]:
struct ConvCNP
    discretisation::Discretisation
    encoder::SetConv
    conv::Chain
    decoder::SetConv
end

@Flux.treelike ConvCNP

In [117]:
function (model::ConvCNP)(
    x_context::AbstractArray,
    y_context::AbstractArray,
    x_target::AbstractArray
)
    x_discretisation = model.discretisation(x_context, x_target)
    encoding = model.encoder(x_context, y_context, x_discretisation)
    latent = model.conv(encoding)
    if size(encoding, 1) != size(latent, 1)
        error("Conv net changed the discretisation size from $(size(encoding, 1)) to $(size(latent, 1)).")
    end
    return model.decoder(x_discretisation, latent, x_target)
end

In [118]:
function convcnp_1d(conv; points_per_unit=64, margin=0.1, multiple=1)
    scale = 2 / points_per_unit
    return ConvCNP(
        UniformDiscretisation1d(points_per_unit, margin, multiple),
        set_conv(1, scale; density=true),
        conv,
        set_conv(1, scale; density=false)
    )
end

convcnp_1d (generic function with 2 methods)

In [119]:
round_odd(x) = Integer(ceil((x - 1) / 2) * 2 + 1)

function kernel_size(receptive_field, points_per_unit, num_layers)
    receptive_points = receptive_field * points_per_unit
    return round_odd(1 + (receptive_points - 1) / num_layers)
end

padding(kernel_size) = Integer(floor(kernel_size / 2))

function get_conv(
    receptive_field,
    num_layers,
    num_channels;
    points_per_unit=64,
    in_channels=2,
    out_channels=1,
    dimensionality=1,
)
    k = kernel_size(receptive_field, points_per_unit, num_layers)
    p = padding(k)
    
    # Repeat the kernel size `dimensionality` many times to construct the convolution kernel.
    ks = ntuple(_ -> k, dimensionality)
    
    # Build layers of the conv net.
    layers = []
    push!(layers, Conv(ks, in_channels=>num_channels, pad=p, relu))
    for i = 1:(num_layers - 2)
        push!(layers, DepthwiseConv(ks, num_channels=>num_channels, pad=p, relu))
    end
    push!(layers, Conv(ks, num_channels=>out_channels, pad=p, relu))
    
    return Chain(layers...)
end

get_conv (generic function with 2 methods)

In [122]:
model = convcnp_1d(get_conv(1.0, 5, 3))

ConvCNP(UniformDiscretisation1d(64.0, 0.1, 1), SetConv{TrackedArray{…,Array{Float64,1}}}([-3.4657359027997265, -3.4657359027997265] (tracked), true), Chain(Conv((15,), 2=>3, relu), DepthwiseConv((15,), 3=>3, relu), DepthwiseConv((15,), 3=>3, relu), DepthwiseConv((15,), 3=>3, relu), Conv((15,), 3=>1, relu)), SetConv{TrackedArray{…,Array{Float64,1}}}([-3.4657359027997265] (tracked), false))

In [124]:
rng = MersenneTwister(2)
x = randn(rng, 5, 1, 2)
y = randn(rng, 5, 1, 2)
t = randn(rng, 13, 1, 2)

model(x, y, t)

Tracked 13×1×2 Array{Float64,3}:
[:, :, 1] =
 0.0023353587010365197
 0.009421277296582764 
 0.020942118548850825 
 5.4561536389380815e-6
 1.3718779582419487e-7
 0.0006910548731020154
 0.0008483649562197064
 0.0005137920800159766
 6.774942520272679e-6 
 4.35416908377211e-8  
 0.09686607064037359  
 0.41040594661470764  
 0.009194023003166501 

[:, :, 2] =
 0.13561144738673922   
 0.01947123478509225   
 4.692895089595109e-7  
 0.01559701320221661   
 2.0419223134600564e-10
 0.041937200775655215  
 1.6467060419001765e-6 
 0.0005038522616987228 
 9.037252345179225e-22 
 0.0005148482299264626 
 0.79518790529436      
 6.153596180833405e-7  
 0.006544258053640553  