Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MNIST url redirection #2775

Merged
merged 2 commits into from
Mar 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ install:
- pip install numpy
# TODO replace with torch_stable before release
# - pip install torch==1.8.0+cpu torchvision==0.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
# TODO replace with torch_test once torchvision binaries are released
# - pip install torch torchvision -f https://download.pytorch.org/whl/test/cpu/torch_test.html
# This is the last nightly release of 1.8.0 before splitting to 1.9.0.
- pip install --pre torch==1.8.0.dev20210210+cpu torchvision==0.9.0.dev20210210+cpu -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- pip install torch==1.8.0+cpu torchvision==0.9.0 -f https://download.pytorch.org/whl/test/cpu/torch_test.html
- pip install .[test]
- pip install coveralls
- pip freeze
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ Make sure that the models come from the same release version of the [Pyro source
### Installing Pyro dev branch

For recent features you can install Pyro from source.
Pyro's dev branch requires PyTorch [nightly builds](https://pytorch.org/get-started/locally/).
Pyro's dev branch currently requires PyTorch [test builds](https://pytorch.org/get-started/locally/).

**Install PyTorch nightly:**
**Install PyTorch test:**

```sh
pip install numpy
pip install --pre torch==1.8.0.dev20210210 torchvision==0.9.0.dev20210210 -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip install torch==1.8.0 torchvision==0.9.0 -f https://download.pytorch.org/whl/test/cpu/torch_test.html
```

**Install Pyro using pip:**
Expand Down
10 changes: 2 additions & 8 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,5 @@ def setup(app):
# TODO replace with torch_stable before release
# os.system('pip install torch==1.8.0+cpu torchvision==0.9.0+cpu '
# '-f https://download.pytorch.org/whl/torch_stable.html')
# TODO replace with torch_test once torchvision binaries are released
# os.system('pip install torch torchvision '
# '-f https://download.pytorch.org/whl/test/cpu/torch_test.html')
# This is the last nightly release of 1.8.0 before splitting to 1.9.0.
os.system('pip install --pre '
'torch==1.8.0.dev20210210+cpu '
'torchvision==0.9.0.dev20210210+cpu '
'-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html')
os.system('pip install torch==1.8.0+cpu torchvision==0.9.0 '
'-f https://download.pytorch.org/whl/test/cpu/torch_test.html')
3 changes: 2 additions & 1 deletion examples/cvae/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, functional

from pyro.contrib.examples.util import MNIST


class CVAEMNIST(Dataset):
def __init__(self, root, train=True, transform=None, download=False):
Expand Down
3 changes: 1 addition & 2 deletions examples/vae/utils/mnist_cached.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

from pyro.contrib.examples.util import get_data_directory
from pyro.contrib.examples.util import MNIST, get_data_directory

# This file contains utilities for caching, transforming and splitting MNIST data
# efficiently. By default, a PyTorch DataLoader will apply the transform every epoch
Expand Down
12 changes: 5 additions & 7 deletions pyro/contrib/examples/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,12 @@
import os
import sys

import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

from pyro.distributions.torch_patch import patch_dependency


@patch_dependency('torchvision.datasets.MNIST', torchvision)
class _MNIST(getattr(MNIST, '_pyro_unpatched', MNIST)):
class MNIST(datasets.MNIST):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for getting rid of this!

# For older torchvision.
urls = [
"https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/train-images-idx3-ubyte.gz",
Expand All @@ -40,7 +35,10 @@ def get_data_loader(dataset_name,
if not dataset_transforms:
dataset_transforms = []
trans = transforms.Compose([transforms.ToTensor()] + dataset_transforms)
dataset = getattr(datasets, dataset_name)
if dataset_name == "MNIST":
dataset = MNIST
else:
dataset = getattr(datasets, dataset_name)
print("downloading data")
dset = dataset(root=data_dir,
train=is_training_set,
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
'jupyter>=1.0.0',
'graphviz>=0.8',
'matplotlib>=1.3',
'torchvision>=0.9.0.dev20210210',
'torchvision>=0.9.0',
'visdom>=0.1.4',
'pandas',
'scikit-learn',
Expand Down Expand Up @@ -89,7 +89,7 @@
'numpy>=1.7',
'opt_einsum>=2.3.2',
'pyro-api>=0.1.1',
'torch>=1.8.0.dev20210210',
'torch>=1.8.0',
'tqdm>=4.36',
],
extras_require={
Expand Down