In [8]:
import jax.random as jrn
import jax.numpy as jnp
import jax.lax as jla
import jax
import random

import jaxtyping as jtp

from tqdm import trange
from jax import tree_util

from game.run import dummy_history

# TODO

do not calculate full bool targets from last history

In [54]:
player_total = 7
game_len = 15

key = jrn.PRNGKey(random.randint(0, 1000))
batched_history = dummy_history(key, player_total, game_len, prob_vote=0.7)


In [78]:
def rate_votes(state):
    """
    """

    def rate(player: int, winner, voted, roles):
        vote = voted[0, player]

        rating = jnp.zeros([2])
        rating = rating.at[vote.astype(int)].set(1)

        # invert in case opponent wins
        role = roles[0, player]
        role = role != 0 # iff player is F
        
        invs = winner.argmax() != role
        rating = jla.select(invs, 1 - rating, rating)

        # if no winner, set rating to 0.5
        skip = winner.sum() == 0
        rating = jla.select(skip, rating.at[:].set(0.5), rating)
        
        return rating

    rate_vmap = jax.vmap(rate, in_axes=(None, None, 0, 0))
    rate_vmap_vmap = jax.vmap(rate_vmap, in_axes=(0, None, None, None))

    player_total = state["roles"].shape[-1]
    players = jnp.arange(player_total)
    
    winner = state["winner"][-1, 0]
    
    voted = state["voted"][1:]
    roles = state["roles"][1:]

    return rate_vmap_vmap(players, winner, voted, roles) # type: ignore


player = random.randint(0, player_total)
time = random.randint(0, game_len)

print("player", player)
print("time  ", time)
print("role  ", batched_history["roles"][time, 0, player])

print("winner", *batched_history["winner"][-1, 0].astype(int))

print("voted ", *batched_history["voted"][-1, :, player].astype(int))

print("rating", *rate_votes(batched_history)[player].argmax(axis=-1).astype(int)[::-1]) #.shape

player 4
time   13
role   0
winner 0 1
voted  1 1 1 1 0 1 0 0 1 1 1 1 1 1 0
rating 0 0 0 0 1 0 1 1 0 0 0 0 0 0 1


In [4]:
def vote_rating(player: int, endwin, voted, roles):
    """
    """
    vote = voted[0, player]

    rating = jnp.zeros([2])
    rating = rating.at[vote.astype(int)].set(1)

    # invert in case opponent wins
    role = roles[0, player]
    role = role != 0 # iff player is F
    
    invs = endwin.argmax() != role
    rating = jla.select(invs, 1 - rating, rating)

    # if no winner, set rating to 0.5
    skip = endwin.sum() == 0
    rating = jla.select(skip, rating.at[:].set(0.5), rating)
    
    return rating


print("voted  ", history["voted"][0, player].astype(int))

vote_rating(player, endwin, history["voted"], history["roles"])

voted   1


DeviceArray([0., 1.], dtype=float32)

In [5]:
vote_rating_vmap = jax.vmap(vote_rating, in_axes=(None, None, 0, 0), out_axes=0)

print(batched_history["voted"][-1, :, player].astype(int))

vote_rating_vmap(player, endwin, batched_history["voted"][1:], batched_history["roles"][1:])

[1 1 1 1 1 1 1 1 0 1 1 1 0 1 1]


DeviceArray([[0., 1.],
             [0., 1.],
             [1., 0.],
             [0., 1.],
             [0., 1.],
             [0., 1.],
             [1., 0.],
             [0., 1.],
             [0., 1.],
             [0., 1.],
             [0., 1.],
             [0., 1.],
             [0., 1.],
             [0., 1.],
             [0., 1.]], dtype=float32)

In [6]:
vote_rating_vmap_vmap = jax.vmap(vote_rating_vmap, in_axes=(0, None, None, None), out_axes=0)

players = jnp.arange(5)

vote_rating_vmap_vmap(players, endwin, batched_history["voted"][1:], batched_history["roles"][1:]).shape

(5, 15, 2)

In [6]:
player = 1

print("voted  ", *history["voted"][:, player].astype(int))
print("winner ", *history["winner"].astype(int))
print("roles  ", *history["roles"][0].astype(int))

voted   0 0 0 1 1 1 1 0 0 0 0
winner  [1 0] [0 0] [0 0] [0 0] [0 0] [0 0] [0 0] [0 0] [0 0] [0 0] [0 0]
roles   2 0 0 0 1


In [7]:
def presi_disc_rating(player, endwin, roles, presi_shown, chanc_shown):
    raise NotImplementedError

