In [4]:
import unittest
import jax
import jax.numpy as jnp
from thesis import reinforce, binomial_f, grad_log_binomial
from thesis import s_to_random_walk, grad_log_random_walk, random_walk_f
from thesis import count_neighbors, update_stochastic, run_game, game_of_life, reward, finite_difference_estimate

class TestReinforceGeneric(unittest.TestCase):

    def reinforce_test(self, n, p, num_samples, sample_fn, grad_log_prob_fn, expected_grad, delta=0.1):
        key = jax.random.PRNGKey(0)
        grad_est, _, _ = reinforce(key, n, p, num_samples, sample_fn, grad_log_prob_fn)
        self.assertAlmostEqual(grad_est, expected_grad, delta=delta)

    def test_binomial(self):
        self.reinforce_test(
            n=2,
            p=0.3,
            num_samples=1000000,
            sample_fn=binomial_f,
            grad_log_prob_fn=grad_log_binomial,
            expected_grad=2  # d/dp E[X] = n
        )

    def test_random_walk(self):
        self.reinforce_test(
            n=2,
            p=0.3,
            num_samples=1000000,
            sample_fn=random_walk_f,
            grad_log_prob_fn=grad_log_random_walk,
            expected_grad=4  # d/dp E[W_n] = 4 for n=2
        )
        
    #testing stochastic game of life using finite difference
    def test_game_of_life(self):
            key = jax.random.PRNGKey(42)
            p = 0.5
            T = 10
            N = 10
            num_samples = 100_000
            tolerance = 0.05  # 5%
    
            reinforce_grad = game_of_life(key, p, T, N, num_samples)
            fd_grad = finite_difference_estimate(p, key, T, N, num_samples)
    
            relative_error = jnp.abs(reinforce_grad - fd_grad) / jnp.abs(fd_grad)
            print(f"REINFORCE: {reinforce_grad:.4f}, FD: {fd_grad:.4f}, Error: {relative_error:.4%}")
    
            self.assertLessEqual(relative_error, tolerance, "Gradient estimate is not within 5% tolerance")


if __name__ == '__main__':
    unittest.main()


ModuleNotFoundError: No module named 'thesis'