This repository has been archived by the owner on Mar 3, 2020. It is now read-only.
/
mnist.lua
105 lines (93 loc) · 3.14 KB
/
mnist.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
--[[
Copyright (c) 2016-present, Facebook, Inc.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree. An additional grant
of patent rights can be found in the PATENTS file in the same directory.
]]--
-- load torchnet:
local tnt = require 'torchnet'
-- use GPU or not:
local cmd = torch.CmdLine()
cmd:option('-usegpu', false, 'use gpu for training')
local config = cmd:parse(arg)
print(string.format('running on %s', config.usegpu and 'GPU' or 'CPU'))
-- function that sets of dataset iterator:
local function getIterator(mode)
return tnt.ParallelDatasetIterator{
nthread = 1,
init = function() require 'torchnet' end,
closure = function()
-- load MNIST dataset:
local mnist = require 'mnist'
local dataset = mnist[mode .. 'dataset']()
dataset.data = dataset.data:reshape(dataset.data:size(1),
dataset.data:size(2) * dataset.data:size(3)):double():div(256)
-- return batches of data:
return tnt.BatchDataset{
batchsize = 128,
dataset = tnt.ListDataset{ -- replace this by your own dataset
list = torch.range(1, dataset.data:size(1)):long(),
load = function(idx)
return {
input = dataset.data[idx],
target = torch.LongTensor{dataset.label[idx] + 1},
} -- sample contains input and target
end,
}
}
end,
}
end
-- set up logistic regressor:
local net = nn.Sequential():add(nn.Linear(784,10))
local criterion = nn.CrossEntropyCriterion()
-- set up training engine:
local engine = tnt.SGDEngine()
local meter = tnt.AverageValueMeter()
local clerr = tnt.ClassErrorMeter{topk = {1}}
engine.hooks.onStartEpoch = function(state)
meter:reset()
clerr:reset()
end
engine.hooks.onForwardCriterion = function(state)
meter:add(state.criterion.output)
clerr:add(state.network.output, state.sample.target)
if state.training then
print(string.format('avg. loss: %2.4f; avg. error: %2.4f',
meter:value(), clerr:value{k = 1}))
end
end
-- set up GPU training:
if config.usegpu then
-- copy model to GPU:
require 'cunn'
net = net:cuda()
criterion = criterion:cuda()
-- copy sample to GPU buffer:
local igpu, tgpu = torch.CudaTensor(), torch.CudaTensor()
engine.hooks.onSample = function(state)
igpu:resize(state.sample.input:size() ):copy(state.sample.input)
tgpu:resize(state.sample.target:size()):copy(state.sample.target)
state.sample.input = igpu
state.sample.target = tgpu
end -- alternatively, this logic can be implemented via a TransformDataset
end
-- train the model:
engine:train{
network = net,
iterator = getIterator('train'),
criterion = criterion,
lr = 0.2,
maxepoch = 5,
}
-- measure test loss and error:
meter:reset()
clerr:reset()
engine:test{
network = net,
iterator = getIterator('test'),
criterion = criterion,
}
print(string.format('test loss: %2.4f; test error: %2.4f',
meter:value(), clerr:value{k = 1}))