-
Notifications
You must be signed in to change notification settings - Fork 6k
/
Copy pathmnist_pytorch.py
161 lines (134 loc) · 4.94 KB
/
mnist_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Original Code here:
# https://github.com/pytorch/examples/blob/master/mnist/main.py
import argparse
import os
import tempfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from filelock import FileLock
from torchvision import datasets, transforms
import ray
from ray import train, tune
from ray.train import Checkpoint
from ray.tune.schedulers import AsyncHyperBandScheduler
# Change these values if you want the training to run quicker or slower.
EPOCH_SIZE = 512
TEST_SIZE = 256
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
self.fc = nn.Linear(192, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 3))
x = x.view(-1, 192)
x = self.fc(x)
return F.log_softmax(x, dim=1)
def train_func(model, optimizer, train_loader, device=None):
device = device or torch.device("cpu")
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if batch_idx * len(data) > EPOCH_SIZE:
return
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
def test_func(model, data_loader, device=None):
device = device or torch.device("cpu")
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(data_loader):
if batch_idx * len(data) > TEST_SIZE:
break
data, target = data.to(device), target.to(device)
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
return correct / total
def get_data_loaders(batch_size=64):
mnist_transforms = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
# We add FileLock here because multiple workers will want to
# download data, and this may cause overwrites since
# DataLoader is not threadsafe.
with FileLock(os.path.expanduser("~/data.lock")):
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"~/data", train=True, download=True, transform=mnist_transforms
),
batch_size=batch_size,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"~/data", train=False, download=True, transform=mnist_transforms
),
batch_size=batch_size,
shuffle=True,
)
return train_loader, test_loader
def train_mnist(config):
should_checkpoint = config.get("should_checkpoint", False)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
train_loader, test_loader = get_data_loaders()
model = ConvNet().to(device)
optimizer = optim.SGD(
model.parameters(), lr=config["lr"], momentum=config["momentum"]
)
while True:
train_func(model, optimizer, train_loader, device)
acc = test_func(model, test_loader, device)
metrics = {"mean_accuracy": acc}
# Report metrics (and possibly a checkpoint)
if should_checkpoint:
with tempfile.TemporaryDirectory() as tempdir:
torch.save(model.state_dict(), os.path.join(tempdir, "model.pt"))
train.report(metrics, checkpoint=Checkpoint.from_directory(tempdir))
else:
train.report(metrics)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
parser.add_argument(
"--cuda", action="store_true", default=False, help="Enables GPU training"
)
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing"
)
args, _ = parser.parse_known_args()
ray.init(num_cpus=2 if args.smoke_test else None)
# for early stopping
sched = AsyncHyperBandScheduler()
resources_per_trial = {"cpu": 2, "gpu": int(args.cuda)} # set this for GPUs
tuner = tune.Tuner(
tune.with_resources(train_mnist, resources=resources_per_trial),
tune_config=tune.TuneConfig(
metric="mean_accuracy",
mode="max",
scheduler=sched,
num_samples=1 if args.smoke_test else 50,
),
run_config=train.RunConfig(
name="exp",
stop={
"mean_accuracy": 0.98,
"training_iteration": 5 if args.smoke_test else 100,
},
),
param_space={
"lr": tune.loguniform(1e-4, 1e-2),
"momentum": tune.uniform(0.1, 0.9),
},
)
results = tuner.fit()
print("Best config is:", results.get_best_result().config)
assert not results.errors