# Noise Evaluation

This notebook demonstrates few tasks one can do with the noise model entities.

In [1]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import numpy as np

In [3]:
from enum import Enum

In [4]:
from ase.envs.graph import Graph
from ase.actors.graph import GraphActor, GraphAction
from ase.tools.noise import UniformNoiseModel
from ase.tools.order import Order

## Posterior Sampling

In this section, we demonstrate the posterior sampling from the noise model. The goal is to sample a corresponding noise to the realized sample of the given probability distribution. We repeat the experiment several times, thus trying to eliminate any potentially flaky results.

### Proof of Concept

For a naive test of posterior sampling, we hard-code a probability distribution and repeatedly generate samples from it. Then, for each generated sample, we sample a corresponding noise instance and ensure sampling from the new noise corresponds to the realised observation.

In [5]:
num_repeats = 100
distribution = [0.1, 0.25, 0.45, 0.2]

In [6]:
for i in range(num_repeats):
    realised = np.random.choice(range(len(distribution)), p=distribution)

    noise = UniformNoiseModel()
    noise = noise.sample_posterior(probs=distribution, realised=realised)

    assert realised == noise.choice(distribution)

### Trajectory Sampling

For a more advanced version of posterior sampling, we turn our attention to the toy graph environment. To validate the correctness of implementation, we follow the steps:
1. We sample a trajectory;
2. We sample new noise from the posterior, conditioning on the trajectory;
3. Using the new noise, we sample another trajectory;
4. We ensure that the sampled trajectory matches the original trajectory.

The noise-sampling procedure also has internal asserts that (partially) ensure the correctness of the implementation.

In [7]:
num_repeats = 100
order = Order([GraphAction.up.value, GraphAction.down.value, GraphAction.straight.value])
env = Graph(num_agents=6, num_levels=3, num_columns=3, act_noise_model=UniformNoiseModel(order=order))
agents = [GraphActor(id=id) for id in range(env.num_agents)]

In [8]:
for i in range(num_repeats):
    traj = env.sample_trajectory(agents)
    act_noise, env_noise = env.sample_noise_from_posterior(trajectory=traj, agents=agents)
    traj_counterfactual = env.sample_trajectory(agents=agents, act_noise=act_noise, env_noise=env_noise, initial_state=traj.states[0])
    assert traj.states == traj_counterfactual.states

Finally, we ensure that generated counterfactual trajectory matches when the assumed order is wrong.

In [9]:
# sample trajectories with correct total order
order = Order([GraphAction.up.value, GraphAction.down.value, GraphAction.straight.value])
env_true = Graph(num_agents=6, num_levels=3, num_columns=3, act_noise_model=UniformNoiseModel(order=order))
trajectories = [env_true.sample_trajectory(agents) for _ in range(num_repeats)]

In [10]:
for i in range(num_repeats):
    # shuffle the order
    shuffled_order = [GraphAction.up.value, GraphAction.down.value, GraphAction.straight.value]
    np.random.shuffle(shuffled_order)
    shuffled_order = Order(shuffled_order)

    # create the environment with shuffled order
    env = Graph(num_agents=6, num_levels=3, num_columns=3, act_noise_model=UniformNoiseModel(order=shuffled_order))
   
    # sample counterfactual trajectory
    traj = trajectories[i]
    act_noise, env_noise = env.sample_noise_from_posterior(trajectory=traj, agents=agents)
    traj_counterfactual = env.sample_trajectory(agents=agents, act_noise=act_noise, env_noise=env_noise, initial_state=traj.states[0])
    assert traj.states == traj_counterfactual.states

## Noise Monotonicity

In this section, we demonstrate the noise-monotonic property w.r.t. a given total order.

In [11]:
class Fruit(int, Enum):
    apple = 0
    banana = 1
    strawberry = 2
    grapes = 3

In [12]:
num_repeats = 5
distribution = [0.1, 0.25, 0.45, 0.2]
order = Order([Fruit.strawberry.value, Fruit.apple.value, Fruit.grapes.value, Fruit.banana.value])

In [13]:
for i in range(num_repeats):
    noise = UniformNoiseModel(order=order)
    noise_1 = noise.sample()
    noise_2 = noise.sample()

    if noise_1.noise < noise_2.noise:
        assert order.compare(noise_1.choice(probs=distribution), noise_2.choice(probs=distribution)) <= 0
    if noise_1.noise > noise_2.noise:
        assert order.compare(noise_1.choice(probs=distribution), noise_2.choice(probs=distribution)) >= 0