Skip to content

Model sharing across torch Processes #2103

@JanRocketMan

Description

@JanRocketMan

❓ Questions and Help

Hi, I have a model that needs to be shared across different process - e.g. one trains it while the other collects new training data for it (like in RL). In traditional pytorch this can be easily achieved with smth like:

# Create shared model
shared_model = nn.Linear(512, 32).to(device)
shared_model.share_memory()

# Pass it during creation
some_process = Process(
    target=some_fn,
    args=(shared_model)
)
some_process.start()

# do smth else with shared_model

# Subprocess can write/get from shared_model

which is unavailable for XLA tensors since they don't have storage.

Now I may try pushing this state back & forth through Queue, or even cast all weights to numpy arrays and share them, but I afraid this may be way too messy & inefficient. Are there any other approaches? Like is there any way of recasting xla tensors to native tensors only for exchanging, and then casting back for efficient computation?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions