In [1]:
import jax.random as jrn
import jax.numpy as jnp
import jax.lax as jla
import jax
import random

import jaxtyping as jtp

from tqdm import trange
from jax import tree_util

from game.run import dummy_history

# TODO

do not calculate full bool targets from last history

In [2]:
player_total = 7
game_len = 15

key = jrn.PRNGKey(random.randint(0, 1000))
history = dummy_history(key, player_total, game_len, prob_vote=0.7)


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
def rate_votes(state):
    """
    """

    def rate(player: int, winner, voted, roles):
        vote = voted[0, player]

        rating = jnp.zeros([2])
        rating = rating.at[vote.astype(int)].set(1)

        # invert in case opponent wins
        role = roles[0, player]
        role = role != 0 # iff player is F
        
        invs = winner.argmax() != role
        rating = jla.select(invs, 1 - rating, rating)

        # if no winner, set rating to 0.5
        skip = winner.sum() == 0
        rating = jla.select(skip, rating.at[:].set(0.5), rating)
        
        return rating

    rate_vmap = jax.vmap(rate, in_axes=(None, None, 0, 0))
    rate_vmap_vmap = jax.vmap(rate_vmap, in_axes=(0, None, None, None))

    player_total = state["roles"].shape[-1]
    players = jnp.arange(player_total)
    
    winner = state["winner"][-1, 0]
    
    voted = state["voted"][1:]
    roles = state["roles"][1:]

    return rate_vmap_vmap(players, winner, voted, roles) # type: ignore


player = random.randint(0, player_total)
time = random.randint(0, game_len)

print("player", player)
print("time  ", time)
print("role  ", history["roles"][time, 0, player])

print("winner", *history["winner"][-1, 0].astype(int))

print("voted ", *history["voted"][-1, :, player].astype(int)[::-1])

print("rating\n", rate_votes(history)[player].astype(float)) #.shape

player 6
time   4
role   1
winner 0 1
voted  1 1 1 1 1 0 0 0 0 0 0 0 0 0 0
rating
 [[0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]]


In [4]:
def rate_presi_disc(state):
    """
    """

    def rate(player: int, winner, presi_shown, chanc_shown, roles):
        discarded = presi_shown[0] - chanc_shown[0]
        discarded = discarded.argmax(axis=-1)

        rating = jnp.zeros([2])
        rating = rating.at[discarded.astype(int)].set(1)

        # invert in case opponent wins
        role = roles[0, player]
        role = role != 0 # iff player is F

        invs = winner.argmax() != role
        rating = jla.select(invs, 1 - rating, rating)

        # if no winner, set rating to 0.5
        skip = winner.sum() == 0
        
        # also skip, if there are no cards (vote did not get through)
        skip |= presi_shown[0].sum() == 0

        # or president has no choice
        skip |= presi_shown[0, 0].sum() == 0
        skip |= presi_shown[0, 1].sum() == 0

        rating = jla.select(skip, rating.at[:].set(0.5), rating)

        return rating

    rate_vmap = jax.vmap(rate, in_axes=(None, None, 0, 0, 0))
    rate_vmap_vmap = jax.vmap(rate_vmap, in_axes=(0, None, None, None, None))

    player_total = state["roles"].shape[-1]
    players = jnp.arange(player_total)

    winner = state["winner"][-1, 0]

    presi_shown = state["presi_shown"][1:]
    chanc_shown = state["chanc_shown"][1:]
    roles = state["roles"][1:]

    return rate_vmap_vmap(players, winner, presi_shown, chanc_shown, roles) # type: ignore


player = random.randint(0, player_total)
time = random.randint(0, game_len)

print("player", player)
print("time  ", time)
print("role  ", history["roles"][time, 0, player])

print("winner", *history["winner"][-1, 0].astype(int))

print("presi_shown ", *history["presi_shown"][-1].astype(int)[::-1])
print("chanc_shown ", *history["chanc_shown"][-1].astype(int)[::-1])

print("rating\n", rate_presi_disc(history)[player].astype(float)) #.shape

player 7
time   9
role   1
winner 0 1
presi_shown  [0 0] [1 2] [1 2] [0 3] [2 1] [0 3] [2 1] [0 3] [3 0] [1 2] [2 1] [0 0] [1 2] [1 2] [2 1]
chanc_shown  [0 0] [0 2] [0 2] [0 2] [1 1] [0 2] [2 0] [0 2] [2 0] [0 2] [2 0] [0 0] [0 2] [0 2] [1 1]
rating
 [[0.5 0.5]
 [1.  0. ]
 [1.  0. ]
 [0.5 0.5]
 [1.  0. ]
 [0.5 0.5]
 [0.  1. ]
 [0.5 0.5]
 [0.5 0.5]
 [1.  0. ]
 [0.  1. ]
 [0.5 0.5]
 [1.  0. ]
 [1.  0. ]
 [1.  0. ]]


In [22]:
def rate_chanc_disc(state):
    """
    """

    def rate(player: int, winner, chanc_shown, roles, board):
        non_discarded = board[0] - board[1]
        discarded = chanc_shown[0] - non_discarded[0]
        discarded = discarded.argmax(axis=-1)

        rating = jnp.zeros([2])
        rating = rating.at[discarded.astype(int)].set(1)

        # invert in case opponent wins
        role = roles[0, player]
        role = role != 0 # iff player is F

        invs = winner.argmax() != role
        rating = jla.select(invs, 1 - rating, rating)

        # if no winner, set rating to 0.5
        skip = winner.sum() == 0
        
        # also skip, if there are no cards (vote did not get through)
        skip |= chanc_shown[0].sum() == 0

        # or president has no choice
        skip |= chanc_shown[0, 0].sum() == 0
        skip |= chanc_shown[0, 1].sum() == 0

        rating = jla.select(skip, rating.at[:].set(0.5), rating)

        return rating #skip#non_discarded

    rate_vmap = jax.vmap(rate, in_axes=(None, None, 0, 0, 0))
    rate_vmap_vmap = jax.vmap(rate_vmap, in_axes=(0, None, None, None, None))

    player_total = state["roles"].shape[-1]
    players = jnp.arange(player_total)

    winner = state["winner"][-1, 0]

    chanc_shown = state["chanc_shown"][1:]
    roles = state["roles"][1:]
    board = state["board"][1:]

    return rate_vmap_vmap(players, winner, chanc_shown, roles, board) # type: ignore


player = random.randint(0, player_total)
time = random.randint(0, game_len)

print("player", player)
print("time  ", time)
print("role  ", history["roles"][time, 0, player])

print("winner", *history["winner"][-1, 0].astype(int))

print("chanc_shown ", *history["chanc_shown"][-1].astype(int)[::-1])
print("board       ", *history["board"][-1].astype(int)[::-1])

print("tracker", *history["tracker"][-1].astype(int)[::-1])

print("rating\n", rate_chanc_disc(history)[player].astype(float)) #.shape

player 2
time   12
role   0
winner 0 1
chanc_shown  [0 0] [0 2] [0 2] [0 2] [1 1] [0 2] [2 0] [0 2] [2 0] [0 2] [2 0] [0 0] [0 2] [0 2] [1 1]
board        [0 0] [0 1] [0 2] [0 3] [0 4] [0 5] [1 5] [1 6] [2 6] [2 7] [3 7] [3 7] [3 8] [3 9] [ 3 10]
tracker 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0
rating
 [[0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.  1. ]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.  1. ]]
