-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tf.function block thread #51244
Comments
@yangtao121 In order to expedite the trouble-shooting process, please provide the complete code snippet to reproduce the issue reported here. please do let us know which version of tf you are using. Thanks! |
My Tensorflow version is 2.5.0. This code turns on multithreading(multiprocessing): threads = [mp.Process(target=self.rolling, args=[env_name]) for _ in range(self.multi_worker_num)]
def rolling(self, env_name):
env = gym.make(env_name).unwrapped
action_queue = Queue()
worker = Worker(
env=env,
env_args=self.env_args,
hyper_parameter=self.hyper_parameter,
get_data_queue=action_queue
)
policy = Gaussian_policy(output_queue=action_queue)
critic = Critic()
for _ in range(self.epochs):
policy.load_model('data/policy.h5')
critic.load_model('data/critic.h5')
self.ROLLING_EVENT.wait()
worker.update(policy, critic)
batch = worker.runner()
self.batches.put(batch)
if self.batches.qsize() < self.multi_worker_num - 1:
self.UPDATE_EVENT.wait()
# if self.roll_flag < self.multi_worker_num:
else:
self.roll_flag = 0
self.UPDATE_EVENT.set()
self.ROLLING_EVENT.clear() The
def runner(self):
# print('start')
batches = []
for i in range(self.trajs):
collector = Collector(observation_dims=self.obs_dims, action_dims=self.act_dims,
episode_length=self.steps)
state = self.env.reset()
# print(i)
for t in range(self.steps):
state = state.reshape(1, -1)
print(state)
action, prob = self.policy.get_action(state)
action_ = action * 2
state_, reward, done, _ = self.env.step(action_)
collector.store(state, action, reward, prob)
state = state_
if (t + 1) % self.batch_size == 0 or t == self.steps - 1:
observations, reward = collector.get_current_data()
value_ = self.critic.get_value(state_.reshape(1, -1))
values = self.critic.get_value(observations)
gae, target = gae_target(self.gamma, self.lambada, reward, values, value_, done)
collector.get_gae_target(gae, target)
batches.append(collector)
return batches I found when I use multiprocessing, the thread will block at And the complete code can be found in my github. Thanks |
I tried to run this code in multiprocess by
multiprocessing.Process
:the tf.function would block the thread.
The text was updated successfully, but these errors were encountered: