Skip to content

[jit] Fix future type annotation in python #26578

@wanchaol

Description

@wanchaol

We should fix the Future type annotation to make it also work in python, fork will immediately execute the function single threaded and returns a Future of the result, which wait just unpacks.(the same way that tracing did).

from torch import nn
from typing import List
import torch

class Encoder(nn.Module):
    def forward(self, x):
        return x

class EncoderEnsemble2(nn.Module):
    def __init__(self, encoders : List[nn.Module]):
        super().__init__()
        self.encoders = nn.ModuleList(encoders)

    def forward(self, x):
        futures = torch.jit.annotate(
            List[Future[Tensor]], []
        )
        for encoder in self.encoders:
            futures.append(
                torch.jit._fork(encoder, x)
            )

        all_outputs = []
        for future in futures:
            all_outputs.append(torch.jit._wait(future))
        return all_outputs

model =EncoderEnsemble2([Encoder()])
output = model(torch.randn(3, 4))

cc @ezyang @gchanan @zou3519 @suo

Metadata

Metadata

Assignees

Labels

jit-backlogoncall: jitAdd this issue/PR to JIT oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions