Skip to content

Commit

Permalink
Removing in-place updates for batchnorm running variance. This is a m…
Browse files Browse the repository at this point in the history
…emory leak.
  • Loading branch information
Ragav Venkatesan committed Mar 3, 2017
1 parent cdc13af commit db578fe
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
6 changes: 4 additions & 2 deletions yann/layers/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ def __init__ ( self,

mean = theano.tensor.unbroadcast(mean,0)
var = theano.tensor.unbroadcast(var,0)
var = var + 0.000001
self.updates[self.running_mean] = mean
self.updates[self.running_var] = var + 0.001
self.updates[self.running_var] = var

self.inference = batch_normalization_test (
inputs = input,
Expand Down Expand Up @@ -237,8 +238,9 @@ def __init__ ( self,

mean = theano.tensor.unbroadcast(mean,0)
var = theano.tensor.unbroadcast(var,0)
var = var + 0.0000001
self.updates[self.running_mean] = mean
self.updates[self.running_var] = var + 0.001
self.updates[self.running_var] = var

self.inference = batch_normalization_test (
inputs = input,
Expand Down
3 changes: 2 additions & 1 deletion yann/layers/conv_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,9 @@ def __init__ ( self,

mean = theano.tensor.unbroadcast(mean,0)
var = theano.tensor.unbroadcast(var,0)
var = var + 0.000001
self.updates[self.running_mean] = mean
self.updates[self.running_var] = var + 0.001
self.updates[self.running_var] = var

batch_norm_inference = batch_normalization_test (
inputs = pool_out + \
Expand Down
3 changes: 2 additions & 1 deletion yann/layers/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,9 @@ def __init__ (self,

mean = theano.tensor.unbroadcast(mean,0)
var = theano.tensor.unbroadcast(var,0)
var = var + 0. 000001
self.updates[self.running_mean] = mean
self.updates[self.running_var] = var + 0.001
self.updates[self.running_var] = var

batch_norm_inference = batch_normalization_test (inputs = linear_fit,
gamma = self.gamma,
Expand Down

0 comments on commit db578fe

Please sign in to comment.