Skip to content

Commit 31c98a2

Browse files
committed
Update gaussian_diffusion.py
1 parent fa7cb13 commit 31c98a2

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/modeling/diffusion/gaussian_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def training_losses(self, model, x_start_never_used, t, model_kwargs=None, noise
243243
"""
244244
assert "input_ids" in model_kwargs
245245
input_ids = model_kwargs.pop("input_ids").to(t.device)
246-
x_start_mean = model.model.get_embeds(input_ids)
246+
x_start_mean = model.model.module.get_embeds(input_ids)
247247

248248
std = _extract_into_tensor(
249249
self.sqrt_one_minus_alphas_cumprod,
@@ -256,7 +256,7 @@ def training_losses(self, model, x_start_never_used, t, model_kwargs=None, noise
256256
if noise is None:
257257
noise = th.randn_like(x_start)
258258
x_t = self.q_sample(x_start, t, noise=noise) # reparametrization trick.
259-
get_logits = model.model.get_logits
259+
get_logits = model.model.module.get_logits
260260

261261
terms = {}
262262

0 commit comments

Comments
 (0)