Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SWA to PyTorch mainline #35032

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
323 changes: 323 additions & 0 deletions test/test_optim.py
Expand Up @@ -13,6 +13,7 @@
from torch.optim.lr_scheduler import LambdaLR, MultiplicativeLR, StepLR, \
MultiStepLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, \
_LRScheduler, CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \
skipIfRocm

Expand Down Expand Up @@ -1344,6 +1345,80 @@ def test_CosineAnnealingWarmRestarts_lr3(self):
targets[1] += [eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2]
self._test_interleaved_CosineAnnealingWarmRestarts(scheduler, targets, epochs)

def test_swalr_no_anneal(self):
epochs, swa_start, swa_lr = 10, 5, 0.01
initial_lrs = [group['lr'] for group in self.opt.param_groups]
targets = [[lr] * (swa_start + 1) + [swa_lr] * (epochs - swa_start - 1)
for lr in initial_lrs]
swa_scheduler = SWALR(self.opt, anneal_epochs=1, swa_lr=swa_lr)
self._test_swalr(swa_scheduler, None, targets, swa_start, epochs)

def test_swalr_cosine_anneal_after_multiplicative(self):
# same swa_lr for different param_groups
epochs, swa_start, swa_lr, anneal_epochs = 15, 5, 0.01, 5
mult_factor = 0.9
scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor)
swa_scheduler = SWALR(self.opt, anneal_epochs=anneal_epochs, swa_lr=swa_lr)

def anneal_coef(t):
if t + 1 >= anneal_epochs:
return 0.
return (1 + math.cos(math.pi * (t + 1) / anneal_epochs)) / 2

initial_lrs = [group['lr'] for group in self.opt.param_groups]
targets_before_swa = [[lr * mult_factor**i for i in range(swa_start + 1)]
for lr in initial_lrs]
swa_epochs = epochs - swa_start - 1
targets = [lrs + [lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t)) for t in range(swa_epochs)]
for lrs in targets_before_swa]

self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs)

def test_swalr_linear_anneal_after_multiplicative(self):
# separate swa_lr for different param_groups
epochs, swa_start, swa_lrs, anneal_epochs = 15, 5, [0.01, 0.02], 4
mult_factor = 0.9
scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor)
swa_scheduler = SWALR(self.opt, anneal_epochs=anneal_epochs,
anneal_strategy="linear", swa_lr=swa_lrs)

def anneal_coef(t):
if t + 1 >= anneal_epochs:
return 0.
return 1 - (t + 1) / anneal_epochs

initial_lrs = [group['lr'] for group in self.opt.param_groups]
targets_before_swa = [[lr * mult_factor**i for i in range(swa_start + 1)]
for lr in initial_lrs]
swa_epochs = epochs - swa_start - 1
targets = [lrs + [lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t)) for t in range(swa_epochs)]
for lrs, swa_lr in zip(targets_before_swa, swa_lrs)]

self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs)

def _test_swalr(self, swa_scheduler, scheduler, targets, swa_start, epochs):
for epoch in range(epochs):
for param_group, target in zip(self.opt.param_groups, targets):
self.assertAlmostEqual(target[epoch], param_group['lr'],
msg='LR is wrong in epoch {}: expected {}, got {}'.format(
epoch, target[epoch], param_group['lr']), delta=1e-5)
if epoch >= swa_start:
swa_scheduler.step()
elif scheduler is not None:
scheduler.step()

def test_swalr_hypers(self):
# Test that SWALR raises errors for incorrect hyper-parameters
with self.assertRaisesRegex(ValueError, "anneal_strategy must"):
swa_scheduler = SWALR(self.opt, anneal_strategy="exponential", swa_lr=1.)

with self.assertRaisesRegex(ValueError, "anneal_epochs must"):
swa_scheduler = SWALR(self.opt, anneal_epochs=-1, swa_lr=1.)
with self.assertRaisesRegex(ValueError, "anneal_epochs must"):
swa_scheduler = SWALR(self.opt, anneal_epochs=1.7, swa_lr=1.)
with self.assertRaisesRegex(ValueError, "swa_lr must"):
swa_scheduler = SWALR(self.opt, swa_lr=[1., 0.1, 0.01])

def test_step_lr_state_dict(self):
self._check_scheduler_state_dict(
lambda: StepLR(self.opt, gamma=0.1, step_size=3),
Expand Down Expand Up @@ -1404,6 +1479,11 @@ def test_CosineAnnealingWarmRestarts_lr_state_dict(self):
lambda: CosineAnnealingWarmRestarts(self.opt, T_0=10, T_mult=2),
lambda: CosineAnnealingWarmRestarts(self.opt, T_0=100))