In [5]:
def vote_data(player: int, winner, voted, roles, **_):
    """
    """
    role = roles[0, player]
    role = role != 0 # only check if F

    # 0: L win, 1: F win, 
    # also 0 for no winner
    wins = winner[0].argmax()
    invs = wins != role

    # iff there is a winner
    skip = winner[0].sum() > 0

    acts = voted[:, player].astype(jnp.int32)
    rang = jnp.arange(acts.shape[0])

    data = jnp.zeros((*acts.shape, 2))
    data = data.at[rang, acts].set(1)

    data = jla.select(invs, 1 - data, data)
    data = jla.select(skip, data, data.at[:].set(0.5))

    return data


vote_data(player, **history)

Array([[0., 1.],
       [0., 1.],
       [0., 1.],
       [1., 0.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [1., 0.],
       [0., 1.],
       [0., 1.]], dtype=float32)

In [6]:
print("presi", *history["presi"].astype(int))
print("shown", *history["presi_shown"].astype(int))
print("given", *history["chanc_shown"].astype(int))

presi 0 3 1 0 3 1 0 3 2 1 0
shown [0 3] [2 1] [1 2] [0 0] [2 1] [1 2] [1 2] [1 2] [2 1] [0 3] [1 2]
given [0 2] [2 0] [0 2] [0 0] [1 1] [1 1] [0 2] [0 2] [1 1] [0 2] [0 2]


In [8]:
def presi_discard_data(player: int, roles, winner, presi_shown, chanc_shown, **_):
    """
    """
    role = roles[0, player]
    role = role != 0 # only check if F

    # 0: L win, 1: F win, 
    # also 0 for no winner
    wins = winner[0].argmax()
    invs = wins != role

    acts = presi_shown - chanc_shown
    acts = acts.argmax(axis=-1)

    skip = presi_shown.sum(axis=-1) == 0

    data = jnp.zeros((*acts.shape, 2))

    rang = jnp.arange(acts.shape[0])
    data = data.at[rang, acts].set(1)

    data = jnp.where(skip[:, None], 0.5, data)
    data = jla.select(invs, 1 - data, data) # type: ignore

    return data


presi_discard_data(player, **history)

Array([[0. , 1. ],
       [0. , 1. ],
       [1. , 0. ],
       [0.5, 0.5],
       [1. , 0. ],
       [0. , 1. ],
       [1. , 0. ],
       [1. , 0. ],
       [1. , 0. ],
       [0. , 1. ],
       [1. , 0. ]], dtype=float32)

In [18]:
print("shown", *history["chanc_shown"].astype(int))
print("board", *history["board"].astype(int))
print("disc ", *history["disc"].astype(int))

shown [0 2] [2 0] [0 2] [0 0] [1 1] [1 1] [0 2] [0 2] [1 1] [0 2] [0 2]
board [3 7] [3 6] [2 6] [2 5] [2 5] [1 5] [0 5] [0 4] [0 3] [0 2] [0 1]
disc  [0 2] [3 5] [2 4] [1 3] [1 3] [0 2] [5 5] [4 4] [3 3] [1 3] [1 1]


In [28]:
def chanc_discard_data(player: int, roles, winner, chanc_shown, board, **_):
    """
    """
    role = roles[0, player]
    role = role != 0 # only check if F

    # 0: L win, 1: F win, 
    # also 0 for no winner
    wins = winner[0].argmax()
    invs = wins != role

    acts = board[:-1] - board[1:]
    skip = acts.sum(axis=-1) == 0

    acts = acts.argmax(axis=-1)

    data = jnp.zeros((*acts.shape, 2))
 
    rang = jnp.arange(acts.shape[0])
    data = data.at[rang, acts].set(1)

    data = jnp.where(skip[:, None], 0.5, data)
    data = jla.select(invs, 1 - data, data) # type: ignore

    last = jnp.array([0.5, 0.5])[None]
    data = jnp.concatenate([data, last], axis=0)

    return acts, skip.astype(int), data


chanc_discard_data(player, **history)

(Array([1, 0, 1, 0, 0, 0, 1, 1, 1, 1], dtype=int32),
 Array([0, 0, 0, 1, 0, 0, 0, 0, 0, 0], dtype=int32),
 Array([[0. , 1. ],
        [1. , 0. ],
        [0. , 1. ],
        [0.5, 0.5],
        [1. , 0. ],
        [1. , 0. ],
        [0. , 1. ],
        [0. , 1. ],
        [0. , 1. ],
        [0. , 1. ],
        [0.5, 0.5]], dtype=float32))