In [4]:
#Exploring neural network

using Flux
using Random

Random.seed!(0)

function nn_regression(X, Y, lambda; numiters=40)
    d = size(X,2); m = size(Y,2)

    model = Chain(
        Dense(d, 10, relu),
        Dense(10, 10, relu),
        Dense(10, 10, relu),
        Dense(10, m, identity))

    data = zip(eachrow(X), eachrow(Y))
    
    # Now define functions to pass to Flux.train!
    
    reg() = sum([norm(model[i].weight, 2).^2 for i = 1:length(model)]) # Regularizer
    
    loss(x,y) = norm(model(x)-y, 2).^2
    cost(x,y) = loss(x,y) + lambda*reg()
    
    opt = Descent(0.001) # 0.001 is the learning rate

    # These lines all handle the callback which prints the loss 
    # Be careful with this RMSE function - it takes (yhat - y).^2 as one parameter
    function RMSE(rsquared)
        mse = sum(rsquared)/length(rsquared)
        return sqrt(mse)
    end

    ctr = 0    
    function callback()
        if ctr % 1000 == 0 # controls the frequency of printing the loss
            println("Loss: $(RMSE([loss(x,y) for (x,y) in data]))")
        end
        ctr += 1
    end
    # Done with callback

    # This line trains the model
    Flux.@epochs numiters Flux.train!(cost, Flux.params(model), data, opt, cb=callback)
    return model
end

nn_regression (generic function with 1 method)

In [5]:
include("readclassjson.jl")
data = readclassjson("nn_regression.json")

U_train = data["U_train"]
U_test = data["U_test"]
v_train = data["v_train"]
v_test = data["v_test"]

1000-element Vector{Float64}:
  5.619063082978003
  5.43409993716697
  6.5944813438503935
  6.7284037018898255
  4.827989666209208
  5.406970869153291
  7.64457117258805
  3.7510062926491696
  7.879489390335793
  7.604758850850491
 12.212700338618268
  2.3017074095914385
  5.871629622937523
  ⋮
  4.154572706869771
  6.608893356950651
  9.438660448130161
  5.24610110826202
  3.8811424190941715
  4.045154637100237
  6.601107812749124
  4.377159353976358
  5.761956886896048
  7.695242571955377
  6.557001900223517
  2.738346742946826

In [8]:
using LinearAlgebra

rms(y, y_hat) = sqrt(sum((y_hat .- y).^2)/size(y, 1))
lambda = 1

predictall(model,U_train) = vcat([model(x) for x in eachrow(U_train)]...)


model = nn_regression(U_train, v_train, lambda)
train_rms = rms(predictall(model,U_train), v_train)
test_rms = rms(predictall(model,U_test), v_test)

Loss: 4.841715169660015
Loss: 1.8801530396166115
Loss: 2.105426626957191
Loss: 2.122705946719942
Loss: 2.171714469617606
Loss: 2.223052343015288
Loss: 2.3183070964406602
Loss: 2.292327684276827
Loss: 2.275308051837562
Loss: 2.2714534389181185
Loss: 2.2954873972521965
Loss: 2.291306915996526
Loss: 2.2716327802881815
Loss: 2.2737652078054205
Loss: 2.2845821909486155
Loss: 2.291796494134358
Loss: 2.2702749324183573
Loss: 2.2754619503732414
Loss: 2.2792587488363805
Loss: 2.291208925996707
Loss: 2.2697891922729307
Loss: 2.2764379356875017
Loss: 2.2763348499569824
Loss: 2.289779829515993
Loss: 2.2696178554450954
Loss: 2.276842514550623
Loss: 2.274600056215032
Loss: 2.28787150190992
Loss: 2.269557756339213
Loss: 2.276867080475522
Loss: 2.2734952112907463
Loss: 2.2857833516092283
Loss: 2.269536041973917
Loss: 2.276673483089321
Loss: 2.2727430994541815
Loss: 2.2837214583997576
Loss: 2.269527602033994
Loss: 2.2763742615156812
Loss: 2.272201011236125
Loss: 2.2817999907299824
Loss: 2.2695239488808

