-
Notifications
You must be signed in to change notification settings - Fork 0
/
iris.rb
67 lines (49 loc) · 2.61 KB
/
iris.rb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
require_relative 'mlp'
require_relative 'dataset'
# --------------- データセットの準備 -----------------
iris_dataset = Dataset.get_iris
# データセットを分割する
train_size, valid_size = 100, 30
train, other = Chainer::Datasets.split_dataset_random(iris_dataset, train_size)
valid, test = Chainer::Datasets.split_dataset_random(other, valid_size)
# --------------- イテレータの準備 -----------------
batch_size = 4
SerialIterator = Chainer::Iterators::SerialIterator
train_iter = SerialIterator.new(train, batch_size)
valid_iter = SerialIterator.new(valid, batch_size, repeat: false, shuffle: false)
# --------------- ネットワークの準備 -----------------
predictor = MLP.new(hidden_nodes_size: 100, output_size: 3)
# --------------- アップデータの準備 -----------------
model = Chainer::Links::Model::Classifier.new(predictor)
optimizer = Chainer::Optimizers::MomentumSGD.new(lr: 0.01)
optimizer.setup(model)
updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer)
# --------------- トレーナの作成 -----------------
output_dir = "results/iris_result_#{Time.now.strftime('%Y%m%d_%H%M%S')}"
epoch_size = 30
trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [epoch_size, 'epoch'], out: output_dir)
# --------------- トレーナの拡張 -----------------
Extensions = Chainer::Training::Extensions
trainer.extend(Extensions::Evaluator.new(valid_iter, model, device: -1), name: 'val')
trainer.extend(Extensions::LogReport.new(trigger: [1, 'epoch'], log_name: 'log'))
filename_proc = Proc.new { |t| format('snapshot_epoch-%02d', t.updater.epoch) }
trainer.extend(Extensions::Snapshot.new(filename_proc: filename_proc), trigger: [1, 'epoch'])
entries = %w[epoch iteration main/loss main/accuracy val/main/loss val/main/accuracy elapsed_time]
trainer.extend(Extensions::PrintReport.new(entries))
trainer.extend(Extensions::ProgressBar.new)
# --------------- 訓練の開始 -----------------
trainer.run
# --------------- 推論 -----------------
predictor = MLP.new(hidden_nodes_size: 100, output_size: 3)
snapshot_filename = "#{output_dir}/#{format('snapshot_epoch-%02d', epoch_size)}"
path = '/updater/model:main/@predictor/'
Chainer::Serializers::MarshalDeserializer.load_file(snapshot_filename, predictor, path: path)
print '-' * 100 + "\n"
pass_count = 0
(0...test.size).each do |i|
variables, answer = test[i]
prediction = predictor.(variables).data.argmax # 推論結果を取得
pass_count += 1 if prediction == answer
print format("test%03d: prediction = %d, answer = %d\n",i + 1, prediction, answer)
end
print "accuracy: #{pass_count * 100.0 / test.size}\n"