Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
106 lines (77 sloc) 2.59 KB
-- Joost van Amersfoort - <joost@joo.st>
require 'torch'
require 'nn'
require 'optim'
local VAE = require 'VAE'
require 'KLDPenalty'
require 'GaussianCriterion'
require 'Sampler'
--For loading data files
require 'load'
local continuous = false
data = load(continuous)
local input_size = data.train:size(2)
local latent_variable_size = 20
local hidden_layer_size = 400
local batch_size = 100
torch.manualSeed(1)
local encoder = VAE.get_encoder(input_size, hidden_layer_size, latent_variable_size)
local decoder = VAE.get_decoder(input_size, hidden_layer_size, latent_variable_size, continuous)
local KLD = nn.KLDPenalty()
local model = nn.Sequential()
model:add(encoder)
model:add(KLD)
model:add(nn.Sampler())
model:add(decoder)
if continuous then
criterion = nn.GaussianCriterion()
else
criterion = nn.BCECriterion()
criterion.sizeAverage = false
end
local parameters, gradients = model:getParameters()
local config = {
learningRate = 0.001
}
local state = {}
epoch = 0
while true do
epoch = epoch + 1
local lowerbound = 0
local tic = torch.tic()
local shuffle = torch.randperm(data.train:size(1))
-- This batch creation is inspired by szagoruyko CIFAR example.
local indices = torch.randperm(data.train:size(1)):long():split(batch_size)
indices[#indices] = nil
local N = #indices * batch_size
local tic = torch.tic()
for t,v in ipairs(indices) do
xlua.progress(t, #indices)
local inputs = data.train:index(1,v)
local opfunc = function(x)
if x ~= parameters then
parameters:copy(x)
end
model:zeroGradParameters()
local reconstruction = model:forward(inputs)
local err = criterion:forward(reconstruction, inputs)
local df_dw = criterion:backward(reconstruction, inputs)
model:backward(inputs, df_dw)
local batchlowerbound = err + KLD.loss
return batchlowerbound, gradients
end
x, batchlowerbound = optim.adam(opfunc, parameters, config, state)
lowerbound = lowerbound + batchlowerbound[1]
end
print("Epoch: " .. epoch .. " Lowerbound: " .. lowerbound/N .. " time: " .. torch.toc(tic))
if lowerboundlist then
lowerboundlist = torch.cat(lowerboundlist,torch.Tensor(1,1):fill(lowerbound/N),1)
else
lowerboundlist = torch.Tensor(1,1):fill(lowerbound/N)
end
if epoch % 2 == 0 then
torch.save('save/parameters.t7', parameters)
torch.save('save/state.t7', state)
torch.save('save/lowerbound.t7', torch.Tensor(lowerboundlist))
end
end