Permalink
Browse files

NNgraph based VAE

This update adds nngraph and cleans up the custom modules.

The total speed up is about 30% (only need a single backward pass now)
and the convergence is much much faster (use of ReLU and Adam)
  • Loading branch information...
y0ast committed Dec 1, 2015
1 parent 84df2c9 commit 5fc1405a03dcaf0c3b1ade6986e9c16d1151f6c4
Showing with 267 additions and 363 deletions.
  1. +16 −8 GaussianCriterion.lua
  2. +19 −7 KLDCriterion.lua
  3. +0 −12 LinearVA.lua
  4. +0 −34 Reparametrize.lua
  5. +36 −0 Sampler.lua
  6. +40 −0 VAE.lua
  7. +0 −131 binaryva.lua
  8. +0 −139 continuousva.lua
  9. +0 −20 convert.py
  10. BIN datasets/test_32x32.t7
  11. BIN datasets/train_32x32.t7
  12. +22 −12 load.lua
  13. +134 −0 main.lua
View
@@ -3,19 +3,27 @@ require 'nn'
local GaussianCriterion, parent = torch.class('nn.GaussianCriterion', 'nn.Criterion')
function GaussianCriterion:updateOutput(input, target)
-- Verify again for correct handling of 0.5 multiplication
local Gelement = torch.add(input[2],math.log(2 * math.pi)):mul(-0.5)
Gelement:add(-1,torch.add(target,-1,input[1]):pow(2):cdiv(torch.exp(input[2])):mul(0.5))
-- - log(sigma) - 0.5 *(2pi)) - 0.5 * (x - mu)^2/sigma^2
-- input[1] = mu
-- input[2] = log(sigma^2)
local Gelement = torch.mul(input[2],0.5):add(0.5 * math.log(2 * math.pi))
Gelement:add(torch.add(target,-1,input[1]):pow(2):cdiv(torch.exp(input[2])):mul(0.5))
self.output = torch.sum(Gelement)
return self.output
end
function GaussianCriterion:updateGradInput(input, target)
-- Verify again for correct handling of 0.5 multiplication
self.gradInput = {}
self.gradInput[1] = torch.exp(-input[2]):cmul(torch.add(target,-1,input[1]))
self.gradInput[2] = torch.exp(-input[2]):cmul(torch.add(target,-1,input[1]):pow(2)):add(-0.5)
self.gradInput = {}
-- (x - mu) / sigma^2 --> (1 / sigma^2 = exp(-log(sigma^2)) )
self.gradInput[1] = torch.exp(-input[2]):cmul(torch.add(target,-1,input[1])):mul(-1)
-- - 0.5 + 0.5 * (x - mu)^2 / sigma^2
self.gradInput[2] = torch.exp(-input[2]):cmul(torch.add(target,-1,input[1]):pow(2)):mul(-1):add(0.5)
return self.gradInput
end
end
View
@@ -1,17 +1,29 @@
local KLDCriterion, parent = torch.class('nn.KLDCriterion', 'nn.Criterion')
function KLDCriterion:updateOutput(input, target)
-- 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
function KLDCriterion:updateOutput(mean, log_var)
-- Appendix B from VAE paper: 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
local mean_sq = torch.pow(mean, 2)
local KLDelements = log_var:clone()
KLDelements:exp():mul(-1)
KLDelements:add(-1, mean_sq)
KLDelements:add(1)
KLDelements:add(log_var)
self.output = -0.5 * torch.sum(KLDelements)
local KLDelement = (input[2] + 1):add(-1,torch.pow(input[1],2)):add(-1,torch.exp(input[2]))
self.output = 0.5 * torch.sum(KLDelement)
return self.output
end
function KLDCriterion:updateGradInput(input, target)
function KLDCriterion:updateGradInput(mean, log_var)
self.gradInput = {}
self.gradInput[1] = (-input[1]):clone()
self.gradInput[2] = (-torch.exp(input[2])):add(1):mul(0.5)
self.gradInput[1] = mean:clone()
-- Fix this to be nicer
self.gradInput[2] = torch.exp(log_var):mul(-1):add(1):mul(-0.5)
return self.gradInput
end
View

This file was deleted.

Oops, something went wrong.
View

This file was deleted.

Oops, something went wrong.
View
@@ -0,0 +1,36 @@
-- Based on JoinTable module
require 'nn'
local Sampler, parent = torch.class('nn.Sampler', 'nn.Module')
function Sampler:__init()
parent.__init(self)
self.gradInput = {}
end
function Sampler:updateOutput(input)
self.eps = self.eps or input[1].new()
self.eps:resizeAs(input[1]):copy(torch.randn(input[1]:size()))
self.ouput = self.output or self.output.new()
self.output:resizeAs(input[2]):copy(input[2])
self.output:mul(0.5):exp():cmul(self.eps)
self.output:add(input[1])
return self.output
end
function Sampler:updateGradInput(input, gradOutput)
self.gradInput[1] = self.gradInput[1] or input[1].new()
self.gradInput[1]:resizeAs(gradOutput):copy(gradOutput)
self.gradInput[2] = self.gradInput[2] or input[2].new()
self.gradInput[2]:resizeAs(gradOutput):copy(input[2])
self.gradInput[2]:mul(0.5):exp():mul(0.5):cmul(self.eps)
self.gradInput[2]:cmul(gradOutput)
return self.gradInput
end
View
40 VAE.lua
@@ -0,0 +1,40 @@
require 'torch'
require 'nn'
local VAE = {}
function VAE.get_encoder(input_size, hidden_layer_size, latent_variable_size)
-- The Encoder
local encoder = nn.Sequential()
encoder:add(nn.Linear(input_size, hidden_layer_size))
encoder:add(nn.ReLU(true))
mean_logvar = nn.ConcatTable()
mean_logvar:add(nn.Linear(hidden_layer_size, latent_variable_size))
mean_logvar:add(nn.Linear(hidden_layer_size, latent_variable_size))
encoder:add(mean_logvar)
return encoder
end
function VAE.get_decoder(input_size, hidden_layer_size, latent_variable_size, continuous)
-- The Decoder
local decoder = nn.Sequential()
decoder:add(nn.Linear(latent_variable_size, hidden_layer_size))
decoder:add(nn.ReLU(true))
if continuous then
mean_logvar = nn.ConcatTable()
mean_logvar:add(nn.Linear(hidden_layer_size, input_size))
mean_logvar:add(nn.Linear(hidden_layer_size, input_size))
decoder:add(mean_logvar)
else
decoder:add(nn.Linear(hidden_layer_size, input_size))
decoder:add(nn.Sigmoid(true))
end
return decoder
end
return VAE
View

This file was deleted.

Oops, something went wrong.
Oops, something went wrong.

0 comments on commit 5fc1405

Please sign in to comment.