def test_swa_lr_state_dict(self):
self._check_scheduler_state_dict(
lambda: SWALR(self.opt, anneal_epochs=3, swa_lr=0.5),
lambda: SWALR(self.opt, anneal_epochs=10, anneal_strategy="linear", swa_lr=5.))

def _check_scheduler_state_dict(self, constr, constr2, epochs=10):
scheduler = constr()
for _ in range(epochs):
Expand Down Expand Up @@ -1547,5 +1627,248 @@ def test_cosine_then_cyclic(self):

self.assertLessEqual(last_lr, max_lr)


class SWATestDNN(torch.nn.Module):
def __init__(self, input_features):
super(SWATestDNN, self).__init__()
self.n_features = 100
self.fc1 = torch.nn.Linear(input_features, self.n_features)
self.bn = torch.nn.BatchNorm1d(self.n_features)

def compute_preactivation(self, x):
return self.fc1(x)

def forward(self, x):
x = self.fc1(x)
x = self.bn(x)
return x


class SWATestCNN(torch.nn.Module):
def __init__(self, input_channels):
super(SWATestCNN, self).__init__()
self.n_features = 10
self.conv1 = torch.nn.Conv2d(input_channels, self.n_features, kernel_size=3, padding=1)
self.bn = torch.nn.BatchNorm2d(self.n_features, momentum=0.3)

def compute_preactivation(self, x):
return self.conv1(x)

def forward(self, x):
x = self.conv1(x)
x = self.bn(x)
return x


class TestSWAUtils(TestCase):

def _test_averaged_model(self, net_device, swa_device):
dnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2),
torch.nn.BatchNorm2d(5, momentum=0.3),
torch.nn.Conv2d(5, 2, kernel_size=3),
torch.nn.ReLU(),
torch.nn.Linear(5, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 10)
).to(net_device)

averaged_dnn = AveragedModel(dnn, device=swa_device)
averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
n_updates = 10
for i in range(n_updates):
for p, p_avg in zip(dnn.parameters(), averaged_params):
p.detach().add_(torch.randn_like(p))
p_avg += p.detach() / n_updates
averaged_dnn.update_parameters(dnn)

for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
self.assertAlmostEqual(p_avg, p_swa)
# Check that AveragedModel is on the correct device
self.assertTrue(p_swa.device == swa_device)
self.assertTrue(p.device == net_device)
self.assertTrue(averaged_dnn.n_averaged.device == swa_device)

def test_averaged_model_all_devices(self):
cpu = torch.device("cpu")
self._test_averaged_model(cpu, cpu)
if torch.cuda.is_available():
cuda = torch.device(0)
self._test_averaged_model(cuda, cpu)
self._test_averaged_model(cpu, cuda)
self._test_averaged_model(cuda, cuda)

def test_averaged_model_mixed_device(self):
if not torch.cuda.is_available():
return
dnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3),
torch.nn.Linear(5, 10)
)
dnn[0].cuda()
dnn[1].cpu()
averaged_dnn = AveragedModel(dnn)
averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
n_updates = 10
for i in range(n_updates):
for p, p_avg in zip(dnn.parameters(), averaged_params):
p.detach().add_(torch.randn_like(p))
p_avg += p.detach() / n_updates
averaged_dnn.update_parameters(dnn)

for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
self.assertAlmostEqual(p_avg, p_swa)
# Check that AveragedModel is on the correct device
self.assertTrue(p_avg.device == p_swa.device)

def test_averaged_model_state_dict(self):
dnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3),
torch.nn.Linear(5, 10)
)
averaged_dnn = AveragedModel(dnn)
averaged_dnn2 = AveragedModel(dnn)
n_updates = 10
for i in range(n_updates):
for p in dnn.parameters():
p.detach().add_(torch.randn_like(p))
averaged_dnn.update_parameters(dnn)
averaged_dnn2.load_state_dict(averaged_dnn.state_dict())
for p_swa, p_swa2 in zip(averaged_dnn.parameters(), averaged_dnn2.parameters()):
self.assertAlmostEqual(p_swa, p_swa2)
self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged)

def test_averaged_model_exponential(self):
# Test AveragedModel with EMA as avg_fn
dnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3),
torch.nn.Linear(5, 10)
)
alpha = 0.9

