-
Notifications
You must be signed in to change notification settings - Fork 13
/
Nn.jl
1139 lines (935 loc) · 40.2 KB
/
Nn.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"Part of [BetaML](https://github.com/sylvaticus/BetaML.jl). Licence is MIT."
"""
BetaML.Nn module
Implement the functionality required to define an artificial Neural Network, train it with data, forecast data and assess its performances.
Common type of layers and optimisation algorithms are already provided, but you can define your own ones subclassing respectively the `AbstractLayer` and `OptimisationAlgorithm` abstract types.
The module provide the following types or functions. Use `?[type or function]` to access their full signature and detailed documentation:
# Model definition:
- `DenseLayer`: Classical feed-forward layer with user-defined activation function
- `DenseNoBiasLayer`: Classical layer without the bias parameter
- `VectorFunctionLayer`: Parameterless layer whose activation function run over the ensable of its nodes rather than on each one individually
- `NeuralNetworkEstimator`: Build the chained network and define a cost function
Each layer can use a default activation function, one of the functions provided in the `Utils` module (`relu`, `tanh`, `softmax`,...) or you can specify your own function. The derivative of the activation function can be optionally be provided, in such case training will be quicker, altought this difference tends to vanish with bigger datasets.
You can alternativly implement your own layer defining a new type as subtype of the abstract type `AbstractLayer`. Each user-implemented layer must define the following methods:
- A suitable constructor
- `forward(layer,x)`
- `backward(layer,x,next_gradient)`
- `get_params(layer)`
- `get_gradient(layer,x,next_gradient)`
- `set_params!(layer,w)`
- `size(layer)`
# Model fitting:
- `fit!(nn,X,Y)`: fitting function
- `fitting_info(nn)`: Default callback function during fitting
- `SGD`: The classical optimisation algorithm
- `ADAM`: A faster moment-based optimisation algorithm
To define your own optimisation algorithm define a subtype of `OptimisationAlgorithm` and implement the function `single_update!(θ,▽;opt_alg)` and eventually `init_optalg!(⋅)` specific for it.
# Model predictions and assessment:
- `predict(nn)` or `predict(nn,X)`: Return the output given the data
While high-level functions operating on the dataset expect it to be in the standard format (n_records × n_dimensions matrices) it is customary to represent the chain of a neural network as a flow of column vectors, so all low-level operations (operating on a single datapoint) expect both the input and the output as a column vector.
"""
module Nn
#import Base.Threads.@spawn
using Random, LinearAlgebra, StaticArrays, LoopVectorization, Zygote, ProgressMeter, Reexport, DocStringExtensions
import Distributions: Uniform
using ForceImport
@force using ..Api
@force using ..Utils
import Base.size
import Base: +, -, *, /, sum, sqrt
import Base.show
# module own functions
export AbstractLayer, forward, backward, get_params, get_gradient, set_params!, size, preprocess! # layer API
#export forward_old, backward_old, get_gradient_old
export DenseLayer, DenseNoBiasLayer, VectorFunctionLayer, ScalarFunctionLayer # Available layers
export ConvLayer, ReshaperLayer, PoolingLayer
export init_optalg!, single_update! # Optimizers API
export SGD,ADAM, DebugOptAlg # Available optimizers
export Learnable, fitting_info, NeuralNetworkEstimator, NNHyperParametersSet, NeuralNetworkEstimatorOptionsSet # NN API
# export get_nparams, NN, buildNetwork, predict, loss, train!, getindex, show # old
# for working on gradient as e.g [([1.0 2.0; 3.0 4.0], [1.0,2.0,3.0]),([1.0,2.0,3.0],1.0)]
"""
Learnable(data)
Structure representing the learnable parameters of a layer or its gradient.
The learnable parameters of a layers are given in the form of a N-tuple of Array{Float64,N2} where N2 can change (e.g. we can have a layer with the first parameter being a matrix, and the second one being a scalar).
We wrap the tuple on its own structure a bit for some efficiency gain, but above all to define standard mathematic operations on the gradients without doing "type piracy" with respect to Base tuples.
"""
mutable struct Learnable
data::Tuple{Vararg{Array{Float64,N} where N}}
function Learnable(data)
return new(data)
end
end
function +(items::Learnable...)
values = collect(items[1].data)
N = length(values)
@inbounds for item in items[2:end]
@inbounds @simd for n in 1:N # @inbounds @simd
values[n] += item.data[n]
end
end
return Learnable(Tuple(values))
end
sum(items::Learnable...) = +(items...)
function -(items::Learnable...)
values = collect(items[1].data)
N = length(values)
@inbounds for item in items[2:end]
@inbounds @simd for n in 1:N # @simd
values[n] -= item.data[n]
end
end
return Learnable(Tuple(values))
end
function *(items::Learnable...)
values = collect(items[1].data)
N = length(values)
@inbounds for item in items[2:end]
@inbounds @simd for n in 1:N # @simd
values[n] = values[n] .* item.data[n]
end
end
return Learnable(Tuple(values))
end
+(item::Learnable,sc::Number) = Learnable(Tuple([item.data[i] .+ sc for i in 1:length(item.data)]))
+(sc::Number, item::Learnable) = +(item,sc)
-(item::Learnable,sc::Number) = Learnable(Tuple([item.data[i] .- sc for i in 1:length(item.data)]))
-(sc::Number, item::Learnable) = (-(item,sc)) * -1
*(item::Learnable,sc::Number) = Learnable(item.data .* sc)
*(sc::Number, item::Learnable) = Learnable(sc .* item.data)
/(item::Learnable,sc::Number) = Learnable(item.data ./ sc)
/(sc::Number,item::Learnable,) = Learnable(Tuple([sc ./ item.data[i] for i in 1:length(item.data)]))
sqrt(item::Learnable) = Learnable(Tuple([sqrt.(item.data[i]) for i in 1:length(item.data)]))
/(item1::Learnable,item2::Learnable) = Learnable(Tuple([item1.data[i] ./ item2.data[i] for i in 1:length(item1.data)]))
#=
# not needed ??
function Base.iterate(iter::Learnable, state=(iter.data[1], 1))
element, count = state
if count > length(iter)
return nothing
elseif count == length(iter)
return (element, (iter.data[count], count + 1))
end
return (element, (iter.data[count+1], count + 1))
end
Base.length(iter::Learnable) = length(iter.data)
#Base.eltype(iter::Learnable) = Int
=#
## Sckeleton for the layer functionality.
# See nn_default_layers.jl for actual implementations
abstract type AbstractLayer end
abstract type RecursiveLayer <: AbstractLayer end
include("default_layers/DenseLayer.jl")
include("default_layers/DenseNoBiasLayer.jl")
include("default_layers/VectorFunctionLayer.jl")
include("default_layers/ScalarFunctionLayer.jl")
include("default_layers/ConvLayer.jl")
include("default_layers/PoolingLayer.jl")
include("default_layers/ReshaperLayer.jl")
include("default_layers/RNNLayer.jl")
"""
forward(layer,x)
Predict the output of the layer given the input
# Parameters:
* `layer`: Worker layer
* `x`: Input to the layer
# Return:
- An Array{T,1} of the prediction (even for a scalar)
"""
function forward(layer::AbstractLayer,x)
error("Not implemented for this kind of layer. Please implement `forward(layer,x)`.")
end
"""
backward(layer,x,next_gradient)
Compute backpropagation for this layer with respect to its inputs
# Parameters:
* `layer`: Worker layer
* `x`: Input to the layer
* `next_gradient`: Derivative of the overal loss with respect to the input of the next layer (output of this layer)
# Return:
* The evaluated gradient of the loss with respect to this layer inputs
"""
function backward(layer::AbstractLayer,x,next_gradient)
error("Not implemented for this kind of layer. Please implement `backward(layer,x,next_gradient)`.")
end
"""
get_params(layer)
Get the layers current value of its trainable parameters
# Parameters:
* `layer`: Worker layer
# Return:
* The current value of the layer's trainable parameters as tuple of matrices. It is up to you to decide how to organise this tuple, as long you are consistent with the `get_gradient()` and `set_params()` functions. Note that starting from BetaML 0.2.2 this tuple needs to be wrapped in its `Learnable` type.
"""
function get_params(layer::AbstractLayer)
error("Not implemented for this kind of layer. Please implement `get_params(layer)`.")
end
"""
get_gradient(layer,x,next_gradient)
Compute backpropagation for this layer with respect to the layer weigths
# Parameters:
* `layer`: Worker layer
* `x`: Input to the layer
* `next_gradient`: Derivative of the overaall loss with respect to the input of the next layer (output of this layer)
# Return:
* The evaluated gradient of the loss with respect to this layer's trainable parameters as tuple of matrices. It is up to you to decide how to organise this tuple, as long you are consistent with the `get_params()` and `set_params()` functions. Note that starting from BetaML 0.2.2 this tuple needs to be wrapped in its `Learnable` type.
"""
function get_gradient(layer::AbstractLayer,x,next_gradient)
error("Not implemented for this kind of layer. Please implement `get_gradient(layer,x,next_gradient)`.")
end
"""
set_params!(layer,w)
Set the trainable parameters of the layer with the given values
# Parameters:
* `layer`: Worker layer
* `w`: The new parameters to set (Learnable)
# Notes:
* The format of the tuple wrapped by Learnable must be consistent with those of the `get_params()` and `get_gradient()` functions.
"""
function set_params!(layer::AbstractLayer,w)
error("Not implemented for this kind of layer. Please implement `set_params!(layer,w)`.")
end
"""
size(layer)
Get the size of the layers in terms of (size in input, size in output) - both as tuples
# Notes:
* You need to use `import Base.size` before defining this function for your layer
"""
function size(layer::AbstractLayer)
error("Not implemented for this kind of layer. Please implement `size(layer)`.")
end
"""get_nparams(layer)
Return the number of parameters of a layer.
It doesn't need to be implemented by each layer type, as it uses get_params().
"""
function get_nparams(layer::AbstractLayer)
pars = get_params(layer)
nP = 0
for p in pars.data
nP += *(size(p)...)
end
return nP
end
"""
$(TYPEDSIGNATURES)
Preprocess the layer with information known at layer creation (i.e. no data info used)
This function is used for some layers to cache some computation that doesn't require the data and it is called at the beginning of `fit!`.
For example, it is used in ConvLayer to store the ids of the convolution.
# Notes:
- as it doesn't depend on data, it is not reset by `reset!`
"""
function preprocess!(layer::AbstractLayer)
return nothing
end
# ------------------------------------------------------------------------------
# NN-related functions
"""
NN
Low-level representation of a Neural Network. Use the model `NeuralNetworkEstimator` instead.
# Fields:
* `layers`: Array of layers objects
* `cf`: Cost function
* `dcf`: Derivative of the cost function
* `trained`: Control flag for trained networks
"""
mutable struct NN
layers::Array{AbstractLayer,1}
cf::Function
dcf::Union{Function,Nothing}
trained::Bool
name::String
end
"""
buildNetwork(layers,cf;dcf,name)
Instantiate a new Feedforward Neural Network
!!! warning
This function has been de-exported in BetaML 0.9.
Use the model [`NeuralNetworkEstimator`](@ref) instead.
Parameters:
* `layers`: Array of layers objects
* `cf`: Cost function
* `dcf`: Derivative of the cost function [def: `nothing`]
* `name`: Name of the network [def: "Neural Network"]
# Notes:
* Even if the network ends with a single output note, the cost function and its derivative should always expect y and ŷ as column vectors.
"""
function buildNetwork(layers,cf;dcf=match_known_derivatives(cf),name="Neural Network")
return NN(layers,cf,dcf,false,name)
end
"""
predict(nn::NN,x)
Low-level network predictions. Use instead `predict(m::NeuralNetworkEstimator)`
# Parameters:
* `nn`: Worker network
* `x`: Input to the network (n × d)
"""
function predict(nn::NN,x)
#x = makematrix(x)
# get the output dimensions
n = size(x)[1]
lastlayer_size = size(nn.layers[end])[2]
length(lastlayer_size) == 1 || error("The last NN layer should always ve a single dimension vector. Eventually use `ReshaperLayer` to reshape its output as a vector.")
d = lastlayer_size[1]
out = zeros(n,d)
for i in 1:size(x)[1]
values = selectdim(x,1,i) # x[i,:]
for l in nn.layers
values = forward(l,values)
end
out[i,:] = values
end
return out
end
"""
loss(fnn,x,y)
Low level funciton that compute the avg. network loss on a test set (or a single (1 × d) data point)
# Parameters:
* `fnn`: Worker network
* `x`: Input to the network (n) or (n x d)
* `y`: Label input (n) or (n x d)
"""
function loss(nn::NN,x,y)
#x = makematrix(x) # TODO: check these two lines
y = makematrix(y)
(n,d) = size(x)
#(nn.trained || n == 1) ? "" : @warn "Seems you are trying to test a neural network that has not been tested. Use first `train!(nn,x,y)`"
ϵ = 0.0
for i in 1:n
ŷ = predict(nn,x[i,:]')[1,:]
ϵ += nn.cf(y[i,:],ŷ)
end
return ϵ/n
end
"""
get_params(nn)
Retrieve current weigthts
# Parameters:
* `nn`: Worker network
# Notes:
* The output is a vector of tuples of each layer's input weigths and bias weigths
"""
@inline function get_params(nn::NN)
return [get_params(l) for l in nn.layers]
end
"""
get_gradient(nn,x,y)
Low level function that retrieve the current gradient of the weigthts (i.e. derivative of the cost with respect to the weigths). Unexported in BetaML >= v0.9
# Parameters:
* `nn`: Worker network
* `x`: Input to the network (d,1)
* `y`: Label input (d,1)
#Notes:
* The output is a vector of tuples of each layer's input weigths and bias weigths
"""
function get_gradient(nn::NN,x::Union{T,AbstractArray{T,N1}},y::Union{T2,AbstractArray{T2,N2}}) where { T <: Number, T2 <: Number, N1, N2}
#x = makecolvector(x)
#y = makecolvector(y)
nLayers = length(nn.layers)
# Stap 1: Forward pass
forwardStack = Vector{Array{Float64}}(undef,nLayers+1)
forwardStack[1] = x
@inbounds for (i,l) in enumerate(nn.layers)
#println(i)
forwardStack[i+1] = forward(l,forwardStack[i])
end
# Step 2: Backpropagation pass
backwardStack = Vector{Array{Float64}}(undef,nLayers+1)
if nn.dcf != nothing
backwardStack[end] = nn.dcf(y,forwardStack[end]) # adding dϵ_dHatY
else
backwardStack[end] = gradient(nn.cf,y,forwardStack[end])[2] # using AD from Zygote
end
@inbounds for lidx in nLayers:-1:1
l = nn.layers[lidx]
#println(lidx)
#println(l)
dϵ_do = backward(l,forwardStack[lidx],backwardStack[lidx+1])
backwardStack[lidx] = dϵ_do
end
# Step 3: Computing gradient of weigths
dWs = Array{Learnable,1}(undef,nLayers)
@inbounds for lidx in 1:nLayers
dWs[lidx] = get_gradient(nn.layers[lidx],forwardStack[lidx],backwardStack[lidx+1])
end
return dWs
end
"""
get_batchgradient(nn,xbatch,ybatch)
Retrieve the current gradient of the weigthts (i.e. derivative of the cost with respect to the weigths)
# Parameters:
* `nn`: Worker network
* `xbatch`: Input to the network (n,d)
* `ybatch`: Label input (n,d)
#Notes:
* The output is a vector of tuples of each layer's input weigths and bias weigths
"""
function get_batchgradient(nn,xbatch::AbstractArray{T,N1},ybatch::AbstractArray{T2,N2}) where {T <: Number, T2 <: Number, N1, N2}
#return [get_gradient(nn,xbatch[j,:],ybatch[j,:]) for j in 1:size(xbatch,1)]
bsize = size(xbatch,1)
gradients = Array{Vector{Learnable},1}(undef,bsize)
# Note: in Julia 1.6 somehow the multithreading is less efficient than in Julia 1.5
# Using @inbounds @simd result faster than using 4 threads, so reverting to it.
# But to keep following the evolution, as there seems to be some issues on performances
# in Julia 1.6: https://discourse.julialang.org/t/drop-of-performances-with-julia-1-6-0-for-interpolationkernels/58085
# Maybe when that's solved it will be again more convenient to use multi-threading
#Threads.@threads
@inbounds for j in 1:bsize # @simd
gradients[j] = get_gradient(nn,xbatch[j,:],ybatch[j,:])
end
return gradients
end
"""
set_params!(nn,w)
Update weigths of the network
# Parameters:
* `nn`: Worker network
* `w`: The new weights to set
"""
function set_params!(nn::NN,w)
for lidx in 1:length(nn.layers)
set_params!(nn.layers[lidx],w[lidx])
end
end
"get_nparams(nn) - Return the number of trainable parameters of the neural network."
function get_nparams(nn::NN)
nP = 0
for l in nn.layers
nP += get_nparams(l)
end
return nP
end
function preprocess!(nn::NN)
for l in nn.layers
preprocess!(l)
end
end
Base.getindex(n::NN, i::AbstractArray) = NN(n.layers[i]...)
# ------------------------------------------------------------------------------
# Optimisation-related functions
"""
OptimisationAlgorithm
Abstract type representing an Optimisation algorithm.
Currently supported algorithms:
- `SGD` (Stochastic) Gradient Descent
- `ADAM` The ADAM algorithm, an adaptive moment estimation optimiser.
See `?[Name OF THE ALGORITHM]` for their details
You can implement your own optimisation algorithm using a subtype of `OptimisationAlgorithm` and implementing its constructor and the update function `singleUpdate(⋅)` (type `?singleUpdate` for details).
"""
abstract type OptimisationAlgorithm end
include("Nn_default_optalgs.jl")
"""
fitting_info(nn,x,y;n,batch_size,epochs,verbosity,n_epoch,n_batch)
Default callback funtion to display information during training, depending on the verbosity level
# Parameters:
* `nn`: Worker network
* `x`: Batch input to the network (batch_size,d)
* `y`: Batch label input (batch_size,d)
* `n`: Size of the full training set
* `n_batches` : Number of baches per epoch
* `epochs`: Number of epochs defined for the training
* `verbosity`: Verbosity level defined for the training (NONE,LOW,STD,HIGH,FULL)
* `n_epoch`: Counter of the current epoch
* `n_batch`: Counter of the current batch
#Notes:
* Reporting of the error (loss of the network) is expensive. Use `verbosity=NONE` for better performances
"""
function fitting_info(nn,x,y;n,n_batches,epochs,verbosity,n_epoch,n_batch)
if verbosity == NONE
return false # doesn't stop the training
end
nMsgDict = Dict(LOW => 0, STD => 10,HIGH => 100, FULL => n)
nMsgs = nMsgDict[verbosity]
batch_size = size(x,1)
if verbosity == FULL || ( n_batch == n_batches && ( n_epoch == 1 || n_epoch % ceil(epochs/nMsgs) == 0))
ϵ = loss(nn,x,y)
println("Training.. \t avg ϵ on (Epoch $n_epoch Batch $n_batch): \t $(ϵ)")
end
return false
end
"""
train!(nn,x,y;epochs,batch_size,sequential,opt_alg,verbosity,cb)
Low leval function that trains a neural network with the given x,y data.
!!! warning
This function is deprecated and has been unexported in BetaML v0.9.
Use the model [`NeuralNetworkEstimator`](@ref) instead.
# Parameters:
* `nn`: Worker network
* `x`: Training input to the network (records x dimensions)
* `y`: Label input (records x dimensions)
* `epochs`: Number of passages over the training set [def: `100`]
* `batch_size`: Size of each individual batch [def: `min(size(x,1),32)`]
* `sequential`: Wether to run all data sequentially instead of random [def: `false`]
* `opt_alg`: The optimisation algorithm to update the gradient at each batch [def: `ADAM()`]
* `verbosity`: A verbosity parameter for the trade off information / efficiency [def: `STD`]
* `cb`: A callback to provide information. [def: `fitting_info`]
* `rng`: Random Number Generator (see [`FIXEDSEED`](@ref)) [deafult: `Random.GLOBAL_RNG`]
# Return:
- A named tuple with the following information
- `epochs`: Number of epochs actually ran
- `ϵ_epochs`: The average error on each epoch (if `verbosity > LOW`)
- `θ_epochs`: The parameters at each epoch (if `verbosity > STD`)
# Notes:
- Currently supported algorithms:
- `SGD`, the classical (Stochastic) Gradient Descent optimiser
- `ADAM`, an adaptive moment estimation optimiser
- Look at the individual optimisation algorithm (`?[Name OF THE ALGORITHM]`) for info on its parameter, e.g. [`?SGD`](@ref SGD) for the Stochastic Gradient Descent.
- You can implement your own optimisation algorithm using a subtype of `OptimisationAlgorithm` and implementing its constructor and the update function `single_update!(⋅)` (type `?single_update!` for details).
- You can implement your own callback function, altought the one provided by default is already pretty generic (its output depends on the `verbosity` parameter). See [`fitting_info`](@ref) for informations on the cb parameters.
- Both the callback function and the [`single_update!`](@ref) function of the optimisation algorithm can be used to stop the training algorithm, respectively returning `true` or `stop=true`.
- The verbosity can be set to any of `NONE`,`LOW`,`STD`,`HIGH`,`FULL`.
- The update is done computing the average gradient for each batch and then calling `single_update!` to let the optimisation algorithm perform the parameters update
"""
function train!(nn::NN,x,y; epochs=100, batch_size=min(size(x,1),32), sequential=false, verbosity::Verbosity=STD, cb=fitting_info, opt_alg::OptimisationAlgorithm=ADAM(),rng = Random.GLOBAL_RNG)#, η=t -> 1/(1+t), λ=1, rShuffle=true, nMsgs=10, tol=0opt_alg::SD=SD())
if verbosity > STD
@codelocation
end
#x = makematrix(x)
#y = makematrix(y)
preprocess!(nn)
(n,d) = size(x)
batch_size = min(size(x,1),batch_size)
if verbosity > NONE # Note that are two "Verbosity type" objects. To compare with numbers use Int(NONE) > 1
println("***\n*** Training $(nn.name) for $epochs epochs with algorithm $(typeof(opt_alg)).")
end
ϵ_epoch_l = Inf
θ_epoch_l = get_params(nn)
ϵ_epoch = loss(nn,x,y)
θ_epoch = get_params(nn)
ϵ_epochs = Float64[]
θ_epochs = []
init_optalg!(opt_alg::OptimisationAlgorithm;θ=get_params(nn),batch_size=batch_size,x=x,y=y)
if verbosity == NONE
showTime = typemax(Float64)
elseif verbosity <= LOW
showTime = 50
elseif verbosity <= STD
showTime = 1
elseif verbosity <= HIGH
showTime = 0.5
else
showTime = 0.2
end
@showprogress showTime "Training the Neural Network..." for t in 1:epochs
batches = batch(n,batch_size,sequential=sequential,rng=rng)
n_batches = length(batches)
if t == 1
if (verbosity >= STD) push!(ϵ_epochs,ϵ_epoch); end
if (verbosity > STD) push!(θ_epochs,θ_epoch); end
end
for (i,batch) in enumerate(batches)
xbatch = x[batch, :]
ybatch = y[batch, :]
θ = get_params(nn)
# remove @spawn and fetch (on next row) to get single thread code
# note that there is no random number issue here..
#gradients = @spawn get_gradient(nn,xbatch,ybatch)
#sumGradient = sum(fetch(gradients))
gradients = get_batchgradient(nn,xbatch,ybatch)
sumGradient = sum(gradients)
▽ = sumGradient / length(batch)
#▽ = gradDiv.(gradSum([get_gradient(nn,xbatch[j,:],ybatch[j,:]) for j in 1:batch_size]), batch_size)
res = single_update!(θ,▽;n_epoch=t,n_batch=i,n_batches=n_batches,xbatch=xbatch,ybatch=ybatch,opt_alg=opt_alg)
set_params!(nn,res.θ)
cbOut = cb(nn,xbatch,ybatch,n=d,n_batches=n_batches,epochs=epochs,verbosity=verbosity,n_epoch=t,n_batch=i)
if(res.stop==true || cbOut==true)
nn.trained = true
return (epochs=t,ϵ_epochs=ϵ_epochs,θ_epochs=θ_epochs)
end
end
if (verbosity >= STD)
ϵ_epoch_l = ϵ_epoch
ϵ_epoch = loss(nn,x,y)
push!(ϵ_epochs,ϵ_epoch);
end
if (verbosity > STD)
θ_epoch_l = θ_epoch
θ_epoch = get_params(nn)
push!(θ_epochs,θ_epoch); end
end
if (verbosity > NONE)
if verbosity > LOW
ϵ_epoch = loss(nn,x,y)
end
println("Training of $epochs epoch completed. Final epoch error: $(ϵ_epoch).");
end
nn.trained = true
return (epochs=epochs,ϵ_epochs=ϵ_epochs,θ_epochs=θ_epochs)
end
"""
single_update!(θ,▽;n_epoch,n_batch,batch_size,xbatch,ybatch,opt_alg)
Perform the parameters update based on the average batch gradient.
# Parameters:
- `θ`: Current parameters
- `▽`: Average gradient of the batch
- `n_epoch`: Count of current epoch
- `n_batch`: Count of current batch
- `n_batches`: Number of batches per epoch
- `xbatch`: Data associated to the current batch
- `ybatch`: Labels associated to the current batch
- `opt_alg`: The Optimisation algorithm to use for the update
# Notes:
- This function is overridden so that each optimisation algorithm implement their
own version
- Most parameters are not used by any optimisation algorithm. They are provided
to support the largest possible class of optimisation algorithms
- Some optimisation algorithms may change their internal structure in this function
"""
function single_update!(θ,▽;n_epoch,n_batch,n_batches,xbatch,ybatch,opt_alg::OptimisationAlgorithm)
return single_update!(θ,▽,opt_alg;n_epoch=n_epoch,n_batch=n_batch,n_batches=n_batches,xbatch=xbatch,ybatch=ybatch)
end
function single_update!(θ,▽,opt_alg::OptimisationAlgorithm;n_epoch,n_batch,n_batches,xbatch,ybatch)
error("singleUpdate() not implemented for this optimisation algorithm")
end
"""
init_optalg!(opt_alg;θ,batch_size,x,y)
Initialize the optimisation algorithm
# Parameters:
- `opt_alg`: The Optimisation algorithm to use
- `θ`: Current parameters
- `batch_size`: The size of the batch
- `x`: The training (input) data
- `y`: The training "labels" to match
* `rng`: Random Number Generator (see [`FIXEDSEED`](@ref)) [deafult: `Random.GLOBAL_RNG`]
# Notes:
- Only a few optimizers need this function and consequently ovverride it. By default it does nothing, so if you want write your own optimizer and don't need to initialise it, you don't have to override this method
"""
init_optalg!(opt_alg::OptimisationAlgorithm;θ,batch_size,x,y,rng = Random.GLOBAL_RNG) = nothing
#=
if rShuffle
# random shuffle x and y
ridx = shuffle(1:size(x)[1])
x = x[ridx, :]
y = y[ridx , :]
end
ϵ = 0
#η = dyn_η ? 1/(1+t) : η
ηₜ = η(t)*λ
for i in 1:size(x)[1]
xᵢ = x[i,:]'
yᵢ = y[i,:]'
W = get_params(nn)
dW = get_gradient(nn,xᵢ,yᵢ)
newW = gradientDescentSingleUpdate(W,dW,ηₜ)
set_params!(nn,newW)
ϵ += loss(nn,xᵢ,yᵢ)
end
if nMsgs != 0 && (t % ceil(epochs/nMsgs) == 0 || t == 1 || t == epochs)
println("Avg. error after epoch $t : $(ϵ/size(x)[1])")
end
if abs(ϵl/size(x)[1] - ϵ/size(x)[1]) < (tol * abs(ϵl/size(x)[1]))
if nMsgs != 0
println((tol * abs(ϵl/size(x)[1])))
println("*** Avg. error after epoch $t : $(ϵ/size(x)[1]) (convergence reached")
end
converged = true
break
else
ϵl = ϵ
end
end
if nMsgs != 0 && converged == false
println("*** Avg. error after epoch $epochs : $(ϵ/size(x)[1]) (convergence not reached)")
end
nn.trained = true
end
=#
# ------------------------------------------------------------------------------
# V2 Api
#$([println(\"- $(i)\" for i in subtypes(AbstractLayer)])
# $(subtypes(AbstractLayer))
#
"""
**`$(TYPEDEF)`**
Hyperparameters for the `Feedforward` neural network model
## Parameters:
$(FIELDS)
To know the available layers type `subtypes(AbstractLayer)`) and then type `?LayerName` for information on how to use each layer.
"""
Base.@kwdef mutable struct NNHyperParametersSet <: BetaMLHyperParametersSet
"Array of layer objects [def: `nothing`, i.e. basic network]. See `subtypes(BetaML.AbstractLayer)` for supported layers"
layers::Union{Array{AbstractLayer,1},Nothing} = nothing
"""Loss (cost) function [def: `squared_cost`]
It must always assume y and ŷ as (n x d) matrices, eventually using `dropdims` inside.
"""
loss::Union{Nothing,Function} = squared_cost
"Derivative of the loss function [def: `dsquared_cost` if `loss==squared_cost`, `nothing` otherwise, i.e. use the derivative of the squared cost or autodiff]"
dloss::Union{Function,Nothing} = nothing
"Number of epochs, i.e. passages trough the whole training sample [def: `1000`]"
epochs::Int64 = 100
"Size of each individual batch [def: `32`]"
batch_size::Int64 = 32
"The optimisation algorithm to update the gradient at each batch [def: `ADAM()`]"
opt_alg::OptimisationAlgorithm = ADAM()
"Whether to randomly shuffle the data at each iteration (epoch) [def: `true`]"
shuffle::Bool = true
"""
The method - and its parameters - to employ for hyperparameters autotuning.
See [`SuccessiveHalvingSearch`](@ref) for the default method.
To implement automatic hyperparameter tuning during the (first) `fit!` call simply set `autotune=true` and eventually change the default `tunemethod` options (including the parameter ranges, the resources to employ and the loss function to adopt).
"""
tunemethod::AutoTuneMethod = SuccessiveHalvingSearch(hpranges = Dict("epochs"=>[50,100,150],"batch_size"=>[2,4,8,16,32],"opt_alg"=>[SGD(λ=2),SGD(λ=1),SGD(λ=3),ADAM(λ=0.5),ADAM(λ=1),ADAM(λ=0.25)], "shuffle"=>[false,true]),multithreads=false)
end
"""
NeuralNetworkEstimatorOptionsSet
A struct defining the options used by the Feedforward neural network model
## Parameters:
$(FIELDS)
"""
Base.@kwdef mutable struct NeuralNetworkEstimatorOptionsSet
"Cache the results of the fitting stage, as to allow predict(mod) [default: `true`]. Set it to `false` to save memory for large data."
cache::Bool = true
"An optional title and/or description for this model"
descr::String = ""
"The verbosity level to be used in training or prediction (see [`Verbosity`](@ref)) [deafult: `STD`]
"
verbosity::Verbosity = STD
"A call back function to provide information during training [def: `fitting_info`"
cb::Function=fitting_info
"0ption for hyper-parameters autotuning [def: `false`, i.e. not autotuning performed]. If activated, autotuning is performed on the first `fit!()` call. Controll auto-tuning trough the option `tunemethod` (see the model hyper-parameters)"
autotune::Bool = false
"Random Number Generator (see [`FIXEDSEED`](@ref)) [deafult: `Random.GLOBAL_RNG`]
"
rng::AbstractRNG = Random.GLOBAL_RNG
end
Base.@kwdef mutable struct NeuralNetworkEstimatorLearnableParameters <: BetaMLLearnableParametersSet
nnstruct::Union{Nothing,NN} = nothing
end
"""
**`NeuralNetworkEstimator`**
A "feedforward" neural network (supervised).
For the parameters see [`NNHyperParametersSet`](@ref).
# Notes:
- data must be numerical
- the label can be a _n-records_ vector or a _n-records_ by _n-dimensions_ matrix, but the result is always a matrix.
- For one-dimension regressions drop the unnecessary dimension with `dropdims(ŷ,dims=2)`
- For classification tasks the columns should normally be interpreted as the probabilities for each categories
# Examples:
- Classification...
```julia
julia> using BetaML
julia> X = [1.8 2.5; 0.5 20.5; 0.6 18; 0.7 22.8; 0.4 31; 1.7 3.7];
julia> y = ["a","b","b","b","b","a"];
julia> ohmod = OneHotEncoder()
A OneHotEncoder BetaMLModel (unfitted)
julia> y_oh = fit!(ohmod,y)
6×2 Matrix{Bool}:
1 0
0 1
0 1
0 1
0 1
1 0
julia> layers = [DenseLayer(2,6),DenseLayer(6,2),VectorFunctionLayer(2,f=softmax)];
julia> m = NeuralNetworkEstimator(layers=layers,opt_alg=ADAM(),epochs=300,verbosity=LOW)
NeuralNetworkEstimator - A Feed-forward neural network (unfitted)
julia> ŷ_prob = fit!(m,X,y_oh)
***
*** Training for 300 epochs with algorithm ADAM.
Training.. avg ϵ on (Epoch 1 Batch 1): 0.4116936481380642
Training of 300 epoch completed. Final epoch error: 0.44308719831108734.
6×2 Matrix{Float64}:
0.853198 0.146802
0.0513715 0.948629
0.0894273 0.910573
0.0367079 0.963292
0.00548038 0.99452
0.808334 0.191666
julia> ŷ = inverse_predict(ohmod,ŷ_prob)
6-element Vector{String}:
"a"
"b"
"b"
"b"
"b"
"a"
```
- Regression...
```julia
julia> using BetaML
julia> X = [1.8 2.5; 0.5 20.5; 0.6 18; 0.7 22.8; 0.4 31; 1.7 3.7];
julia> y = 2 .* X[:,1] .- X[:,2] .+ 3;
julia> layers = [DenseLayer(2,6),DenseLayer(6,6),DenseLayer(6,1)];
julia> m = NeuralNetworkEstimator(layers=layers,opt_alg=ADAM(),epochs=3000,verbosity=LOW)
NeuralNetworkEstimator - A Feed-forward neural network (unfitted)
julia> ŷ = fit!(m,X,y);
***
*** Training for 3000 epochs with algorithm ADAM.
Training.. avg ϵ on (Epoch 1 Batch 1): 33.30063874270561
Training of 3000 epoch completed. Final epoch error: 34.61265465430473.
julia> hcat(y,ŷ)
6×2 Matrix{Float64}:
4.1 4.11015
-16.5 -16.5329
-13.8 -13.8381
-18.4 -18.3876
-27.2 -27.1667
2.7 2.70542
```
"""
mutable struct NeuralNetworkEstimator <: BetaMLSupervisedModel
hpar::NNHyperParametersSet
opt::NeuralNetworkEstimatorOptionsSet
par::Union{Nothing,NeuralNetworkEstimatorLearnableParameters}
cres::Union{Nothing,AbstractArray}
fitted::Bool
info::Dict{String,Any}
end
function NeuralNetworkEstimator(;kwargs...)
m = NeuralNetworkEstimator(NNHyperParametersSet(),NeuralNetworkEstimatorOptionsSet(),NeuralNetworkEstimatorLearnableParameters(),nothing,false,Dict{Symbol,Any}())
thisobjfields = fieldnames(nonmissingtype(typeof(m)))
for (kw,kwv) in kwargs
found = false
for f in thisobjfields
fobj = getproperty(m,f)
if kw in fieldnames(typeof(fobj))
setproperty!(fobj,kw,kwv)
found = true
end
end
found || error("Keyword \"$kw\" is not part of this model.")
end
# Special correction for NNHyperParametersSet
kwkeys = keys(kwargs) #in(2,[1,2,3])
#if !in(:dloss,kwkeys) # if dloss in not explicitly provided
# if (in(:loss,kwkeys) && kwargs[:loss] == squared_cost ) || # loss is explicitly provided and it is equal to squared_loss
# (!in(:loss,kwkeys) ) # (or) loss in not explicitly provided
# m.hpar.dloss = dsquared_cost
# end
#end
if !in(:dloss,kwkeys) # if dloss in not explicitly provided
m.hpar.dloss = match_known_derivatives(m.hpar.loss)
end
return m
end
function fit!(m::NeuralNetworkEstimator,X,Y)
(m.fitted) || autotune!(m,(X,Y))
# Parameter alias..
layers = m.hpar.layers
loss = m.hpar.loss
dloss = m.hpar.dloss
epochs = m.hpar.epochs
batch_size = m.hpar.batch_size
opt_alg = m.hpar.opt_alg
shuffle = m.hpar.shuffle
cache = m.opt.cache
descr = m.opt.descr
verbosity = m.opt.verbosity
cb = m.opt.cb