Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ The :obj:`trainer.train()` method can be sketched as follows:
... self._post_steps_hook() # "post_steps"
... self._post_steps_log_hook(batch) # "post_steps_log"

There are 9 hooks that can be used in a trainer loop: :obj:`"batch_process"`, :obj:`"pre_optim_steps"`,
There are 10 hooks that can be used in a trainer loop: :obj:`"batch_process"`, :obj:`"pre_optim_steps"`,
:obj:`"process_optim_batch"`, :obj:`"post_loss"`, :obj:`"post_steps"`, :obj:`"post_optim"`, :obj:`"pre_steps_log"`,
:obj:`"post_steps_log"` and :obj:`"post_optim_log"`. They are indicated in the comments where they are applied.
:obj:`"post_steps_log"`, :obj:`"post_optim_log"` and :obj:`"optimizer"`. They are indicated in the comments where they are applied.
Hooks can be split into 3 categories: **data processing** (:obj:`"batch_process"` and :obj:`"process_optim_batch"`),
**logging** (:obj:`"pre_steps_log"`, :obj:`"post_optim_log"` and :obj:`"post_steps_log"`) and **operations** hook
(:obj:`"pre_optim_steps"`, :obj:`"post_loss"`, :obj:`"post_optim"` and :obj:`"post_steps"`).
Expand Down Expand Up @@ -139,6 +139,7 @@ Trainer and hooks
BatchSubSampler
CountFramesLog
LogReward
OptimizerHook
Recorder
ReplayBuffer
RewardNormalizer
Expand Down
161 changes: 159 additions & 2 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
CountFramesLog,
LogReward,
mask_batch,
OptimizerHook,
ReplayBufferTrainer,
RewardNormalizer,
SelectKeys,
Expand Down Expand Up @@ -82,15 +83,18 @@ class MockingLossModule(nn.Module):
pass


def mocking_trainer(file=None) -> Trainer:
_mocking_optim = MockingOptim()


def mocking_trainer(file=None, optimizer=_mocking_optim) -> Trainer:
trainer = Trainer(
MockingCollector(),
*[
None,
]
* 2,
loss_module=MockingLossModule(),
optimizer=MockingOptim(),
optimizer=optimizer,
save_trainer_file=file,
)
trainer._pbar_str = OrderedDict()
Expand Down Expand Up @@ -472,6 +476,159 @@ def make_storage():
TensorDict.load_state_dict = TensorDict_load_state_dict


class TestOptimizer:
@staticmethod
def _setup():
torch.manual_seed(0)
x = torch.randn(5, 10)
model1 = nn.Linear(10, 20)
model2 = nn.Linear(10, 20)
td = TensorDict(
{
"loss_1": model1(x).sum(),
"loss_2": model2(x).sum(),
},
batch_size=[],
)
model1_params = list(model1.parameters())
model2_params = list(model2.parameters())
all_params = model1_params + model2_params
return model1_params, model2_params, all_params, td

def test_optimizer_set_as_argument(self):
_, _, all_params, td = self._setup()

optimizer = torch.optim.SGD(all_params, lr=1e-3)
trainer = mocking_trainer(optimizer=optimizer)

params_before = [torch.clone(p) for p in all_params]
td_out = trainer._optimizer_hook(td)
params_after = all_params

assert "grad_norm_0" in td_out.keys()
assert all(
not torch.equal(p_before, p_after)
for p_before, p_after in zip(params_before, params_after)
)

def test_optimizer_set_as_hook(self):
_, _, all_params, td = self._setup()

optimizer = torch.optim.SGD(all_params, lr=1e-3)
trainer = mocking_trainer(optimizer=None)
hook = OptimizerHook(optimizer)
hook.register(trainer)

params_before = [torch.clone(p) for p in all_params]
td_out = trainer._optimizer_hook(td)
params_after = all_params

assert "grad_norm_0" in td_out.keys()
assert all(
not torch.equal(p_before, p_after)
for p_before, p_after in zip(params_before, params_after)
)

def test_optimizer_no_optimizer(self):
_, _, all_params, td = self._setup()

trainer = mocking_trainer(optimizer=None)

params_before = [torch.clone(p) for p in all_params]
td_out = trainer._optimizer_hook(td)
params_after = all_params

assert not [key for key in td_out.keys() if key.startswith("grad_norm_")]
assert all(
torch.equal(p_before, p_after)
for p_before, p_after in zip(params_before, params_after)
)

def test_optimizer_hook_loss_components_empty(self):
model = nn.Linear(10, 20)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
with pytest.raises(ValueError, match="loss_components list cannot be empty"):
OptimizerHook(optimizer, loss_components=[])

def test_optimizer_hook_loss_components_partial(self):
model1_params, model2_params, all_params, td = self._setup()

optimizer = torch.optim.SGD(all_params, lr=1e-3)
trainer = mocking_trainer(optimizer=None)
hook = OptimizerHook(optimizer, loss_components=["loss_1"])
hook.register(trainer)

model1_params_before = [torch.clone(p) for p in model1_params]
model2_params_before = [torch.clone(p) for p in model2_params]
td_out = trainer._optimizer_hook(td)
model1_params_after = model1_params
model2_params_after = model2_params

assert "grad_norm_0" in td_out.keys()
assert all(
not torch.equal(p_before, p_after)
for p_before, p_after in zip(model1_params_before, model1_params_after)
)
assert all(
torch.equal(p_before, p_after)
for p_before, p_after in zip(model2_params_before, model2_params_after)
)

def test_optimizer_hook_loss_components_none(self):
model1_params, model2_params, all_params, td = self._setup()

optimizer = torch.optim.SGD(all_params, lr=1e-3)
trainer = mocking_trainer(optimizer=None)
hook = OptimizerHook(optimizer, loss_components=None)
hook.register(trainer)

model1_params_before = [torch.clone(p) for p in model1_params]
model2_params_before = [torch.clone(p) for p in model2_params]
td_out = trainer._optimizer_hook(td)
model1_params_after = model1_params
model2_params_after = model2_params

assert "grad_norm_0" in td_out.keys()
assert all(
not torch.equal(p_before, p_after)
for p_before, p_after in zip(model1_params_before, model1_params_after)
)
assert all(
not torch.equal(p_before, p_after)
for p_before, p_after in zip(model2_params_before, model2_params_after)
)

def test_optimizer_multiple_hooks(self):
model1_params, model2_params, _, td = self._setup()

trainer = mocking_trainer(optimizer=None)

optimizer1 = torch.optim.SGD(model1_params, lr=1e-3)
hook1 = OptimizerHook(optimizer1, loss_components=["loss_1"])
hook1.register(trainer, name="optimizer1")

optimizer2 = torch.optim.Adam(model2_params, lr=1e-4)
hook2 = OptimizerHook(optimizer2, loss_components=["loss_2"])
hook2.register(trainer, name="optimizer2")

model1_params_before = [torch.clone(p) for p in model1_params]
model2_params_before = [torch.clone(p) for p in model2_params]
td_out = trainer._optimizer_hook(td)
model1_params_after = model1_params
model2_params_after = model2_params

assert "grad_norm_0" in td_out.keys()
assert "grad_norm_1" in td_out.keys()
assert all(
not torch.equal(p_before, p_after)
for p_before, p_after in zip(model1_params_before, model1_params_after)
)
assert all(
not torch.equal(p_before, p_after)
for p_before, p_after in zip(model2_params_before, model2_params_after)
)


class TestLogReward:
@pytest.mark.parametrize("logname", ["a", "b"])
@pytest.mark.parametrize("pbar", [True, False])
Expand Down
Loading