Skip to content

Commit 27fac2c

Browse files
committedFeb 4, 2025
Resolved log(0) error in KL divergence TheAlgorithms#12233
1 parent 6c92c5a commit 27fac2c

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed
 

‎machine_learning/loss_functions.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,9 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
659659
if len(y_true) != len(y_pred):
660660
raise ValueError("Input arrays must have the same length.")
661661

662-
kl_loss = y_true * np.log(y_true / y_pred)
662+
kl_loss = np.concatenate((y_true[None, :], y_pred[None, :]))
663+
kl_loss = kl_loss[:, ~np.any(kl_loss == 0, axis=0)]
664+
kl_loss = kl_loss[0] * np.log(kl_loss[0] / kl_loss[1])
663665
return np.sum(kl_loss)
664666

665667

0 commit comments

Comments
 (0)
Failed to load comments.