-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
pytorch_utils.py
31 lines (26 loc) · 973 Bytes
/
pytorch_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torchvision
import os
DEFAULT_IMAGE_SIZE = 224
def build_torch_dataset(root_dir, batch_size, shuffle=False, num_workers=None):
if num_workers is None:
num_workers = os.cpu_count()
# Note(swang): This is a different order from tf.data.
# torch: decode -> randCrop+resize -> randFlip
# tf.data: decode -> randCrop -> randFlip -> resize
transform = torchvision.transforms.Compose(
[
torchvision.transforms.RandomResizedCrop(
size=DEFAULT_IMAGE_SIZE,
scale=(0.05, 1.0),
ratio=(0.75, 1.33),
),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
]
)
data = torchvision.datasets.ImageFolder(root_dir, transform=transform)
data_loader = torch.utils.data.DataLoader(
data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
)
return data_loader