-
Notifications
You must be signed in to change notification settings - Fork 15
/
train.py
93 lines (79 loc) · 3.36 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# -*- coding: utf-8 -*-
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from tqdm import tqdm
import utils
from datasets import load_dataset
from models import ShakePyramidNet
def main(args):
train_loader, test_loader = load_dataset(args.label, args.batch_size)
model = ShakePyramidNet(depth=args.depth, alpha=args.alpha, label=args.label)
model = torch.nn.DataParallel(model).cuda()
cudnn.benckmark = True
opt = optim.SGD(model.parameters(),
lr=args.lr,
momentum=0.9,
weight_decay=args.weight_decay,
nesterov=args.nesterov)
scheduler = optim.lr_scheduler.MultiStepLR(opt, [args.epochs // 2, args.epochs * 3 // 4])
loss_func = nn.CrossEntropyLoss().cuda()
headers = ["Epoch", "LearningRate", "TrainLoss", "TestLoss", "TrainAcc.", "TestAcc."]
logger = utils.Logger(args.checkpoint, headers)
for e in range(args.epochs):
scheduler.step()
model.train()
train_loss, train_acc, train_n = 0, 0, 0
bar = tqdm(total=len(train_loader), leave=False)
for x, t in train_loader:
x, t = Variable(x.cuda()), Variable(t.cuda())
y = model(x)
loss = loss_func(y, t)
opt.zero_grad()
loss.backward()
opt.step()
train_acc += utils.accuracy(y, t).item()
train_loss += loss.item() * t.size(0)
train_n += t.size(0)
bar.set_description("Loss: {:.4f}, Accuracy: {:.2f}".format(
train_loss / train_n, train_acc / train_n * 100), refresh=True)
bar.update()
bar.close()
model.eval()
test_loss, test_acc, test_n = 0, 0, 0
for x, t in tqdm(test_loader, total=len(test_loader), leave=False):
with torch.no_grad():
x, t = Variable(x.cuda()), Variable(t.cuda())
y = model(x)
loss = loss_func(y, t)
test_loss += loss.item() * t.size(0)
test_acc += utils.accuracy(y, t).item()
test_n += t.size(0)
if (e + 1) % args.snapshot_interval == 0:
torch.save({
"state_dict": model.state_dict(),
"optimizer": opt.state_dict()
}, os.path.join(args.checkpoint, "{}.tar".format(e + 1)))
lr = opt.param_groups[0]["lr"]
logger.write(e+1, lr, train_loss / train_n, test_loss / test_n,
train_acc / train_n * 100, test_acc / test_n * 100)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--label", type=int, default=10)
parser.add_argument("--checkpoint", type=str, default="./checkpoint")
parser.add_argument("--snapshot_interval", type=int, default=10)
# For Networks
parser.add_argument("--depth", type=int, default=110)
parser.add_argument("--alpha", type=int, default=270)
# For Training
parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument("--weight_decay", type=float, default=0.0001)
parser.add_argument("--nesterov", type=bool, default=True)
parser.add_argument("--epochs", type=int, default=1800)
parser.add_argument("--batch_size", type=int, default=128)
args = parser.parse_args()
main(args)