-
Notifications
You must be signed in to change notification settings - Fork 559
Closed
Description
❓ Questions and Help
Hi, I have a model that needs to be shared across different process - e.g. one trains it while the other collects new training data for it (like in RL). In traditional pytorch this can be easily achieved with smth like:
# Create shared model
shared_model = nn.Linear(512, 32).to(device)
shared_model.share_memory()
# Pass it during creation
some_process = Process(
target=some_fn,
args=(shared_model)
)
some_process.start()
# do smth else with shared_model
# Subprocess can write/get from shared_modelwhich is unavailable for XLA tensors since they don't have storage.
Now I may try pushing this state back & forth through Queue, or even cast all weights to numpy arrays and share them, but I afraid this may be way too messy & inefficient. Are there any other approaches? Like is there any way of recasting xla tensors to native tensors only for exchanging, and then casting back for efficient computation?
Metadata
Metadata
Assignees
Labels
No labels