diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index 400e9ebed..93b42a5cd 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -33,24 +33,35 @@ def warn_if_zscoring_changes_data(x: Tensor, duplicate_tolerance: float = 0.1) - # Count unique xs. num_unique = torch.unique(x, dim=0).numel() - # z-score. - zx = (x - x.mean(0)) / x.std(0) - - # Count again and warn on too many new duplicates. - num_unique_z = torch.unique(zx, dim=0).numel() - - if num_unique_z < num_unique * (1 - duplicate_tolerance): + # Check we do have different data in the batch + if num_unique == 1: warnings.warn( - """Z-scoring these simulation outputs resulted in {num_unique_z} unique - datapoints. Before z-scoring, it had been {num_unique}. This can occur due - to numerical inaccuracies when the data covers a large range of values. - Consider either setting `z_score_x=False` (but beware that this can be - problematic for training the NN) or exclude outliers from your dataset. - Note: if you have already set `z_score_x=False`, this warning will still be - displayed, but you can ignore it.""", + """Variance along batches is 0. Therefore, Z-score could not be computed and will be skipped. + Check that data is different along all dimensions.""", UserWarning, ) + # Skip computation. + return + else: + # z-score. + zx = (x - x.mean(0)) / x.std(0) + + # Count again and warn on too many new duplicates. + num_unique_z = torch.unique(zx, dim=0).numel() + + if num_unique_z < num_unique * (1 - duplicate_tolerance): + warnings.warn( + """Z-scoring these simulation outputs resulted in {num_unique_z} unique + datapoints. Before z-scoring, it had been {num_unique}. This can occur due + to numerical inaccuracies when the data covers a large range of values. + Consider either setting `z_score_x=False` (but beware that this can be + problematic for training the NN) or exclude outliers from your dataset. + Note: if you have already set `z_score_x=False`, this warning will still be + displayed, but you can ignore it.""", + UserWarning, + ) + def x_shape_from_simulation(batch_x: Tensor) -> torch.Size: ndims = batch_x.ndim