diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index f224e86cf..14a79bb2e 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -688,9 +688,9 @@ def denoise( noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True ) # Iterate through batch - for noise_norm, max_new_norm, noise in zip(noise_norms, max_new_norms, noise_pred): + for i, (noise_norm, max_new_norm) in enumerate(zip(noise_norms, max_new_norms)): if noise_norm >= max_new_norm: - noise = noise * (max_new_norm / noise_norm) + noise_pred[i] = noise_pred[i] * (max_new_norm / noise_norm) else: noise_pred = noise_pred_cond