-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Closed
Labels
jit-backlogoncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
We should fix the Future
type annotation to make it also work in python, fork will immediately execute the function single threaded and returns a Future of the result, which wait just unpacks.(the same way that tracing did).
from torch import nn
from typing import List
import torch
class Encoder(nn.Module):
def forward(self, x):
return x
class EncoderEnsemble2(nn.Module):
def __init__(self, encoders : List[nn.Module]):
super().__init__()
self.encoders = nn.ModuleList(encoders)
def forward(self, x):
futures = torch.jit.annotate(
List[Future[Tensor]], []
)
for encoder in self.encoders:
futures.append(
torch.jit._fork(encoder, x)
)
all_outputs = []
for future in futures:
all_outputs.append(torch.jit._wait(future))
return all_outputs
model =EncoderEnsemble2([Encoder()])
output = model(torch.randn(3, 4))
Metadata
Metadata
Assignees
Labels
jit-backlogoncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module