┌ Info: Epoch 1
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 2
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 3
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 4
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 5
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 6
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 7
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 8
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 9
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 10
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 11
└ @ Main C:\Users\surface\.julia

2.188187332974078

In [10]:
println(train_rms)
println(test_rms)

2.269550212363815
2.188187332974078


In [None]:
# Form: y_hat = g^4(g^3(g^2(g^1(x))))
# There are 4 layers.
# Layer 1: 30 input, 10 output, relu activation, A1:30*10, b1:10*1
# Layer 2: 10 input, 10 output, relu activation, A2:10*10, b2:10*1
# Layer 3: 10 input, 10 output, relu activation, A3:10*10, b3:10*1
# Output layer: 10 input, 1 output, identity activation, 
# A4:10*1, b4:1
# 541 scalar parameters in total.

In [11]:
model

Chain(
  Dense(30 => 10, relu),                [90m# 310 parameters[39m
  Dense(10 => 10, relu),                [90m# 110 parameters[39m
  Dense(10 => 10, relu),                [90m# 110 parameters[39m
  Dense(10 => 1),                       [90m# 11 parameters[39m
) [90m                  # Total: 8 arrays, [39m541 parameters, 2.613 KiB.

In [14]:
lambdas = 10 .^ range(-3,0,length=10)
test_rmss = zeros(10)

for i in 1:10
    model = nn_regression(U_train, v_train, lambdas[i])
    test_rmss[i] = rms(predictall(model,U_test), v_test)
end

┌ Info: Epoch 1
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 2
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 3
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 4
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 5
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 6
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 7
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 8
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 9
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 10
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 11
└ @ Main C:\Users\surface\.julia

Loss: 6.0548675820324975
Loss: 1.8406131831204395
Loss: 1.8837775309967593
Loss: 1.8098091966578755
Loss: 1.7186090530994296
Loss: 1.7104272563126484
Loss: 1.7797614856363
Loss: 1.7349536426223455
Loss: 1.6618444739924902
Loss: 1.6608883102170149
Loss: 1.732632863687367
Loss: 1.701225411443413
Loss: 1.637282360755988
Loss: 1.634733854956913
Loss: 1.7130676522675627
Loss: 1.6775053468886605
Loss: 1.6065389981863865
Loss: 1.6114060884171881
Loss: 1.6744115458850743
Loss: 1.641382959061765
Loss: 1.5778793189305005
Loss: 1.5707195329943617
Loss: 1.5946104152384402
Loss: 1.6051149687971134
Loss: 1.530865958580635
Loss: 1.5097778000117623
Loss: 1.5308312612804698
Loss: 1.5750918639875653
Loss: 1.5016294151107987
Loss: 1.4778437117338499
Loss: 1.4849641813968926
Loss: 1.5702141526034792
Loss: 1.4596694423709657
Loss: 1.4566831052332936
Loss: 1.4555672258097374
Loss: 1.5561933386543971
Loss: 1.4554030092178036
Loss: 1.4339121148945786
Loss: 1.4313879606421056
Loss: 1.5279656025194126
Loss: 1.4

┌ Info: Epoch 24
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 25
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 26
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 27
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 28
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 29
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 30
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 31
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 32
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 33
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 34
└ @ Main C:\Users\surfa


Loss: 1.0758540734972002
Loss: 1.0788081745602032
Loss: 1.02303220095074
Loss: 1.0427505058140598
Loss: 1.0748476829308156
Loss: 1.0759460817821067
Loss: 1.0233066676770346
Loss: 1.0460723613232041
Loss: 1.0483811395489626
Loss: 1.0620351232198069
Loss: 1.0045431481110507
Loss: 1.0282469571129016
Loss: 1.0396300534763534
Loss: 1.0538517771413791
Loss: 0.9944591525419928
Loss: 1.020694709695837
Loss: 1.0292309544066345
Loss: 1.0310981708445048
Loss: 0.9929303231746581
Loss: 1.0443571575854007
Loss: 1.0258148866447796
Loss: 1.0226324490814653
Loss: 0.9904410077066242
Loss: 1.0280034470818702
Loss: 1.0045318395763168
Loss: 1.020301349680202
Loss: 0.9856773006648375
Loss: 1.0167274042789467
Loss: 0.9973923973333059
Loss: 1.0141597940140588
Loss: 0.988569293135713
Loss: 1.0109698589606546
Loss: 0.9940715601672822
Loss: 1.0115740426679822
Loss: 0.9769863503977033
Loss: 1.007268188887393
Loss: 0.9939613235536477
Loss: 0.9945845860528549
Loss: 0.9761609202161396
Loss: 0.9964480149164502
Loss:

┌ Info: Epoch 7
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 8
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 9
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 10
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 11
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 12
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 13
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 14
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 15
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 16
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 17
└ @ Main C:\Users\surface\


Loss: 1.6475850069554019
Loss: 1.6442749223613853
Loss: 1.7403378379891856
Loss: 1.7190917160078896
Loss: 1.6471949586985115
Loss: 1.6404612071445046
Loss: 1.739292401285152
Loss: 1.7132138489352247
Loss: 1.645033810743161
Loss: 1.6403881247139291
Loss: 1.7278797954250522
Loss: 1.7142010586571632
Loss: 1.6446362862609611
Loss: 1.639374144192217
Loss: 1.7306993936209525
Loss: 1.7168844114728175
Loss: 1.64428873396068
Loss: 1.6361993332897176
Loss: 1.721298226189103
Loss: 1.7138550120116165
Loss: 1.6445273268797729
Loss: 1.6346984926065955
Loss: 1.7237158926219134
Loss: 1.7123630948631678
Loss: 1.6436409218421173
Loss: 1.6318941182226274
Loss: 1.718387857847863
Loss: 1.7063121803777175
Loss: 1.6449654810167658
Loss: 1.6324612813979147
Loss: 1.7208641386832235
Loss: 1.7073964247641107
Loss: 1.6452274401412668
Loss: 1.6315727430940348
Loss: 1.723579072292973
Loss: 1.7043990129890858
Loss: 1.6413223690343635
Loss: 1.6316221766048593
Loss: 1.7089562693581668
Loss: 1.7017796904831495
Loss: 1

Loss: 1.8136514463157878
Loss: 1.8091596569040358
Loss: 1.7851750966373932
Loss: 1.7629183259234522
Loss: 1.814959404644531
Loss: 1.8088402956420395
Loss: 1.785534352166524
Loss: 1.7634062782052864
Loss: 1.814828076926232
Loss: 1.80961777990122
Loss: 1.7855406191150072
Loss: 1.7631238472573232
Loss: 1.8127926516475055
Loss: 1.8085239982204617
Loss: 1.786204950016244
Loss: 1.7627642528233147
Loss: 1.8124417778468511
Loss: 1.8072647792158485
Loss: 1.7860563646109873
Loss: 1.761944707340539
Loss: 1.8117160732434832
Loss: 1.8082470520286185
Loss: 1.7864637972324928
Loss: 1.763430972607808
Loss: 1.812360724255816
Loss: 1.8068705624098476
Loss: 1.7860614657825338
Loss: 1.763183428352864
Loss: 1.8091366334760142
Loss: 1.8079647857473364
Loss: 1.7857485272271423
Loss: 1.7625705911362424
Loss: 1.811095970528342
Loss: 1.8065865251348348
Loss: 1.7862211469130342
Loss: 1.7631671072190742
Loss: 1.812176229825011
Loss: 1.8059815428031178
Loss: 5.594910868060176
Loss: 1.7912405478099207
Loss: 1.91236

┌ Info: Epoch 30
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 31
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 32
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 33
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 34
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 35
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 36
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 37
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 38
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 39
└ @ Main C:\Users\surface\.julia\packages\Flux\6Q5r4\src\optimise\train.jl:154
┌ Info: Epoch 40
└ @ Main C:\Users\surfa

In [15]:
println(minimum(test_rmss))
println(lambdas[argmin(test_rmss)])

1.0917593852676348
0.004641588833612777
