Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.function block thread #51244

Open
yangtao121 opened this issue Aug 5, 2021 · 2 comments
Open

tf.function block thread #51244

yangtao121 opened this issue Aug 5, 2021 · 2 comments
Assignees
Labels
comp:tf.function tf.function related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.5 Issues related to TF 2.5 type:others issues not falling in bug, perfromance, support, build and install or feature

Comments

@yangtao121
Copy link

I tried to run this code in multiprocess by multiprocessing.Process:

@tf.function
def get_action(self, obs):
    mu, sigma = self.Model(obs)

    dist = tfp.distributions.Normal(mu, sigma)
    action = tf.squeeze(dist.sample(), axis=0)
    prob = tf.squeeze(dist.prob(action), axis=0)

    return action, prob

the tf.function would block the thread.

@yangtao121 yangtao121 added the type:others issues not falling in bug, perfromance, support, build and install or feature label Aug 5, 2021
@kumariko kumariko assigned kumariko and unassigned saikumarchalla Aug 6, 2021
@kumariko kumariko added the comp:tf.function tf.function related issues label Aug 6, 2021
@kumariko
Copy link

kumariko commented Aug 6, 2021

@yangtao121 In order to expedite the trouble-shooting process, please provide the complete code snippet to reproduce the issue reported here. please do let us know which version of tf you are using. Thanks!

@kumariko kumariko added the stat:awaiting response Status - Awaiting response from author label Aug 6, 2021
@yangtao121
Copy link
Author

My Tensorflow version is 2.5.0.

This code turns on multithreading(multiprocessing):

threads = [mp.Process(target=self.rolling, args=[env_name]) for _ in range(self.multi_worker_num)]

self.rolling()function as show below:

 def rolling(self, env_name):
        env = gym.make(env_name).unwrapped
        action_queue = Queue()
        worker = Worker(
            env=env,
            env_args=self.env_args,
            hyper_parameter=self.hyper_parameter,
            get_data_queue=action_queue
        )
        policy = Gaussian_policy(output_queue=action_queue)
        critic = Critic()

        for _ in range(self.epochs):
            policy.load_model('data/policy.h5')
            critic.load_model('data/critic.h5')
            self.ROLLING_EVENT.wait()
            worker.update(policy, critic)
            batch = worker.runner()
            self.batches.put(batch)
            if self.batches.qsize() < self.multi_worker_num - 1:
                self.UPDATE_EVENT.wait()
            # if self.roll_flag < self.multi_worker_num:

            else:
                self.roll_flag = 0
                self.UPDATE_EVENT.set()
                self.ROLLING_EVENT.clear()

The worker is used for sampling data, the policy tell the worker how to do, policy give the action by:

    @tf.function
    def get_action(self, obs):
        # tf.print(obs)
        # print(obs)
        mu, sigma = self.Model(obs)

        dist = tfp.distributions.Normal(mu, sigma)
        action = tf.squeeze(dist.sample(), axis=0)
        prob = tf.squeeze(dist.prob(action), axis=0)
        # print(action)

        return action, prob

worker.runner():

    def runner(self):
        # print('start')
        batches = []
        for i in range(self.trajs):
            collector = Collector(observation_dims=self.obs_dims, action_dims=self.act_dims,
                                  episode_length=self.steps)
            state = self.env.reset()
            # print(i)

            for t in range(self.steps):
                state = state.reshape(1, -1)
                print(state)
                action, prob = self.policy.get_action(state)

                action_ = action * 2

                state_, reward, done, _ = self.env.step(action_)
                collector.store(state, action, reward, prob)
                state = state_

                if (t + 1) % self.batch_size == 0 or t == self.steps - 1:
                    observations, reward = collector.get_current_data()
                    value_ = self.critic.get_value(state_.reshape(1, -1))
                    values = self.critic.get_value(observations)

                    gae, target = gae_target(self.gamma, self.lambada, reward, values, value_, done)

                    collector.get_gae_target(gae, target)

            batches.append(collector)

        return batches

I found when I use multiprocessing, the thread will block at action, prob = self.policy.get_action(state) in worker.runner and had no error output. However, when I remove the tf.function above the policy.get_action() , it can work. But remove the tf.function would cause seriously memory leak.

And the complete code can be found in my github.

Thanks

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Aug 8, 2021
@kumariko kumariko added the TF 2.5 Issues related to TF 2.5 label Aug 9, 2021
@kumariko kumariko assigned Saduf2019 and unassigned kumariko Aug 9, 2021
@Saduf2019 Saduf2019 assigned jvishnuvardhan and unassigned Saduf2019 Aug 9, 2021
@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Aug 23, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:tf.function tf.function related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.5 Issues related to TF 2.5 type:others issues not falling in bug, perfromance, support, build and install or feature
Projects
None yet
Development

No branches or pull requests

6 participants