You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
According to the paper, the negative component of the contrastive loss is the difference between the negative states (randomly sampled from embedding at timestamp t, (z_{t}~)) and the ground truth state (z_{t+1}).
However, as per the line 113 of modules.py, given no trans, you are effectively taking the difference between randomly sampled from embedding at timestamp t (z_{t}~) and z_{t} (rather than z_{t+1}).
objs = self.obj_extractor(obs)
next_objs = self.obj_extractor(next_obs)
state = self.obj_encoder(objs)
next_state = self.obj_encoder(next_objs)
# Sample negative state across episodes at random
batch_size = state.size(0)
perm = np.random.permutation(batch_size)
neg_state = state[perm]
self.pos_loss = self.energy(state, action, next_state)
zeros = torch.zeros_like(self.pos_loss)
self.pos_loss = self.pos_loss.mean()
self.neg_loss = torch.max(
zeros, self.hinge - self.energy(
state, action, neg_state, no_trans=True)).mean()
loss = self.pos_loss + self.neg_loss
return loss
`
Thus, I feel instead of the state as the first argument of the energy function, next_state should have been the argument. Please let me know if I am misconstruing at any point.
Thanks.
The text was updated successfully, but these errors were encountered:
Hey, do you mean this in the caption of Figure 4 - "One trajectory (in the center) strongly deviates from typical trajectories seen during training, and the model struggles to predict the correct transition." ??
According to the paper, the negative component of the contrastive loss is the difference between the negative states (randomly sampled from embedding at timestamp t, (z_{t}~)) and the ground truth state (z_{t+1}).
However, as per the line 113 of modules.py, given no trans, you are effectively taking the difference between randomly sampled from embedding at timestamp t (z_{t}~) and z_{t} (rather than z_{t+1}).
` def contrastive_loss(self, obs, action, next_obs):
`
Thus, I feel instead of the state as the first argument of the energy function, next_state should have been the argument. Please let me know if I am misconstruing at any point.
Thanks.
The text was updated successfully, but these errors were encountered: