In [115]:
torch.setdefaulttensortype('torch.FloatTensor')
torch.manualSeed(666)

In [116]:
mnist = require 'mnist'

train = mnist.traindataset()
test = mnist.testdataset()

local mean, std

for _,set in ipairs{train, test} do
    set.data = set.data:float()
    
    if not mean then
        mean = train.data:mean()
        std = train.data:std()
    end
    set.data:add(-mean)
    set.data:div(std)
    
    set.label:add(1)
end

collectgarbage()

In [117]:
require 'nn'

net = nn.Sequential()

net:add(nn.SpatialConvolution(1, 64, 5, 5, 1, 1, 0, 0))
net:add(nn.SpatialBatchNormalization(64))
net:add(nn.ReLU())
net:add(nn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
net:add(nn.SpatialBatchNormalization(64))
net:add(nn.ReLU())
net:add(nn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
net:add(nn.SpatialBatchNormalization(64))
net:add(nn.ReLU())

net:add(nn.SpatialMaxPooling(2, 2, 2, 2))

net:add(nn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
net:add(nn.SpatialBatchNormalization(64))
net:add(nn.ReLU())
net:add(nn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
net:add(nn.SpatialBatchNormalization(64))
net:add(nn.ReLU())
net:add(nn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
net:add(nn.SpatialBatchNormalization(64))

net:add(nn.View(64, -1):setNumInputDims(3))
net:add(nn.Mean(3, 3))
net:add(nn.BatchNormalization(64))

net:add(nn.Linear(64, 10))

params, gradParams = net:getParameters()

collectgarbage()

In [118]:
require 'optim'

datasetIdx = 1

optimState = {
    learningRate = 1e-2,
    momentum = 0.9,
    nesterov = true,
    dampening = 0
}

crit = nn.CrossEntropyCriterion()

In [121]:
logger = optim.Logger('/tmp/mnist-acc.log')
logger:setNames{'Training loss'}
logger:style{'-'}
-- logger:showPlot(false)

In [156]:
function fillBatch(batch, labels, set, currIdx, idxs)
    for k = 1,batch:size(1) do
        batch[k]:copy(set.data[idxs and idxs[currIdx] or currIdx])
        labels[k] = set.label[idxs and idxs[currIdx] or currIdx]
        
        currIdx = currIdx + 1
        if currIdx > set.size then currIdx = 1 end
    end
    
    return currIdx
end

In [138]:
batchSize = 128
batch = torch.FloatTensor(batchSize, 1, 28, 28)
labels = torch.IntTensor(batchSize)

function randomShuffle(t)
  for i = 1,#t do
    local j = math.random(i, #t)
    t[i], t[j] = t[j], t[i]
  end
end

idx = {}
for i = 1,train.data:size(1) do
    idx[i] = i
end

for iter = 1,216 do
    local oldIdx = datasetIdx
    datasetIdx = fillBatch(batch, labels, train, datasetIdx, idx)
    if datasetIdx < oldIdx then
        print('New epoch')
        randomShuffle(idx)
        datasetIdx = 1
    end
    
    timer = torch.Timer()
    local loss = crit:forward(net:forward(batch), labels)
    crit:backward(net.output, labels)
    net:zeroGradParameters()
    net:backward(batch, crit.gradInput)
    print(timer:time().real .. ' sec')
    
    local feval = function(x)
        return batchLoss, gradParams
    end

    optim.sgd(feval, params, optimState)
    
    logger:add{loss}
    logger:plot()
    
    collectgarbage() 
end

3.7428669929504 sec	


3.8177711963654 sec	


3.6353871822357 sec	


3.8255701065063 sec	


3.4845149517059 sec	


3.8213200569153 sec	


4.1183891296387 sec	


3.8600790500641 sec	


4.0470020771027 sec	


4.1586771011353 sec	


4.6148080825806 sec	


4.1103229522705 sec	


4.0613470077515 sec	


3.8621652126312 sec	


3.9348700046539 sec	


3.7357759475708 sec	


3.9672770500183 sec	


3.7051191329956 sec	


3.6770942211151 sec	


4.0556600093842 sec	


3.5472049713135 sec	


3.7855041027069 sec	




3.799017906189 sec	


3.8076100349426 sec	


3.6947710514069 sec	


3.9989001750946 sec	


3.8418128490448 sec	


3.7173519134521 sec	


3.7665469646454 sec	


3.9063141345978 sec	


4.0137658119202 sec	


3.6399519443512 sec	


3.7904071807861 sec	


3.4006040096283 sec	


3.5161001682281 sec	


3.9140059947968 sec	


3.6963529586792 sec	


3.9906239509583 sec	


3.628427028656 sec	


3.8164949417114 sec	


4.5324461460114 sec	


4.1312961578369 sec	


3.8290498256683 sec	


3.718456029892 sec	


4.0324001312256 sec	


3.6380350589752 sec	


3.5962538719177 sec	


3.6068332195282 sec	


3.7317268848419 sec	


3.4868340492249 sec	


3.7073531150818 sec	


3.8295249938965 sec	


3.8143770694733 sec	


3.9496388435364 sec	


3.8532478809357 sec	


3.3888788223267 sec	


3.7892379760742 sec	




3.7902410030365 sec	


3.7840809822083 sec	


3.7818210124969 sec	


3.9604690074921 sec	


3.7962081432343 sec	


3.7352981567383 sec	


4.1072828769684 sec	


3.7782020568848 sec	


3.8672440052032 sec	


3.544224023819 sec	


3.8334360122681 sec	


4.3112349510193 sec	


3.7981469631195 sec	


3.7749481201172 sec	


3.5278651714325 sec	


3.6367018222809 sec	


3.5067238807678 sec	


3.6754989624023 sec	


3.9314029216766 sec	


4.0563941001892 sec	


4.116681098938 sec	


4.5268688201904 sec	


4.4030010700226 sec	


4.6239750385284 sec	


4.2581949234009 sec	


4.2787621021271 sec	


3.6963901519775 sec	


4.0396339893341 sec	


3.9739410877228 sec	


3.930135011673 sec	


4.007719039917 sec	


4.2263112068176 sec	


3.6999909877777 sec	


3.6168520450592 sec	


3.9311480522156 sec	


3.8180041313171 sec	


3.80823802948 sec	


4.3725209236145 sec	


3.9630630016327 sec	


3.8758950233459 sec	


4.0865461826324 sec	


4.0027639865875 sec	


4.0498569011688 sec	


3.8722290992737 sec	




3.9268620014191 sec	


3.7351050376892 sec	


3.9398739337921 sec	


4.2678861618042 sec	


4.3625528812408 sec	


4.1377182006836 sec	


4.4361548423767 sec	


3.9870519638062 sec	


4.2545919418335 sec	


3.8598780632019 sec	


4.0901539325714 sec	


3.9357750415802 sec	


4.1904740333557 sec	


4.0565149784088 sec	


4.1539881229401 sec	


4.1398301124573 sec	


3.8455331325531 sec	


4.169440984726 sec	


4.3633410930634 sec	


4.1441221237183 sec	


4.0843331813812 sec	


3.9815628528595 sec	


4.3074090480804 sec	


4.1145489215851 sec	


4.1839108467102 sec	


4.258859872818 sec	


4.1089549064636 sec	


4.3005709648132 sec	


3.9391560554504 sec	


4.1714358329773 sec	


3.9913988113403 sec	


4.1355979442596 sec	


4.401654958725 sec	


4.4819850921631 sec	


4.3008420467377 sec	


4.0840318202972 sec	


4.249596118927 sec	


4.3944530487061 sec	


4.2615120410919 sec	


4.4358229637146 sec	


4.3780200481415 sec	


4.6183559894562 sec	


4.1638898849487 sec	


4.1894679069519 sec	


4.1142601966858 sec	


4.4282259941101 sec	


4.4675540924072 sec	


4.3455538749695 sec	


4.6274130344391 sec	


4.5576708316803 sec	


4.5133781433105 sec	


5.0589230060577 sec	


4.6414160728455 sec	


4.7061719894409 sec	


4.4476358890533 sec	


4.225349187851 sec	


4.6586880683899 sec	


4.5711531639099 sec	


4.5059118270874 sec	


4.5565450191498 sec	


4.5757610797882 sec	


4.2839660644531 sec	


4.5589768886566 sec	


4.5473771095276 sec	


4.7804811000824 sec	


4.3874509334564 sec	


4.4595370292664 sec	


4.8002691268921 sec	


4.6731770038605 sec	


4.6608300209045 sec	


4.6957530975342 sec	


4.3701410293579 sec	


4.5438928604126 sec	


4.9340870380402 sec	


4.7161200046539 sec	


4.8309450149536 sec	




4.7065460681915 sec	


4.7108240127563 sec	


5.0069620609283 sec	


4.3762259483337 sec	


4.8729031085968 sec	


4.9488689899445 sec	


4.5722880363464 sec	


4.6783931255341 sec	


4.9038767814636 sec	


4.7110209465027 sec	


4.9838569164276 sec	


4.7487030029297 sec	


4.5725982189178 sec	


4.9300599098206 sec	


4.910071849823 sec	


4.9162390232086 sec	


4.9735350608826 sec	


5.1201920509338 sec	


5.3425378799438 sec	


5.0124809741974 sec	


5.2273149490356 sec	


5.4940741062164 sec	


5.2283320426941 sec	


5.1735470294952 sec	


4.8936779499054 sec	


5.1400470733643 sec	


5.1807520389557 sec	


5.0340600013733 sec	


5.0195231437683 sec	


5.2961649894714 sec	


4.781674861908 sec	




5.1871540546417 sec	


5.2872149944305 sec	


5.2169251441956 sec	


5.3063199520111 sec	


5.0324149131775 sec	


4.7635970115662 sec	


4.8305861949921 sec	


New epoch	


4.8851311206818 sec	


In [170]:
-- Test
require 'xlua'

local testBatchSize = 100
assert(test.size % testBatchSize == 0)

local correctCount = 0
local batchCount = 0
local testDatasetIdx = 1
local batch = torch.FloatTensor(testBatchSize, 1, 28, 28)
local labels = torch.LongTensor(testBatchSize)

repeat
    batchCount = batchCount + 1
    xlua.progress(testBatchSize * batchCount, test.size)
    
    testDatasetIdx = fillBatch(batch, labels, test, testDatasetIdx)
    net:forward(batch)
    local answer = select(2, net.output:max(2))
    correctCount = correctCount + answer:eq(labels):sum()
    
until testDatasetIdx < testBatchSize

print(correctCount / test.size * 100 .. '%')



99.04%	


In [141]:
net:clearState()
torch.save('mnist-convnet.t7', net)
torch.save('mnist-convnet-optimstate.t7', optimState)