Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CIFAR: permanent 'data' and 'targets' fields #594

Merged
merged 1 commit into from Sep 11, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
78 changes: 23 additions & 55 deletions torchvision/datasets/cifar.py
Expand Up @@ -51,13 +51,6 @@ class CIFAR10(data.Dataset):
'md5': '5ff9c542aee3614f3951f8cda6e48888',
}

@property
def targets(self):
if self.train:
return self.train_labels
else:
return self.test_labels

def __init__(self, root, train=True,
transform=None, target_transform=None,
download=False):
Expand All @@ -73,44 +66,30 @@ def __init__(self, root, train=True,
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

# now load the picked numpy arrays
if self.train:
self.train_data = []
self.train_labels = []
for fentry in self.train_list:
f = fentry[0]
file = os.path.join(self.root, self.base_folder, f)
fo = open(file, 'rb')
downloaded_list = self.train_list
else:
downloaded_list = self.test_list

self.data = []
self.targets = []

# now load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
if sys.version_info[0] == 2:
entry = pickle.load(fo)
entry = pickle.load(f)
else:
entry = pickle.load(fo, encoding='latin1')
self.train_data.append(entry['data'])
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
if 'labels' in entry:
self.train_labels += entry['labels']
self.targets.extend(entry['labels'])
else:
self.train_labels += entry['fine_labels']
fo.close()
self.targets.extend(entry['fine_labels'])

self.train_data = np.concatenate(self.train_data)
self.train_data = self.train_data.reshape((50000, 3, 32, 32))
self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC
else:
f = self.test_list[0][0]
file = os.path.join(self.root, self.base_folder, f)
fo = open(file, 'rb')
if sys.version_info[0] == 2:
entry = pickle.load(fo)
else:
entry = pickle.load(fo, encoding='latin1')
self.test_data = entry['data']
if 'labels' in entry:
self.test_labels = entry['labels']
else:
self.test_labels = entry['fine_labels']
fo.close()
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC

self._load_meta()

Expand All @@ -135,10 +114,7 @@ def __getitem__(self, index):
Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
img, target = self.data[index], self.targets[index]

# doing this so that it is consistent with all other datasets
# to return a PIL Image
Expand All @@ -153,10 +129,7 @@ def __getitem__(self, index):
return img, target

def __len__(self):
if self.train:
return len(self.train_data)
else:
return len(self.test_data)
return len(self.data)

def _check_integrity(self):
root = self.root
Expand All @@ -174,16 +147,11 @@ def download(self):
print('Files already downloaded and verified')
return

root = self.root
download_url(self.url, root, self.filename, self.tgz_md5)
download_url(self.url, self.root, self.filename, self.tgz_md5)

# extract file
cwd = os.getcwd()
tar = tarfile.open(os.path.join(root, self.filename), "r:gz")
os.chdir(root)
tar.extractall()
tar.close()
os.chdir(cwd)
with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
tar.extractall(path=self.root)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
Expand Down