In [1]:
"""
Author: Sebastian Vendt, University of Ulm

This script implements the four different neural networks proposed in the paper 
Recurrent Convolutional Neural Networks: A Better Model of Biological Object Recognition
of Courtney J. Spoerer et al. 

BNet:  the bottom up network with two hidden convolutional layers
BLNet: the BNet including lateral connections within the hidden layers
BTNet: the BNet including top down connections from the second hidden layer to the first
BLTNet:the BNet including top down and lateral connections 

"""

using Flux, Statistics
using Flux: crossentropy, onecold
using Printf, BSON
import LinearAlgebra: norm
using NNlib
using FeedbackConvNets

include("./dataManager.jl")
using .dataManager: make_batch

using Base
norm(x::TrackedArray{T}) where T = sqrt(sum(abs2.(x)) + eps(T)) 

######################
# PARAMETERS
######################
const batch_size = 100
const momentum = 0.9f0
const lambda = 0.0005f0
init_learning_rate = 0.01f0
learning_rate = init_learning_rate
const epochs = 5
const decay_rate = 0.1f0
const decay_step = 40
# number of timesteps the network is unrolled
const time_steps = 4
# end of PARAMETERS

hidden = Dict(
    "l1" => zeros(Float32, 32, 32, 32, 10),
    "l2" => zeros(Float32, 16, 16, 32, 10)
    )

function adapt_learnrate(epoch_idx)
    return init_learning_rate * decay_rate^(epoch_idx / decay_step)
end

function trainReccurentNet(reccurent_model, train_set, test_set)
    function loss(x, y)
        loss_val = 0.0f0
        for i in 1:time_steps
            loss_val += crossentropy(reccurent_model(x), y) + lambda * sum(norm, params(reccurent_model))
        end
        Flux.reset!(reccurent_model)
        return loss_val
    end
    
    function accuracy(test_set::Tuple)
        for i in 1:time_steps-1
            y_hat = reccurent_model(test_set[1])
        end
        acc = mean(onecold(reccurent_model(test_set[1])) .== onecold(test_set[2]))
        Flux.reset!(reccurent_model)
        return acc
    end
    
    opt = Momentum(learning_rate, momentum)
    for i in 1:epochs
        Flux.train!(loss, params(reccurent_model), train_set, opt)
        opt.eta = adapt_learnrate(i)
        acc = accuracy(test_set)
        @printf("Accuracy %f in epoch %d\n", acc, i)
        flush(Base.stdout)
    end
    @printf("final accuracy: %d\n", accuracy(test_set))
end

function trainFeedforwardNet(feedforward_model, train_set, test_set)
    function accuracy(test_set::Tuple)
        return mean(onecold(feedforward_model(test_set[1])) .== onecold(test_set[2]))
    end
    
    opt = Momentum(learning_rate, momentum)
    for i in 1:epochs
        Flux.train!((x, y) -> crossentropy(feedforward_model(x), y) + lambda * sum(norm, params(feedforward_model)), 
            params(feedforward_model), train_set, opt)
        opt.eta = adapt_learnrate(i)
        acc = accuracy(test_set)
        @printf("Accuracy %f in epoch %d\n", acc, i)
        flush(Base.stdout)
    end
    @printf("final accuracy: %d\n", accuracy(test_set))
end

@printf("Constructing models...\n")
BModel = spoerer_model_b(Float32, inputsize=(32, 32))
BKModel = spoerer_model_bk(Float32, inputsize=(32, 32))
BFModel = spoerer_model_bf(Float32, inputsize=(32, 32))
BLChain = spoerer_model_bl(Float32, inputsize=(32, 32), kernel=(3, 3), features=32)
BTChain = spoerer_model_bt(Float32, inputsize=(32, 32), kernel=(3, 3), features=32)
BLTChain = spoerer_model_bt(Float32, inputsize=(32, 32), kernel=(3, 3), features=32)

BLModel = Flux.Recur(BLChain, hidden)
BTModel = Flux.Recur(BTChain, hidden)
BLTModel = Flux.Recur(BLTChain, hidden)


train_set, mean_img, std_img = make_batch("../digitclutter/digitclutter/50_4digits.mat", batch_size=10)
# test_set needs to have the same batchsize as the train_set due to model state init
test_set, temp1, temp2 = make_batch("../digitclutter/digitclutter/50_4digits.mat", batch_size=10)

@printf("loaded %d batches of size %d\n", length(train_set), size(train_set[1][1], 4))

@info("Training BModel\n")
# trainFeedforwardNet(BModel, train_set, test_set[1])
@info("Training BKModel\n")
# trainFeedforwardNet(BKModel, train_set, test_set[1])
@info("Training BFModel\n")
# trainFeedforwardNet(BFModel, train_set, test_set[1])

@info("Training BLModel\n")
trainReccurentNet(BLModel, train_set, test_set[1])
@info("Training BTModel\n")
trainReccurentNet(BTModel, train_set, test_set[1])
@info("Training BLTModel\n")
trainReccurentNet(BLTModel, train_set, test_set[1])




Constructing models...
Reading .mat file form source ../digitclutter/digitclutter/50_4digits.mat
calculate mean and standart deviation of dataset
Creating batches
Reading .mat file form source ../digitclutter/digitclutter/50_4digits.mat
calculate mean and standart deviation of dataset
Creating batches
loaded 5 batches of size 10


┌ Info: Training BModel
└ @ Main In[1]:118


Accuracy 0.200000 in epoch 1
Accuracy 0.000000 in epoch 2
Accuracy 0.000000 in epoch 3
Accuracy 0.000000 in epoch 4
Accuracy 0.000000 in epoch 5
final accuracy: 0


┌ Info: Training BKModel
└ @ Main In[1]:120


Accuracy 0.100000 in epoch 1
Accuracy 0.000000 in epoch 2
Accuracy 0.000000 in epoch 3
Accuracy 0.000000 in epoch 4
Accuracy 0.000000 in epoch 5
final accuracy: 0


┌ Info: Training BFModel
└ @ Main In[1]:122


Accuracy 0.400000 in epoch 1
Accuracy 0.000000 in epoch 2
Accuracy 0.000000 in epoch 3
Accuracy 0.000000 in epoch 4
Accuracy 0.000000 in epoch 5
final accuracy: 0


┌ Info: Training BLModel
└ @ Main In[1]:125


Accuracy 0.100000 in epoch 1
Accuracy 0.000000 in epoch 2
Accuracy 0.000000 in epoch 3
Accuracy 0.000000 in epoch 4
Accuracy 0.000000 in epoch 5
final accuracy: 0


┌ Info: Training BTModel
└ @ Main In[1]:127


Accuracy 0.000000 in epoch 1
Accuracy 0.000000 in epoch 2
Accuracy 0.000000 in epoch 3
Accuracy 0.000000 in epoch 4
Accuracy 0.000000 in epoch 5
final accuracy: 0


┌ Info: Training BLTModel
└ @ Main In[1]:129


Accuracy 0.000000 in epoch 1
Accuracy 0.000000 in epoch 2
Accuracy 0.000000 in epoch 3
Accuracy 0.000000 in epoch 4
Accuracy 0.000000 in epoch 5
final accuracy: 0
