-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
90 lines (72 loc) · 2.94 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# -*- coding: utf-8 -*-
"""
@date: 2020/9/14 下午8:36
@file: main.py
@author: zj
@description:
"""
import os
from datetime import datetime
import argparse
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast
from mixed_multi_gpu.model import ConvNet
from mixed_multi_gpu.data import build_dataloader
def train(gpu, args):
rank = args.nr * args.gpus + gpu
dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
torch.manual_seed(0)
torch.cuda.set_device(gpu)
device = torch.device(f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu')
model = ConvNet().to(device)
# Wrap the model
# model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), 1e-4)
# Data loading code
data_loader = build_dataloader(world_size=args.world_size, rank=rank)
# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()
start = datetime.now()
total_step = len(data_loader)
for epoch in range(args.epochs):
for i, (images, labels) in enumerate(data_loader):
images = images.to(device)
labels = labels.to(device)
# Backward and optimize
optimizer.zero_grad()
# Runs the forward pass with autocasting.
with autocast():
outputs = model(images)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
if (i + 1) % 100 == 0 and gpu == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, args.epochs, i + 1, total_step,
loss.item()))
if gpu == 0:
print("Training complete in: " + str(datetime.now() - start))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-g', '--gpus', default=1, type=int,
help='number of gpus per node (default: 1)')
parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N',
help='number of machines (default: 1)')
parser.add_argument('-nr', '--nr', default=0, type=int,
help='ranking within the nodes (default: 0)')
parser.add_argument('-e', '--epochs', default=2, type=int, metavar='N',
help='number of total epochs to run (default: 2)')
args = parser.parse_args()
# train(0, args)
args.world_size = args.gpus * args.nodes
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '18888'
mp.spawn(train, nprocs=args.gpus, args=(args,))
if __name__ == '__main__':
main()