Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synchronized Batch Normalization example #85

Merged
@@ -21,7 +21,8 @@ def get_args(monitor_path='tmp.monitor', max_iter=234300, model_save_path='tmp.m
"""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", "-b", type=int, default=batch_size)
parser.add_argument("--batch-size", "-b", type=int, default=batch_size,
help="Local batch size, e.g., batch size per worker.")
parser.add_argument("--learning-rate", "-l",
type=float, default=learning_rate)
parser.add_argument("--monitor-path", "-m",
@@ -31,9 +32,6 @@ def get_args(monitor_path='tmp.monitor', max_iter=234300, model_save_path='tmp.m
parser.add_argument("--val-iter", "-j", type=int, default=100)
parser.add_argument("--weight-decay", "-w",
type=float, default=weight_decay)
parser.add_argument("--sync-weight-every-itr",
type=int, default=100,
help="Sync weights every specified iteration. NCCL uses the ring all reduce, so gradients in each device are not exactly same. When it is accumulated in the weights, the weight values in each device diverge.")
parser.add_argument("--device-id", "-d", type=str, default='0',
help='Device ID the training run on. This is only valid if you specify `-c cudnn`.')
parser.add_argument("--type-config", "-t", type=str, default='float',
@@ -52,4 +50,6 @@ def get_args(monitor_path='tmp.monitor', max_iter=234300, model_save_path='tmp.m
"'cifar100_resnet23'")
parser.add_argument("--with-all-reduce-callback", action='store_true',
help="Use all_reduce_callback API instead of all_reduce")
parser.add_argument('--sync-bn', action='store_true',
help="Use Synchronized batch normalization.")
return parser.parse_args()
@@ -37,7 +37,16 @@ def categorical_error(pred, label):
return (pred_label != label.flat).mean()


def resnet23_prediction(image, test=False, rng=None, ncls=10, nmaps=64, act=F.relu):
def batch_normalization(h, test=False, comm=None, group="world"):
if comm is None:
h = PF.batch_normalization(h, batch_stat=not test)
else:
h = PF.sync_batch_normalization(
h, comm=comm, group=group, batch_stat=not test)
return h


def resnet23_prediction(image, test=False, rng=None, ncls=10, nmaps=64, act=F.relu, comm=None, group="world"):
"""
Construct ResNet 23
"""
@@ -49,21 +58,21 @@ def res_unit(x, scope_name, rng, dn=False):
with nn.parameter_scope("conv1"):
h = PF.convolution(x, C / 2, kernel=(1, 1), pad=(0, 0),
with_bias=False)
h = PF.batch_normalization(h, batch_stat=not test)
h = batch_normalization(h, test=test, comm=comm, group=group)
h = act(h)
# Conv -> BN -> Nonlinear
with nn.parameter_scope("conv2"):
h = PF.convolution(h, C / 2, kernel=(3, 3), pad=(1, 1),
with_bias=False)
h = PF.batch_normalization(h, batch_stat=not test)
h = batch_normalization(h, test=test, comm=comm, group=group)
h = act(h)
# Conv -> BN
with nn.parameter_scope("conv3"):
h = PF.convolution(h, C, kernel=(1, 1), pad=(0, 0),
with_bias=False)
h = PF.batch_normalization(h, batch_stat=not test)
h = batch_normalization(h, test=test, comm=comm, group=group)
# Residual -> Nonlinear
h = act(F.add2(h, x, inplace=True))
h = act(F.add2(h, x, inplace=False))
# Maxpooling
if dn:
h = F.max_pooling(h, kernel=(2, 2), stride=(2, 2))
@@ -78,7 +87,7 @@ def res_unit(x, scope_name, rng, dn=False):
image.need_grad = False
h = PF.convolution(image, nmaps, kernel=(3, 3),
pad=(1, 1), with_bias=False)
h = PF.batch_normalization(h, batch_stat=not test)
h = batch_normalization(h, test=test, comm=comm, group=group)
h = act(h)

h = res_unit(h, "conv2", rng, False) # -> 32x32
@@ -70,16 +70,6 @@ def train():
n_train_samples = 50000
n_valid_samples = 10000
bs_valid = args.batch_size
rng = np.random.RandomState(313)
if args.net == "cifar10_resnet23":
prediction = functools.partial(
resnet23_prediction, rng=rng, ncls=10, nmaps=64, act=F.relu)
data_iterator = data_iterator_cifar10

if args.net == "cifar100_resnet23":
prediction = functools.partial(
resnet23_prediction, rng=rng, ncls=100, nmaps=384, act=F.elu)
data_iterator = data_iterator_cifar100

# Create Communicator and Context
extension_module = "cudnn"
@@ -93,22 +83,34 @@ def train():
ctx.device_id = str(device_id)
nn.set_default_context(ctx)

# Model
rng = np.random.RandomState(313)
comm_syncbn = comm if args.sync_bn else None
if args.net == "cifar10_resnet23":
prediction = functools.partial(
resnet23_prediction, rng=rng, ncls=10, nmaps=32, act=F.relu, comm=comm_syncbn)
data_iterator = data_iterator_cifar10
if args.net == "cifar100_resnet23":
prediction = functools.partial(
resnet23_prediction, rng=rng, ncls=100, nmaps=384, act=F.elu, comm=comm_syncbn)
data_iterator = data_iterator_cifar100

# Create training graphs
test = False
image_train = nn.Variable((args.batch_size, 3, 32, 32))
label_train = nn.Variable((args.batch_size, 1))
pred_train = prediction(image_train, test)
pred_train = prediction(image_train, test=False)
pred_train.persistent = True
loss_train = loss_function(pred_train, label_train)
error_train = F.mean(F.top_n_error(pred_train, label_train, axis=1))
loss_train = (loss_function(pred_train, label_train) /
n_devices).apply(persistent=True)
error_train = F.mean(F.top_n_error(
pred_train, label_train, axis=1)).apply(persistent=True)
loss_error_train = F.sink(loss_train, error_train)
input_image_train = {"image": image_train, "label": label_train}

# Create validation graph
test = True
image_valid = nn.Variable((bs_valid, 3, 32, 32))
label_valid = nn.Variable((args.batch_size, 1))
pred_valid = prediction(image_valid, test)
pred_valid = prediction(image_valid, test=True)
error_valid = F.mean(F.top_n_error(pred_valid, label_valid, axis=1))
input_image_valid = {"image": image_valid, "label": label_valid}

@@ -127,14 +129,16 @@ def train():
monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
monitor_err = MonitorSeries("Training error", monitor, interval=10)
monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
monitor_verr = MonitorSeries("Test error", monitor, interval=1)
monitor_verr = MonitorSeries("Validation error", monitor, interval=1)
monitor_vtime = MonitorTimeElapsed("Validation time", monitor, interval=1)

# Data Iterator
rng = np.random.RandomState(device_id)
_, tdata = data_iterator(args.batch_size, True, rng)
vsource, vdata = data_iterator(args.batch_size, False)

# loss_error_train.forward()

# Training-loop
ve = nn.Variable()
for i in range(int(args.max_iter / n_devices)):
@@ -178,16 +182,11 @@ def train():

# Backward/AllReduce
backward_and_all_reduce(
loss_train, comm, with_all_reduce_callback=args.with_all_reduce_callback)
loss_error_train, comm, with_all_reduce_callback=args.with_all_reduce_callback)

# Solvers update
solver.update()

# Synchronize by averaging the weights over devices using allreduce
if (i + 1) % args.sync_weight_every_itr == 0:
weights = [x.data for x in nn.get_parameters().values()]
comm.all_reduce(weights, division=True, inplace=True)

# Linear Warmup
if i <= warmup_iter:
lr = base_lr + warmup_slope * i
@@ -198,6 +197,8 @@ def train():
monitor_err.add(i * n_devices, error_train.d.copy())
monitor_time.add(i * n_devices)

# exit(0)

if device_id == 0:
nn.save_parameters(os.path.join(
args.model_save_path,
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.