-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Description
🚀 Feature
Support parallelized/asynchronous execution of ops on CPU.
PyTorch currently supports asynchronous execution on GPU, where ops are executed in parallel as long as there are no data dependencies between them. However, currently, there is no straightforward way to do this on CPU, even on machines with a large number of cores.
Motivation
Interested in using PyTorch for Reinforcement Learning (ML-Agents Toolkit). Currently (from what I understand) PyTorch parallelizes across CPUs in one of three main ways:
- Intra-op using OpenMP and associated optimizations
- Inter-op using JIT fork and wait
- Using Python multithreading (needs to contend with GIL) and multiprocessing.
Method 1 works well for typical supervised learning tasks, where the network is large and each op is quite expensive. However, none of the three ways makes a lot of sense for Reinforcement Learning workloads, where a) there could be multiple networks (e.g. Policy and Value), b) each network is fairly small (2-3 layers, <200 hidden units), and c) they often share a loss function, making multiprocess very complicated/difficult.
For reference, here is the pseudo-code for the update function of the Soft-Actor Crtitic algorithm (full code here).
# Sample observations, dones, rewards from experience replay buffer
observations, next_observations, dones, rewards = sample_batch()
# Evaluate current policy on sampled observations
(
sampled_actions,
log_probs,
entropies,
sampled_values,
) = policy.sample_actions_and_values(observations)
# Evaluate Q networks on observations and actions
q1p_out, q2p_out = value_network(observations, sampled_actions)
q1_out, q2_out = value_network(observations, actions)
# Evaluate target network on next observations
with torch.no_grad():
target_values = target_network(next_observations)
# Evaluate losses
q1_loss, q2_loss = sac_q_loss(q1_out, q2_out, target_values, dones, rewards)
value_loss = sac_value_loss(log_probs, sampled_values, q1p_out, q2p_out)
policy_loss = sac_policy_loss(log_probs, q1p_out)
entropy_loss = sac_entropy_loss(log_probs)
Note that the policy
contains a policy and a critic network, and each value_network
consists of two Q networks, so the code above contains a total of 7 forward passes and 3 backwards passes, performed sequentially. As the networks are quite small, intra-op parallelism is not very effective (in fact setting num_threads
to 1 is most performant).
In fact, even after CPU/environment variable/thread optimization, the resulting PyTorch code is about 2x slower than the equivalent TensorFlow code while running on a 6-core CPU, with the approximate time spent 35% doing backprop, 45% during forward pass, and 20% during the optimizer step functions. No such performance gap is observed between Torch and TF on GPU.
We've tried alternatives 2 (JIT w/ forking) and 3 (multithreading) but didn't notice much more than around 5% improvement. We believe it's due to the overhead of spawning new threads outweighing the lightweight network evaluations.
Pitch
Support the exact mechanism on CUDA for CPU, respecting num_interop_threads
, or provide a straightforward way to implicitly parallelize multiple networks.
Alternatives
Using JIT fork - Often requires considerable refactors. Furthermore, in our testing did not provide much benefit, possibly due to the overhead of forking new threads for small networks outweighing the benefit.
Using multithreading - Ops in networks are so small that the Python GIL is thrashed when trying to parallelize multiple networks.
Using multiprocessing - Requires significantly more complex code structure, especially because the loss functions are shared between multiple networks.
Additional alternatives/suggestions greatly welcome!
Additional context
Code in question: https://github.com/Unity-Technologies/ml-agents/blob/376168bbb55deac540a572617fb70effffe98cd5/ml-agents/mlagents/trainers/sac/optimizer_torch.py#L438