In [4]:
using Pkg

# Load package environment.
Pkg.activate(".")
Pkg.instantiate()

[32m[1mActivating[22m[39m environment at `~/.julia/dev/ConvCNP/Project.toml`


In [57]:
using Flux
using Statistics
using Printf
using Flux.Tracker

In [67]:
# Make an MLP.
m = Chain(
    Dense(2, 5),
    Dense(5, 2)
);

# Generate coefficients of a linear model.
true_mat = randn(2, 2)

# Set the loss to be the MSE.
loss(x, y) = Flux.mse(m(x), y);

In [68]:
"""
    generate_batch(true_mat; batch_size::Int=100)

Generate a batch of data from a `(2, 2)` matrix of coefficients `true_mat` of batch size `batch_size`.

# Returns
- `Array{Tuple{Array{Float64, 2}, Array{Float64, 2}}}`: Batch of data.
"""
function generate_batch(true_mat; batch_size::Int=100)
    data = Array{Tuple{Array{Float64, 2}, Array{Float64, 2}}}(undef, batch_size)

    for i = 1:batch_size
        x = randn(2, 2)
        data[i] = (x, true_mat * x)
    end
    
    return data
end

generate_batch

In [69]:
"""
    test_loss(true_mat; batch_size::Int=10)

Compute the loss of a batch of data from a `(2, 2)` matrix of coefficients `true_mat` of batch size `batch_size`.

# Returns
- `Float64`: Test loss.
"""
function test_loss(true_mat; batch_size::Int=10)
    total_loss = 0.0 
    
    for d = generate_batch(true_mat; batch_size=batch_size)
        total_loss += loss(d...)
    end
    
    return total_loss / batch_size
end


test_loss

In [71]:
# Perform 10 epochs.
for epoch = 1:10
    @printf("Epoch: %d\n", epoch)
    data = generate_batch(true_mat)
    Flux.train!(
        loss,
        params(m),
        data,
        ADAM(),
        # Show the loss of test data every second.
        cb=Flux.throttle(() -> @printf("Test loss: %.3e\n", Tracker.data(test_loss(true_mat))), 1)  
    )
end

Epoch: 1
Test loss: 3.629e-05
Epoch: 2
Test loss: 2.858e-05
Epoch: 3
Test loss: 2.472e-05
Epoch: 4
Test loss: 2.631e-05
Epoch: 5
Test loss: 1.745e-05
Epoch: 6
Test loss: 2.654e-05
Epoch: 7
Test loss: 2.812e-05
Epoch: 8
Test loss: 2.389e-05
Epoch: 9
Test loss: 2.599e-05
Epoch: 10
Test loss: 5.679e-05
