-
Notifications
You must be signed in to change notification settings - Fork 389
/
dataloader.py
55 lines (46 loc) · 2.01 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import torch
import torchvision
import torchvision.transforms as transforms
def get_data_loaders(args):
if args.trainloader and args.testloader:
assert os.path.exists(args.trainloader), 'trainloader does not exist'
assert os.path.exists(args.testloader), 'testloader does not exist'
trainloader = torch.load(args.trainloader)
testloader = torch.load(args.testloader)
return trainloader, testloader
normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
std=[x/255.0 for x in [63.0, 62.1, 66.7]])
if args.raw_data:
transform_train = transforms.Compose([
transforms.ToTensor(),
])
else:
if not args.noaug:
# with data augmentation
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
else:
# no data agumentation
transform_train = transforms.Compose([
transforms.ToTensor(),
normalize,
])
transform_test = transforms.Compose([
transforms.ToTensor(),
normalize,
])
kwargs = {'num_workers': 2, 'pin_memory': True} if args.ngpu else {}
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,
transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,
transform=transform_test)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
shuffle=True, **kwargs)
testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
shuffle=False, **kwargs)
return trainloader, testloader