forked from shogun-toolbox/shogun
-
Notifications
You must be signed in to change notification settings - Fork 0
/
feedforward_net_classification.sg
44 lines (36 loc) · 1.69 KB
/
feedforward_net_classification.sg
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
File f_feats_train = csv_file("../../data/classifier_binary_2d_nonlinear_features_train.dat")
File f_feats_test = csv_file("../../data/classifier_binary_2d_nonlinear_features_test.dat")
File f_labels_train = csv_file("../../data/classifier_binary_2d_nonlinear_labels_train.dat")
File f_labels_test = csv_file("../../data/classifier_binary_2d_nonlinear_labels_test.dat")
Math:init_random(1)
#![create_features]
Features features_train = features(f_feats_train)
Features features_test = features(f_feats_test)
Labels labels_train = labels(f_labels_train)
Labels labels_test = labels(f_labels_test)
#![create_features]
#![create_instance]
int num_feats = features_train.get_int("num_features")
Machine network = machine("NeuralNetwork", labels=labels_train, auto_quick_initialize=True, l2_coefficient=0.01, dropout_hidden=0.5, max_num_epochs=50, gd_mini_batch_size=num_feats, gd_learning_rate=0.1, gd_momentum=0.9)
#![create_instance]
#![add_layers]
NeuralLayer input = layer("NeuralInputLayer", num_neurons=num_feats)
network.add("layers", input)
NeuralLayer relu = layer("NeuralRectifiedLinearLayer", num_neurons=10)
network.add("layers", relu)
NeuralLayer softmax = layer("NeuralSoftmaxLayer", num_neurons=2)
network.add("layers", softmax)
#![add_layers]
#![train_and_apply]
network.train(features_train)
Labels labels_predict = network.apply(features_test)
#![train_and_apply]
#![get_params]
RealVector parameters = network.get_real_vector("params")
#![get_params]
#![evaluate_accuracy]
Evaluation eval = evaluation("AccuracyMeasure")
real accuracy = eval.evaluate(labels_predict, labels_test)
#![evaluate_accuracy]
# additional integration testing variables
RealVector output = labels_predict.get_real_vector("labels")