def avg_fn(p_avg, p, n_avg):
return alpha * p_avg + (1 - alpha) * p
averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn)
averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
n_updates = 10
for i in range(n_updates):
updated_averaged_params = []
for p, p_avg in zip(dnn.parameters(), averaged_params):
p.detach().add_(torch.randn_like(p))
if i == 0:
updated_averaged_params.append(p.clone())
else:
updated_averaged_params.append((p_avg * alpha +
p * (1 - alpha)).clone())
averaged_dnn.update_parameters(dnn)
averaged_params = updated_averaged_params

for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
self.assertAlmostEqual(p_avg, p_swa)

def _test_update_bn(self, dnn, dl_x, dl_xy, cuda):

preactivation_sum = torch.zeros(dnn.n_features)
preactivation_squared_sum = torch.zeros(dnn.n_features)
if cuda:
preactivation_sum = preactivation_sum.cuda()
preactivation_squared_sum = preactivation_squared_sum.cuda()
total_num = 0
for x in dl_x:
x = x[0]
if cuda:
x = x.cuda()

dnn.forward(x)
preactivations = dnn.compute_preactivation(x)
if len(preactivations.shape) == 4:
preactivations = preactivations.transpose(1, 3)
preactivations = preactivations.contiguous().view(-1, dnn.n_features)
total_num += preactivations.shape[0]

preactivation_sum += torch.sum(preactivations, dim=0)
preactivation_squared_sum += torch.sum(preactivations**2, dim=0)

preactivation_mean = preactivation_sum / total_num
preactivation_var = preactivation_squared_sum / total_num
preactivation_var = preactivation_var - preactivation_mean**2

update_bn(dl_xy, dnn, device=x.device)
self.assertEqual(preactivation_mean, dnn.bn.running_mean)
self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1)

def _reset_bn(module):
if issubclass(module.__class__,
torch.nn.modules.batchnorm._BatchNorm):
module.running_mean = torch.zeros_like(module.running_mean)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be sth like

nn.init.zeros_(bn.running_mean)
nn.init.ones_(bn.running_var)

would save a minor allocation and would get more consistent with general nn initialization codes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this. Do you want to open a pull request for this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be decided first if we do actually want to do this inplace or not (what the "inplaceness" contract should be)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair. What are you concern in particular here?

module.running_var = torch.ones_like(module.running_var)
# reset batch norm and run update_bn again
dnn.apply(_reset_bn)
update_bn(dl_xy, dnn, device=x.device)
self.assertEqual(preactivation_mean, dnn.bn.running_mean)
self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1)
# using the dl_x loader instead of dl_xy
dnn.apply(_reset_bn)
update_bn(dl_x, dnn, device=x.device)
self.assertEqual(preactivation_mean, dnn.bn.running_mean)
self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1)

def test_update_bn_dnn(self):
# Test update_bn for a fully-connected network with BatchNorm1d
objects, input_features = 100, 5
x = torch.rand(objects, input_features)
y = torch.rand(objects)
ds_x = torch.utils.data.TensorDataset(x)
ds_xy = torch.utils.data.TensorDataset(x, y)
dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)
dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True)
dnn = SWATestDNN(input_features=input_features)
dnn.train()
self._test_update_bn(dnn, dl_x, dl_xy, False)
if torch.cuda.is_available():
dnn = SWATestDNN(input_features=input_features)
dnn.train()
self._test_update_bn(dnn.cuda(), dl_x, dl_xy, True)
self.assertTrue(dnn.training)

def test_update_bn_cnn(self):
# Test update_bn for convolutional network and BatchNorm2d
objects = 100
input_channels = 3
height, width = 5, 5
x = torch.rand(objects, input_channels, height, width)
y = torch.rand(objects)
ds_x = torch.utils.data.TensorDataset(x)
ds_xy = torch.utils.data.TensorDataset(x, y)
dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)
dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True)
dnn = SWATestCNN(input_channels=input_channels)
dnn.train()
self._test_update_bn(dnn, dl_x, dl_xy, False)
if torch.cuda.is_available():
dnn = SWATestCNN(input_channels=input_channels)
dnn.train()
self._test_update_bn(dnn.cuda(), dl_x, dl_xy, True)
self.assertTrue(dnn.training)

def test_bn_update_eval_momentum(self):
# check that update_bn preserves eval mode
objects = 100
input_channels = 3
height, width = 5, 5
x = torch.rand(objects, input_channels, height, width)
ds_x = torch.utils.data.TensorDataset(x)
dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)
dnn = SWATestCNN(input_channels=input_channels)
dnn.eval()
update_bn(dl_x, dnn)
self.assertFalse(dnn.training)

# check that momentum is preserved
self.assertEqual(dnn.bn.momentum, 0.3)


if __name__ == '__main__':
run_tests()