In [1]:
using Pkg

# Load package environment.
Pkg.activate(".")
Pkg.instantiate()

[32m[1mActivating[22m[39m environment at `~/.julia/dev/ConvCNP/Project.toml`


In [2]:
using Flux
using Flux.Tracker
using Statistics
using Printf

In [3]:
struct ConvDeepSet{T}
    log_scales::T
end 

In [55]:
function conv_deep_set(in::Int, scale::Float64)
    # Add one to `in` to account for the density channel.
    ConvDeepSet(param(log.(scale .* ones(in + 1))))
end

conv_deep_set (generic function with 2 methods)

In [56]:
compute_dists2(x, t) = (x .- permutedims(t, (2, 1, 3))).^2 
rbf(dists2) = exp(-0.5 * dists2)

rbf (generic function with 1 method)

In [57]:
"""
    (layer::ConvDeepSet)(x::AbstractArray, y::AbstractArray, t::AbstractArray)

# Arguments
- `x::AbstractArray`: Locations of observed values of shape `(n, batch)`.
- `y::AbstractArray`: Observed values of shape `(n, channels, batch)`.
- `t::AbstractArray`: Discretisation locations of shape `(m)`.

"""
function (layer::ConvDeepSet)(x::AbstractArray, y::AbstractArray, t::AbstractArray)
    # Add extra dimension to `x`.
    # Shape: `(n, 1, batch)`.
    x = reshape(x, size(x, 1), 1, size(x, 2))
    
    # Add extra dimensions to `t`.
    # Shape: `(m, 1, batch)`.
    t = reshape(t, size(t, 1), 1, 1)
    
    # Shape: `(n, m, batch)`.
    dists2 = compute_dists2(x, t)
    
    # Add channel dimension.
    # Shape: `(n, m, channels, batch)`.
    dists2 = reshape(dists2, size(dists2)[1:2]..., 1, size(dists2)[3])
    
    # Apply length scales.
    # Shape: `(n, m, channels, batch)`.
    scales = reshape(exp.(layer.log_scales), 1, 1, length(layer.log_scales), 1)
    dists2 = dists2 ./ scales
    
    # Apply RBF to compute weights.
    weights = rbf.(dists2)
    
    # Add density channel to `y`.
    # Shape: `(n, channels + 1, batch)`.
    density = ones(size(y, 1), 1, size(y, 3))  # TODO: How to do this?
    y = cat(density, y; dims=2)
    
    # Multiply with weights and sum.
    # Shape: `(m, channels + 1, batch)`.
    y = reshape(y, size(y, 1), 1, size(y)[2:end]...)
    enc = dropdims(sum(y .* weights; dims=1); dims=1)
    
    # Divide by the density channel.
    density = enc[:, 1:1, :]
    others = enc[:, 2:end, :] ./ density
    enc = cat(density, others; dims=2)
    
    return enc
end

ConvDeepSet

In [85]:
x = randn(5, 2)
y = randn(5, 2, 2)
t = randn(12)

layer = conv_deep_set(2, 0.1)

layer(x, y, t)

Tracked 12×3×2 Array{Float64,3}:
[:, :, 1] =
 1.2128      -0.0406356  1.06981 
 3.43635     -0.175965   0.796478
 0.996007     0.0716996  1.36829 
 1.47365e-5   0.205988   1.77761 
 1.88523     -0.396243   0.467256
 0.00327371   0.156126   1.59973 
 0.118334     0.0765904  1.36125 
 0.00519613   0.149082   1.57677 
 0.0211232    0.122962   1.4951  
 1.63796     -0.409831   0.451599
 3.4462      -0.1768     0.794971
 0.0821575    0.0881809  1.39353 

[:, :, 2] =
 3.09767     -0.521656   0.297971 
 2.53938     -0.477718   0.060935 
 7.66396e-6   0.476543   0.238097 
 0.0124898   -0.49498    0.418953 
 0.36762     -0.0709685  0.0424042
 0.332386    -0.439957   0.496159 
 1.74514     -0.445123   0.489349 
 0.424903    -0.436792   0.500974 
 0.851192    -0.432112   0.508991 
 0.294283    -0.0414143  0.0504719
 2.53132     -0.476744   0.0600217
 1.52223     -0.439808   0.498029 

In [86]:
layer = conv_deep_set(2, 0.1)
net = Chain(
    Conv((5,), 3=>8, pad=(2,), stride=1, relu),
    DepthwiseConv((5,), 8=>8, pad=(2,), stride=1, relu),
    Conv((5,), 8=>1, pad=(2,), stride=1, relu),
)
net(layer(x, y, t))

Tracked 12×1×2 Array{Float64,3}:
[:, :, 1] =
 0.0                
 0.0                
 0.0                
 0.5631061866418412 
 0.9921545141867242 
 0.7439694211686797 
 0.22988506158596775
 0.8737294838335093 
 0.0                
 0.05333134061606415
 0.03887617319892675
 0.2341616837918043 

[:, :, 2] =
 0.0                
 0.5902604401254128 
 0.6619833398864128 
 0.28449830644621166
 0.9720939221802982 
 0.0                
 0.15543431320342005
 0.6122817874348238 
 0.06290598215037976
 0.0                
 0.0                
 0.29796178800778467