Skip to content

Commit

Permalink
Merge pull request #85 from sony/feature/20190410-synced-batch-normal…
Browse files Browse the repository at this point in the history
…ization

Synchronized Batch Normalization
  • Loading branch information
TakuyaNarihira committed May 9, 2019
2 parents e385646 + ac43b86 commit d0517bd
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 33 deletions.
8 changes: 4 additions & 4 deletions distributed/cifar10-100/args.py
Expand Up @@ -21,7 +21,8 @@ def get_args(monitor_path='tmp.monitor', max_iter=234300, model_save_path='tmp.m
"""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", "-b", type=int, default=batch_size)
parser.add_argument("--batch-size", "-b", type=int, default=batch_size,
help="Local batch size, e.g., batch size per worker.")
parser.add_argument("--learning-rate", "-l",
type=float, default=learning_rate)
parser.add_argument("--monitor-path", "-m",
Expand All @@ -31,9 +32,6 @@ def get_args(monitor_path='tmp.monitor', max_iter=234300, model_save_path='tmp.m
parser.add_argument("--val-iter", "-j", type=int, default=100)
parser.add_argument("--weight-decay", "-w",
type=float, default=weight_decay)
parser.add_argument("--sync-weight-every-itr",
type=int, default=100,
help="Sync weights every specified iteration. NCCL uses the ring all reduce, so gradients in each device are not exactly same. When it is accumulated in the weights, the weight values in each device diverge.")
parser.add_argument("--device-id", "-d", type=str, default='0',
help='Device ID the training run on. This is only valid if you specify `-c cudnn`.')
parser.add_argument("--type-config", "-t", type=str, default='float',
Expand All @@ -52,4 +50,6 @@ def get_args(monitor_path='tmp.monitor', max_iter=234300, model_save_path='tmp.m
"'cifar100_resnet23'")
parser.add_argument("--with-all-reduce-callback", action='store_true',
help="Use all_reduce_callback API instead of all_reduce")
parser.add_argument('--sync-bn', action='store_true',
help="Use Synchronized batch normalization.")
return parser.parse_args()
21 changes: 15 additions & 6 deletions distributed/cifar10-100/models.py
Expand Up @@ -37,7 +37,16 @@ def categorical_error(pred, label):
return (pred_label != label.flat).mean()


def resnet23_prediction(image, test=False, rng=None, ncls=10, nmaps=64, act=F.relu):
def batch_normalization(h, test=False, comm=None, group="world"):
if comm is None:
h = PF.batch_normalization(h, batch_stat=not test)
else:
h = PF.sync_batch_normalization(
h, comm=comm, group=group, batch_stat=not test)
return h


def resnet23_prediction(image, test=False, rng=None, ncls=10, nmaps=64, act=F.relu, comm=None, group="world"):
"""
Construct ResNet 23
"""
Expand All @@ -49,21 +58,21 @@ def res_unit(x, scope_name, rng, dn=False):
with nn.parameter_scope("conv1"):
h = PF.convolution(x, C / 2, kernel=(1, 1), pad=(0, 0),
with_bias=False)
h = PF.batch_normalization(h, batch_stat=not test)
h = batch_normalization(h, test=test, comm=comm, group=group)
h = act(h)
# Conv -> BN -> Nonlinear
with nn.parameter_scope("conv2"):
h = PF.convolution(h, C / 2, kernel=(3, 3), pad=(1, 1),
with_bias=False)
h = PF.batch_normalization(h, batch_stat=not test)
h = batch_normalization(h, test=test, comm=comm, group=group)
h = act(h)
# Conv -> BN
with nn.parameter_scope("conv3"):
h = PF.convolution(h, C, kernel=(1, 1), pad=(0, 0),
with_bias=False)
h = PF.batch_normalization(h, batch_stat=not test)
h = batch_normalization(h, test=test, comm=comm, group=group)
# Residual -> Nonlinear
h = act(F.add2(h, x, inplace=True))
h = act(F.add2(h, x, inplace=False))
# Maxpooling
if dn:
h = F.max_pooling(h, kernel=(2, 2), stride=(2, 2))
Expand All @@ -78,7 +87,7 @@ def res_unit(x, scope_name, rng, dn=False):
image.need_grad = False
h = PF.convolution(image, nmaps, kernel=(3, 3),
pad=(1, 1), with_bias=False)
h = PF.batch_normalization(h, batch_stat=not test)
h = batch_normalization(h, test=test, comm=comm, group=group)
h = act(h)

h = res_unit(h, "conv2", rng, False) # -> 32x32
Expand Down
Expand Up @@ -70,16 +70,6 @@ def train():
n_train_samples = 50000
n_valid_samples = 10000
bs_valid = args.batch_size
rng = np.random.RandomState(313)
if args.net == "cifar10_resnet23":
prediction = functools.partial(
resnet23_prediction, rng=rng, ncls=10, nmaps=64, act=F.relu)
data_iterator = data_iterator_cifar10

if args.net == "cifar100_resnet23":
prediction = functools.partial(
resnet23_prediction, rng=rng, ncls=100, nmaps=384, act=F.elu)
data_iterator = data_iterator_cifar100

# Create Communicator and Context
extension_module = "cudnn"
Expand All @@ -93,22 +83,34 @@ def train():
ctx.device_id = str(device_id)
nn.set_default_context(ctx)

# Model
rng = np.random.RandomState(313)
comm_syncbn = comm if args.sync_bn else None
if args.net == "cifar10_resnet23":
prediction = functools.partial(
resnet23_prediction, rng=rng, ncls=10, nmaps=32, act=F.relu, comm=comm_syncbn)
data_iterator = data_iterator_cifar10
if args.net == "cifar100_resnet23":
prediction = functools.partial(
resnet23_prediction, rng=rng, ncls=100, nmaps=384, act=F.elu, comm=comm_syncbn)
data_iterator = data_iterator_cifar100

# Create training graphs
test = False
image_train = nn.Variable((args.batch_size, 3, 32, 32))
label_train = nn.Variable((args.batch_size, 1))
pred_train = prediction(image_train, test)
pred_train = prediction(image_train, test=False)
pred_train.persistent = True
loss_train = loss_function(pred_train, label_train)
error_train = F.mean(F.top_n_error(pred_train, label_train, axis=1))
loss_train = (loss_function(pred_train, label_train) /
n_devices).apply(persistent=True)
error_train = F.mean(F.top_n_error(
pred_train, label_train, axis=1)).apply(persistent=True)
loss_error_train = F.sink(loss_train, error_train)
input_image_train = {"image": image_train, "label": label_train}

# Create validation graph
test = True
image_valid = nn.Variable((bs_valid, 3, 32, 32))
label_valid = nn.Variable((args.batch_size, 1))
pred_valid = prediction(image_valid, test)
pred_valid = prediction(image_valid, test=True)
error_valid = F.mean(F.top_n_error(pred_valid, label_valid, axis=1))
input_image_valid = {"image": image_valid, "label": label_valid}

Expand All @@ -127,14 +129,16 @@ def train():
monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
monitor_err = MonitorSeries("Training error", monitor, interval=10)
monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
monitor_verr = MonitorSeries("Test error", monitor, interval=1)
monitor_verr = MonitorSeries("Validation error", monitor, interval=1)
monitor_vtime = MonitorTimeElapsed("Validation time", monitor, interval=1)

# Data Iterator
rng = np.random.RandomState(device_id)
_, tdata = data_iterator(args.batch_size, True, rng)
vsource, vdata = data_iterator(args.batch_size, False)

# loss_error_train.forward()

# Training-loop
ve = nn.Variable()
for i in range(int(args.max_iter / n_devices)):
Expand Down Expand Up @@ -178,16 +182,11 @@ def train():

# Backward/AllReduce
backward_and_all_reduce(
loss_train, comm, with_all_reduce_callback=args.with_all_reduce_callback)
loss_error_train, comm, with_all_reduce_callback=args.with_all_reduce_callback)

# Solvers update
solver.update()

# Synchronize by averaging the weights over devices using allreduce
if (i + 1) % args.sync_weight_every_itr == 0:
weights = [x.data for x in nn.get_parameters().values()]
comm.all_reduce(weights, division=True, inplace=True)

# Linear Warmup
if i <= warmup_iter:
lr = base_lr + warmup_slope * i
Expand All @@ -198,6 +197,8 @@ def train():
monitor_err.add(i * n_devices, error_train.d.copy())
monitor_time.add(i * n_devices)

# exit(0)

if device_id == 0:
nn.save_parameters(os.path.join(
args.model_save_path,
Expand Down

0 comments on commit d0517bd

Please sign in to comment.