|
18 | 18 | TensorDictParams, |
19 | 19 | ) |
20 | 20 | from tensordict.nn import ( |
| 21 | + CompositeDistribution, |
21 | 22 | dispatch, |
22 | 23 | ProbabilisticTensorDictModule, |
23 | 24 | ProbabilisticTensorDictSequential, |
|
33 | 34 | _clip_value_loss, |
34 | 35 | _GAMMA_LMBDA_DEPREC_ERROR, |
35 | 36 | _reduce, |
| 37 | + _sum_td_features, |
36 | 38 | default_value_kwargs, |
37 | 39 | distance_loss, |
38 | 40 | ValueEstimators, |
@@ -462,9 +464,13 @@ def reset(self) -> None: |
462 | 464 |
|
463 | 465 | def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: |
464 | 466 | try: |
465 | | - entropy = dist.entropy() |
| 467 | + if isinstance(dist, CompositeDistribution): |
| 468 | + kwargs = {"aggregate_probabilities": False, "include_sum": False} |
| 469 | + else: |
| 470 | + kwargs = {} |
| 471 | + entropy = dist.entropy(**kwargs) |
466 | 472 | if is_tensor_collection(entropy): |
467 | | - entropy = entropy.get(dist.entropy_key) |
| 473 | + entropy = _sum_td_features(entropy) |
468 | 474 | except NotImplementedError: |
469 | 475 | x = dist.rsample((self.samples_mc_entropy,)) |
470 | 476 | log_prob = dist.log_prob(x) |
@@ -497,13 +503,20 @@ def _log_weight( |
497 | 503 | if isinstance(action, torch.Tensor): |
498 | 504 | log_prob = dist.log_prob(action) |
499 | 505 | else: |
500 | | - maybe_log_prob = dist.log_prob(tensordict) |
501 | | - if not isinstance(maybe_log_prob, torch.Tensor): |
502 | | - # In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not |
503 | | - # be a tensor |
504 | | - log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob) |
| 506 | + if isinstance(dist, CompositeDistribution): |
| 507 | + is_composite = True |
| 508 | + kwargs = { |
| 509 | + "inplace": False, |
| 510 | + "aggregate_probabilities": False, |
| 511 | + "include_sum": False, |
| 512 | + } |
505 | 513 | else: |
506 | | - log_prob = maybe_log_prob |
| 514 | + is_composite = False |
| 515 | + kwargs = {} |
| 516 | + log_prob = dist.log_prob(tensordict, **kwargs) |
| 517 | + if is_composite and not isinstance(prev_log_prob, TensorDict): |
| 518 | + log_prob = _sum_td_features(log_prob) |
| 519 | + log_prob.view_as(prev_log_prob) |
507 | 520 |
|
508 | 521 | log_weight = (log_prob - prev_log_prob).unsqueeze(-1) |
509 | 522 | kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) |
@@ -598,6 +611,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: |
598 | 611 | advantage = (advantage - loc) / scale |
599 | 612 |
|
600 | 613 | log_weight, dist, kl_approx = self._log_weight(tensordict) |
| 614 | + if is_tensor_collection(log_weight): |
| 615 | + log_weight = _sum_td_features(log_weight) |
| 616 | + log_weight = log_weight.view(advantage.shape) |
601 | 617 | neg_loss = log_weight.exp() * advantage |
602 | 618 | td_out = TensorDict({"loss_objective": -neg_loss}, batch_size=[]) |
603 | 619 | if self.entropy_bonus: |
@@ -1149,16 +1165,19 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: |
1149 | 1165 | kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) |
1150 | 1166 | except NotImplementedError: |
1151 | 1167 | x = previous_dist.sample((self.samples_mc_kl,)) |
1152 | | - previous_log_prob = previous_dist.log_prob(x) |
1153 | | - current_log_prob = current_dist.log_prob(x) |
| 1168 | + if isinstance(previous_dist, CompositeDistribution): |
| 1169 | + kwargs = { |
| 1170 | + "aggregate_probabilities": False, |
| 1171 | + "inplace": False, |
| 1172 | + "include_sum": False, |
| 1173 | + } |
| 1174 | + else: |
| 1175 | + kwargs = {} |
| 1176 | + previous_log_prob = previous_dist.log_prob(x, **kwargs) |
| 1177 | + current_log_prob = current_dist.log_prob(x, **kwargs) |
1154 | 1178 | if is_tensor_collection(current_log_prob): |
1155 | | - previous_log_prob = previous_log_prob.get( |
1156 | | - self.tensor_keys.sample_log_prob |
1157 | | - ) |
1158 | | - current_log_prob = current_log_prob.get( |
1159 | | - self.tensor_keys.sample_log_prob |
1160 | | - ) |
1161 | | - |
| 1179 | + previous_log_prob = _sum_td_features(previous_log_prob) |
| 1180 | + current_log_prob = _sum_td_features(current_log_prob) |
1162 | 1181 | kl = (previous_log_prob - current_log_prob).mean(0) |
1163 | 1182 | kl = kl.unsqueeze(-1) |
1164 | 1183 | neg_loss = neg_loss - self.beta * kl |
|
0 commit comments