Skip to content

Commit

Permalink
Fix broken EMA in fairseq
Browse files Browse the repository at this point in the history
Summary: EMA broken since D33649708 (995c204) due to indentation error.

Reviewed By: cruvadom

Differential Revision: D33809223

fbshipit-source-id: c6c4d0d327443bfea787817040e1832eef0f50e4
  • Loading branch information
Vimal Manohar authored and facebook-github-bot committed Jan 27, 2022
1 parent 4a7835b commit 1b61bba
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 6 deletions.
10 changes: 5 additions & 5 deletions fairseq/models/ema/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@ def step(self, new_model, updates=None):
self._set_decay(
0 if updates < self.config.ema_start_update else self.config.ema_decay
)
if updates is not None and self.config.ema_update_freq > 1:
self.update_freq_counter += 1
if self.update_freq_counter >= self.config.ema_update_freq:
self._step_internal(new_model, updates)
self.update_freq_counter = 0
if self.config.ema_update_freq > 1:
self.update_freq_counter += 1
if self.update_freq_counter >= self.config.ema_update_freq:
self._step_internal(new_model, updates)
self.update_freq_counter = 0
else:
self._step_internal(new_model, updates)

Expand Down
61 changes: 60 additions & 1 deletion tests/test_ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import unittest
from unittest.mock import patch
from copy import deepcopy
from dataclasses import dataclass
from typing import Optional
Expand Down Expand Up @@ -36,9 +37,10 @@ class EMAConfig(object):
ema_start_update: int = 0
ema_fp32: bool = False
ema_seed_model: Optional[str] = None
ema_update_freq: int = 1


class TestEMAGPU(unittest.TestCase):
class TestEMA(unittest.TestCase):
def assertTorchAllClose(self, x, y, atol=1e-8, rtol=1e-5, msg=None):
diff = x.float() - y.float()
diff_norm = torch.norm(diff)
Expand Down Expand Up @@ -104,6 +106,63 @@ def test_ema(self):
ema_param = ema_state_dict[key]
self.assertTrue(torch.allclose(ema_param, param))

# Check that step_internal is called once
with patch.object(
ema, "_step_internal", return_value=None
) as mock_method:
ema.step(model)
mock_method.assert_called_once_with(model, None)

def _test_ema_start_update(self, updates):
model = DummyModule()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
state = deepcopy(model.state_dict())
config = EMAConfig(ema_start_update=1)
ema = EMA(model, config)

# EMA step
x = torch.randn(32)
y = model(x)
loss = y.sum()
loss.backward()
optimizer.step()

ema.step(model, updates=updates)
ema_state_dict = ema.get_model().state_dict()

self.assertEqual(ema.get_decay(), 0 if updates == 0 else config.ema_decay)

for key, param in model.state_dict().items():
ema_param = ema_state_dict[key]
prev_param = state[key]

if "version" in key:
# Do not decay a model.version pytorch param
continue
if updates == 0:
self.assertTorchAllClose(
ema_param,
param,
)
else:
self.assertTorchAllClose(
ema_param,
config.ema_decay * prev_param + (1 - config.ema_decay) * param,
)

# Check that step_internal is called once
with patch.object(
ema, "_step_internal", return_value=None
) as mock_method:
ema.step(model, updates=updates)
mock_method.assert_called_once_with(model, updates)

def test_ema_before_start_update(self):
self._test_ema_start_update(updates=0)

def test_ema_after_start_update(self):
self._test_ema_start_update(updates=1)

def test_ema_fp32(self):
model = DummyModule().half()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
Expand Down

0 comments on commit 1b61bba

Please sign in to comment.