diff --git a/tutorial/source/vae.ipynb b/tutorial/source/vae.ipynb index 4140b144a7..b30684bc10 100644 --- a/tutorial/source/vae.ipynb +++ b/tutorial/source/vae.ipynb @@ -98,7 +98,7 @@ "\n", "import numpy as np\n", "import torch\n", - "import torchvision.datasets as dset\n", + "from pyro.contrib.examples.util import MNIST\n", "import torch.nn as nn\n", "import torchvision.transforms as transforms\n", "\n", @@ -133,9 +133,9 @@ " root = './data'\n", " download = True\n", " trans = transforms.ToTensor()\n", - " train_set = dset.MNIST(root=root, train=True, transform=trans,\n", - " download=download)\n", - " test_set = dset.MNIST(root=root, train=False, transform=trans)\n", + " train_set = MNIST(root=root, train=True, transform=trans,\n", + " download=download)\n", + " test_set = MNIST(root=root, train=False, transform=trans)\n", "\n", " kwargs = {'num_workers': 1, 'pin_memory': use_cuda}\n", " train_loader = torch.utils.data.DataLoader(dataset=train_set,\n",