diff --git a/torchvision/datasets/cifar.py b/torchvision/datasets/cifar.py index 19e9fc450f5..44cde4ba54f 100644 --- a/torchvision/datasets/cifar.py +++ b/torchvision/datasets/cifar.py @@ -50,7 +50,7 @@ class CIFAR10(data.Dataset): def __init__(self, root, train=True, transform=None, target_transform=None, download=False): - self.root = root + self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform self.train = train # training set or test set diff --git a/torchvision/datasets/coco.py b/torchvision/datasets/coco.py index 4882a33d220..3d767d3d4d8 100644 --- a/torchvision/datasets/coco.py +++ b/torchvision/datasets/coco.py @@ -44,7 +44,7 @@ class CocoCaptions(data.Dataset): """ def __init__(self, root, annFile, transform=None, target_transform=None): from pycocotools.coco import COCO - self.root = root + self.root = os.path.expanduser(root) self.coco = COCO(annFile) self.ids = list(self.coco.imgs.keys()) self.transform = transform diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index e1fbe0476ad..6cc04c1c924 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -23,6 +23,7 @@ def find_classes(dir): def make_dataset(dir, class_to_idx): images = [] + dir = os.path.expanduser(dir) for target in os.listdir(dir): d = os.path.join(dir, target) if not os.path.isdir(d): diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 75f7993c872..6f7495141a0 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -36,7 +36,7 @@ class MNIST(data.Dataset): test_file = 'test.pt' def __init__(self, root, train=True, transform=None, target_transform=None, download=False): - self.root = root + self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform self.train = train # training set or test set diff --git a/torchvision/datasets/phototour.py b/torchvision/datasets/phototour.py index 3da10ef0347..d44449bb6a9 100644 --- a/torchvision/datasets/phototour.py +++ b/torchvision/datasets/phototour.py @@ -49,7 +49,7 @@ class PhotoTour(data.Dataset): matches_files = 'm50_100000_100000_0.txt' def __init__(self, root, name, train=True, transform=None, download=False): - self.root = root + self.root = os.path.expanduser(root) self.name = name self.data_dir = os.path.join(root, name) self.data_down = os.path.join(root, '{}.zip'.format(name)) diff --git a/torchvision/datasets/stl10.py b/torchvision/datasets/stl10.py index e498e763e82..908543b6c1a 100644 --- a/torchvision/datasets/stl10.py +++ b/torchvision/datasets/stl10.py @@ -44,7 +44,7 @@ class STL10(CIFAR10): def __init__(self, root, split='train', transform=None, target_transform=None, download=False): - self.root = root + self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform self.split = split # train/test/unlabeled set diff --git a/torchvision/datasets/svhn.py b/torchvision/datasets/svhn.py index f53b1652d95..5a093bd2792 100644 --- a/torchvision/datasets/svhn.py +++ b/torchvision/datasets/svhn.py @@ -38,7 +38,7 @@ class SVHN(data.Dataset): def __init__(self, root, split='train', transform=None, target_transform=None, download=False): - self.root = root + self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform self.split = split # training set or test set or extra set diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index a6fe8883ae2..962eb4f9461 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -21,6 +21,7 @@ def check_integrity(fpath, md5): def download_url(url, root, filename, md5): from six.moves import urllib + root = os.path.expanduser(root) fpath = os.path.join(root, filename) try: