In [49]:
from model_based.transition_models import _CatchEquiLogits, catch_action_transform, catch_transform
from gymnax.environments.bsuite import Catch
from base_rl.wrappers import FlattenObservationWrapper
import jax.numpy as jnp
import jax
from typing import Generator, Tuple
import random
import distrax
import numpy as np
import jaxtyping as jt

In [50]:
def mock_catch(num_states) -> Generator[Tuple[jt.Array, jt.Array], None, None]:
    key = jax.random.key(10)
    env = FlattenObservationWrapper(Catch())
    env_params = env.default_params
    obs, env_state = env.reset(key, env_params)
    action = jax.random.randint(key, shape=(), minval=0, maxval=3)
    for _ in range(num_states):
        for _ in range(random.randint(0, 5)):
            action = jax.random.randint(key, shape=(), minval=0, maxval=3)
            obs, env_state, _, _, _ = env.step(key, env_state, action, env_params)

            _, key = jax.random.split(key)
        yield obs, action

In [143]:

def prediciton_dist(
    stacked_states: jt.Array, initial_state: jt.Array
) -> jt.Array:
    """Pools returning the closes state to the initial_state.

    Only implemented for the C2 group
    """
    def _ball_l1(pred_logits: jt.Array, state: jt.Array):
        assert pred_logits.shape == state.shape
        pred_ball_y, pred_ball_x = divmod(distrax.Categorical(logits=pred_logits).mode(), 5)
        ball_loc_y, ball_loc_x = divmod(distrax.Categorical(logits=state).mode(), 5)
        pbx = distrax.Categorical(logits=pred_logits).mode()  %5
        assert pred_ball_x == pbx , f"{pred_ball_x}, {pbx}"
        return jnp.sqrt((pred_ball_x - ball_loc_x)**2 + (pred_ball_y - ball_loc_y)**2)
    def _paddel_l1(pred_logits:jt.Array, state: jt.Array):
        assert pred_logits.shape == state.shape
        pad_loc = distrax.Categorical(logits=state).mode()
        pred_pad = distrax.Categorical(logits=pred_logits).mode()
        return jnp.abs(pred_pad - pad_loc)

    def joint_l1(logits, initial_state):
        init_ball, init_pad = initial_state.at[:45].get(), initial_state.at[45:].get()
        ball_logits, pad_logits = logits.at[:45].get(), logits.at[45:].get()
        ball_l1 = _ball_l1(ball_logits, init_ball)
        pad_l1 = _paddel_l1(pad_logits, init_pad)
        return ball_l1  + pad_l1
    """Distance metric needs to be symmetric"""
    next_state = stacked_states.at[..., 0].get()
    mirror_state = stacked_states.at[..., 1].get()
    # vis_from_logits(initial_state)

    # vis_from_logits(next_state)
    
    next_l1 = joint_l1(next_state, initial_state)
    # print(next_l1)
    mirror_l1 = joint_l1(mirror_state, initial_state)

    return next_l1, mirror_l1
    # return next_state, mirror_state

def vis_from_logits(logits):
    ball_logits, pad_logits = logits.at[:45].get(), logits.at[45:].get()
    ball_dist = distrax.Categorical(logits=ball_logits)
    pad_dist = distrax.Categorical(logits=pad_logits)
    ball = jnp.zeros(45).at[ball_dist.mode()].set(1.0)
    pad = jnp.zeros(5).at[pad_dist.mode()].set(1.0)
    state = jnp.concatenate([ball, pad], axis=0)
    print(state.reshape(10, 5))


In [146]:
key = jax.random.PRNGKey(10)    
model = _CatchEquiLogits()
model_params = model.init(key, jnp.zeros(50,), jnp.zeros(1))
apply_fn = jax.jit(model.apply)

state_gen = mock_catch(5)
for state, action in state_gen:
    logits = apply_fn(model_params, state, action)
    next_l1, mirror_l1 = prediciton_dist(logits, state)

    inv_state = catch_transform(state)
    inv_action = catch_action_transform(action)
    inv_logits = apply_fn(model_params, inv_state, inv_action)
    
    inv_next_l1, inv_mirror_l1 = prediciton_dist(inv_logits, inv_state)

    equivariant_pair_1 = next_l1 + inv_mirror_l1
    equivariant_pair_2 = mirror_l1 + inv_next_l1
    old_equi = (equivariant_pair_1, equivariant_pair_2)

    state = catch_transform(state)
    action = catch_action_transform(action)
    
    logits = apply_fn(model_params, state, action)
    next_l1, mirror_l1 = prediciton_dist(logits, state)

    inv_state = catch_transform(state)
    inv_action = catch_action_transform(action)
    inv_logits = apply_fn(model_params, inv_state, inv_action)
    
    inv_next_l1, inv_mirror_l1 = prediciton_dist(inv_logits, inv_state)

    equivariant_pair_1 = next_l1 + inv_mirror_l1
    equivariant_pair_2 = mirror_l1 + inv_next_l1
    assert min(equivariant_pair_1, equivariant_pair_2) == min(old_equi)
    





