Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed May 1, 2016
1 parent aed3438 commit b059ce4
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 2 deletions.
4 changes: 3 additions & 1 deletion tensorpack/dataflow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __init__(self, ds, func):
:param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a original datapoint, returns a new
datapoint. return None to skip this data point.
Note that if you use filter, ds.size() won't be correct.
"""
super(MapData, self).__init__(ds)
self.func = func
Expand All @@ -170,7 +171,8 @@ def __init__(self, ds, func, index=0):
"""
:param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a datapoint component dp[index], returns a
new value of dp[index]. return None to skip this datapoint.
new value of dp[index]. return None to skip this datapoint.
Note that if you use filter, ds.size() won't be correct.
"""
super(MapDataComponent, self).__init__(ds)
self.func = func
Expand Down
7 changes: 7 additions & 0 deletions tensorpack/models/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
# make sure each layer is only logged once
_layer_logged = set()

def disable_layer_logging():
class ContainEverything:
def __contains__(self, x):
return True
# can use nonlocal in python3, but how
globals()['_layer_logged'] = ContainEverything()

def layer_register(summary_activation=False, log_shape=True):
"""
Register a layer.
Expand Down
6 changes: 5 additions & 1 deletion tensorpack/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,13 @@ def __init__(self, idx, gpuid, inqueue, outqueue, config):
self.config = config

def run(self):
logger.info("Worker {} use GPU {}".format(self.idx, self.gpuid))
os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid
G = tf.Graph() # build a graph for each process, because they don't need to share anything
with G.as_default(), tf.device('/gpu:0'):
if self.idx != 0:
from tensorpack.models._common import disable_layer_logging
disable_layer_logging()
self.func = get_predict_func(self.config)
if self.idx == 0:
describe_model()
Expand Down Expand Up @@ -173,13 +177,13 @@ def get_result(self):
die_cnt = 0
while True:
res = self.result_queue.get()
pbar.update()
if res[0] != DIE:
yield res[1]
else:
die_cnt += 1
if die_cnt == self.nr_gpu:
break
pbar.update()
self.inqueue_proc.join()
self.inqueue_proc.terminate()
for p in self.workers:
Expand Down
1 change: 1 addition & 0 deletions tensorpack/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import tensorflow as tf
import threading
import time
import copy
import re
import functools
Expand Down

0 comments on commit b059ce4

Please sign in to comment.