diff --git a/UnitTests/test_core.py b/UnitTests/test_core.py index 4dab0f4..f038612 100644 --- a/UnitTests/test_core.py +++ b/UnitTests/test_core.py @@ -1,9 +1,9 @@ import unittest -from unittest.mock import patch, mock_open +from unittest.mock import patch from tempfile import TemporaryDirectory import os -import glob +import math import random import numpy as np import pandas as pd @@ -115,11 +115,13 @@ def __init__(self): self.conv1 = nn.Conv2d(1, 20, 5) self.bn1 = nn.BatchNorm2d(20) - def is_normal_distribution(self, tensor, mean, std, num_std=3): - # Calculate Z-score - z_scores = (tensor - mean) / std - # Check if values are within num_std standard deviations - return torch.all(torch.abs(z_scores) < num_std).item() + + def is_normal_distribution(self, tensor, mean, std, tolerance=0.01): + # Check if the mean of the tensor is close to the expected mean + mean_close = math.isclose(tensor.mean().item(), mean, abs_tol=tolerance) + # Check if the standard deviation of the tensor is close to the expected std + std_close = math.isclose(tensor.std().item(), std, abs_tol=tolerance) + return mean_close and std_close def setUp(self): self.net = test_NetworkWeights.MockNetwork()