Skip to content

Commit

Permalink
support eval mode
Browse files Browse the repository at this point in the history
  • Loading branch information
wenwei202 committed Aug 24, 2018
1 parent fbb230f commit 769dd72
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 12 deletions.
27 changes: 15 additions & 12 deletions models/noisy_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,23 @@ def __init__(self, a, b, inplace=False):
self.register_buffer('noise', None)

def forward(self, input):
# in case the batch size decreases at the last iteration of an epch
shape = list(input.size())
if self._buffers['device_flag'].is_cuda:
if self._buffers['noise'] is not None:
self._buffers['noise'].uniform_().mul_(self.b-self.a).add_(self.a)
if self.training:
# in case the batch size decreases at the last iteration of an epch
shape = list(input.size())
if self._buffers['device_flag'].is_cuda:
if self._buffers['noise'] is not None:
self._buffers['noise'].uniform_().mul_(self.b-self.a).add_(self.a)
else:
self._buffers['noise'] = torch.cuda.FloatTensor(input.size()).uniform_().mul_(self.b-self.a).add_(self.a)
else:
self._buffers['noise'] = torch.cuda.FloatTensor(input.size()).uniform_().mul_(self.b-self.a).add_(self.a)
if self._buffers['noise'] is not None:
self._buffers['noise'].uniform_().mul_(self.b-self.a).add_(self.a)
else:
self._buffers['noise'] = torch.FloatTensor(input.size()).uniform_().mul_(self.b-self.a).add_(self.a)
#print('noise:', self._buffers['noise'])
return F.threshold(input * torch.autograd.Variable(self._buffers['noise'][0:shape[0]]), 0, 0, self.inplace)
else:
if self._buffers['noise'] is not None:
self._buffers['noise'].uniform_().mul_(self.b-self.a).add_(self.a)
else:
self._buffers['noise'] = torch.FloatTensor(input.size()).uniform_().mul_(self.b-self.a).add_(self.a)
#print('noise:', self._buffers['noise'])
return F.threshold(input * torch.autograd.Variable(self._buffers['noise'][0:shape[0]]), 0, 0, self.inplace)
return F.threshold(input, 0, 0, self.inplace)

def extra_repr(self):
inplace_str = ', inplace' if self.inplace else ''
Expand Down
20 changes: 20 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,23 @@
o.backward()
print('input.grad: ', z.grad)
print('-------------------')

print('=========== eval ============')

m.eval()
for i in range(2):
if 1==i:
x = np.random.rand(5,2).astype('f') - 0.5
else:
x = np.random.rand(1,2).astype('f') - 0.5
x = torch.from_numpy(x)

z = x.cuda()
z = torch.autograd.Variable(z)
m.cuda()

print('input: ', z)

res = m(z)
print('output: ', res)
print('-------------------')

0 comments on commit 769dd72

Please sign in to comment.