Skip to content

Commit

Permalink
use generator arg in dataloader in cevae (#3264)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjankowiak committed Aug 24, 2023
1 parent 0e82cad commit 1048a8b
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion pyro/contrib/cevae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,12 @@ def fit(
self.whiten = PreWhitener(x)

dataset = TensorDataset(x, t, y)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
generator=torch.Generator(device=x.device),
)
logger.info("Training with {} minibatches per epoch".format(len(dataloader)))
num_steps = num_epochs * len(dataloader)
optim = ClippedAdam(
Expand Down

0 comments on commit 1048a8b

Please sign in to comment.