Skip to content

Commit

Permalink
speedup lr_mult=0 by skipping the gradient computation.
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 21, 2016
1 parent b5a238a commit a949bfa
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examples/cifar10-convnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: cifar10_convnet.py
# File: cifar10-convnet.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>

import tensorflow as tf
Expand Down
3 changes: 1 addition & 2 deletions tensorpack/dataflow/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def __init__(self, ds, nr_prefetch, nr_proc=1):
x.start()

def get_data(self):
tot_cnt = 0
for _ in range(tot_cnt):
for _ in range(self._size):
dp = self.queue.get()
yield dp

Expand Down
10 changes: 7 additions & 3 deletions tensorpack/tfutils/gradproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,18 @@ def __init__(self, multipliers):
self.multipliers = multipliers

def _process(self, grads):
# TODO use None for zero can speed up (or not)?
ret = []
for grad, var in grads:
varname = var.op.name
for regex, val in self.multipliers:
if re.search(regex, varname):
# always match against the whole name
if not regex.endswith('$'):
regex = regex + '$'

if re.match(regex, varname):
logger.info("Apply lr multiplier {} for {}".format(val, varname))
ret.append((grad * val, var))
if val != 0: # skip zero to speed up
ret.append((grad * val, var))
break
else:
ret.append((grad, var))
Expand Down

0 comments on commit a949bfa

Please sign in to comment.