In [1]:
import rlcard
import torch
from src import DoubleDQNAgent, ModifiedGinRummyEnv
import numpy as np

In [2]:
!pip install statsmodels
import statsmodels.stats.api as sms



In [3]:
# load checkpoints from our different models. rand-rule is mixed adversary model.

dqn_rand_1 = DoubleDQNAgent.from_checkpoint(torch.load("checkpoints/homegrown/y34n75js/devout-sweep-4/dqn_11500000.pt"))
dqn_rand_2 = DoubleDQNAgent.from_checkpoint(torch.load("checkpoints/homegrown/y34n75js/vital-sweep-4/dqn_11250000.pt"))
dqn_rand_3 = DoubleDQNAgent.from_checkpoint(torch.load("checkpoints/homegrown/y34n75js/autumn-sweep-6/dqn_11950000.pt"))

dqn_rand_rule_1 = DoubleDQNAgent.from_checkpoint(torch.load("checkpoints/homegrown/f4k75mzd/rural-sweep-3/dqn_11450000.pt"))
dqn_rand_rule_2 = DoubleDQNAgent.from_checkpoint(torch.load("checkpoints/homegrown/f4k75mzd/dainty-sweep-5/dqn_10850000.pt"))
dqn_rand_rule_3 = DoubleDQNAgent.from_checkpoint(torch.load("checkpoints/homegrown/f4k75mzd/radiant-sweep-6/dqn_10800000.pt"))

dqn_rule_1 = DoubleDQNAgent.from_checkpoint(torch.load("checkpoints/homegrown/3beipkbe/gentle-sweep-4/dqn_11900000.pt"))
dqn_rule_2 = DoubleDQNAgent.from_checkpoint(torch.load("checkpoints/homegrown/3beipkbe/worthy-sweep-5/dqn_12000019.pt"))
dqn_rule_1 = DoubleDQNAgent.from_checkpoint(torch.load("checkpoints/homegrown/3beipkbe/gentle-sweep-4/dqn_10850000.pt"))

In [3]:
# the rule agent needs to be downloaded from the RLCard repo, as it's inaccessible through the package
!wget https://raw.githubusercontent.com/datamllab/rlcard/master/rlcard/models/gin_rummy_rule_models.py

--2024-04-25 10:52:28--  https://raw.githubusercontent.com/datamllab/rlcard/master/rlcard/models/gin_rummy_rule_models.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8000::154, 2606:50c0:8002::154, 2606:50c0:8003::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8000::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5072 (5.0K) [text/plain]
Saving to: ‘gin_rummy_rule_models.py’


2024-04-25 10:52:28 (130 MB/s) - ‘gin_rummy_rule_models.py’ saved [5072/5072]



In [4]:
from gin_rummy_rule_models import GinRummyNoviceRuleAgent

In [8]:
agents = [[dqn_rand_1, dqn_rand_2, dqn_rand_3], [dqn_rand_rule_1, dqn_rand_rule_2, dqn_rand_rule_3], [dqn_rule_1, dqn_rule_2], [rlcard.agents.RandomAgent(num_actions=110)], [GinRummyNoviceRuleAgent()]]
env = ModifiedGinRummyEnv({'allow_step_back': False, 'seed': None})

In [26]:
# for generating reward table
from rlcard.utils import tournament
r = []
s = [] #rlcard.make("gin-rummy")

for agent in agents:
    r_ = []
    s_ = []
    for adversary in agents:
        t = 0
        s__ = []
        for a in agent:
            for ad in adversary:
                env.set_agents([a, ad])
                t_ = tournament(env, 200)[0]
                t += t_
                s__.append(t_)
        r_.append(t/(len(agent) * len(adversary)))
        c = sms.DescrStatsW(s__).tconfint_mean()
        s_.append((c[1]-c[0])/2)
    r.append(r_)
    s.append(s_)

In [27]:
print(np.array(r))
print(np.array(s))

[[-0.22376111 -0.18334444 -0.09309167  0.16953333 -0.25426667]
 [-0.19082778 -0.16327222 -0.123825    0.13108333 -0.24296667]
 [-0.40419167 -0.39409167 -0.4497625  -0.174875   -0.4088    ]
 [-0.55688333 -0.56743333 -0.575825   -0.5222     -0.594     ]
 [ 0.08695     0.10093333  0.1472      0.18885    -0.0221    ]]
[[0.11163221 0.10647668 0.1343987  0.05350208 0.34849142]
 [0.13768412 0.13066632 0.20723741 0.06651581 0.39447278]
 [0.06998022 0.08522962 0.03181341 0.83892717 0.13786232]
 [0.03465402 0.02851091 0.27540699        nan        nan]
 [0.16599585 0.17613055 0.09720247        nan        nan]]


In [16]:
def run(env, is_training=False):
    trajectories = [[] for _ in range(env.num_players)]
    state, player_id = env.reset()

    # Loop to play the game
    trajectories[player_id].append(state)
    while not env.is_over():
        # Agent plays
        if not is_training:
            action, _ = env.agents[player_id].eval_step(state)
        else:
            action = env.agents[player_id].step(state)

        # Environment steps
        next_state, next_player_id = env.step(action, env.agents[player_id].use_raw)
        # Save action
        trajectories[player_id].append(action)

        # Set the state and player
        state = next_state
        player_id = next_player_id

        # Save state.
        if not env.game.is_over():
            trajectories[player_id].append(state)

    # Add a final state to all the players
    for player_id in range(env.num_players):
        state = env.get_state(player_id)
        trajectories[player_id].append(state)

    # Payoffs
    payoffs = env.get_payoffs()

    return trajectories, payoffs

def tournament(env, num):
    payoffs = [0 for _ in range(env.num_players)]
    counter = 0
    lens = []
    while counter < num:
        _, _payoffs = env.run(is_training=False)
        lens.append((len(_[0])-2)//2)
        if isinstance(_payoffs, list):
            for _p in _payoffs:
                for i, _ in enumerate(payoffs):
                    payoffs[i] += _p[i]
                counter += 1
        else:
            for i, _ in enumerate(payoffs):
                payoffs[i] += _payoffs[i]
            counter += 1
    for i, _ in enumerate(payoffs):
        payoffs[i] /= counter
    return payoffs, lens

def tournament_wins(env, num):
    # only works for environments with two agents...
    wins = [0.0 for _ in range(env.num_players)]
    counter = 0
    while counter < num:
        _, _payoffs = env.run(is_training=False)
        if isinstance(_payoffs, list):
            for _p in _payoffs:
                for i, _ in enumerate(wins):
                    wins[i] += 1 if _p[i] > p_[1-i] else 0
                counter += 1
        else:
            for i, _ in enumerate(wins):
                wins[i] += 1 if _payoffs[i] > _payoffs[1-i] else 0
            counter += 1
    for i, _ in enumerate(wins):
        wins[i] /= counter
    return wins

In [42]:
# for generating episode length table
l = []
s = []

for agent in agents:
    l_ = []
    s_ = []
    for adversary in agents:
        t = []
        for a in agent:
            for ad in adversary:
                env.set_agents([a, ad])
                _, lens = tournament(env, 200)
                t += lens
        l_.append(np.mean(t))
        c = sms.DescrStatsW(t).tconfint_mean()
        s_.append((c[1]-c[0])/2)
    l.append(l_)
    s.append(s_)

In [43]:
print(np.array(l))
print(np.array(s))

[[53.61611111 48.68611111 51.63666667 38.01666667 26.54833333]
 [46.635      45.30166667 55.21166667 41.91666667 28.455     ]
 [50.57       56.12416667 82.56625    66.1875     26.6525    ]
 [34.89833333 40.46666667 65.27       60.565      22.575     ]
 [26.78166667 28.07666667 28.32       25.97       25.1       ]]
[[1.74540946 1.62216849 2.03873773 1.56376451 1.19628581]
 [1.59455009 1.48882552 2.08917081 1.89133319 1.260551  ]
 [2.01217468 2.10957815 2.23457339 2.48705883 1.46447287]
 [1.52644154 1.85976934 2.35749466 1.15860611 1.63811587]
 [1.22649454 1.19680545 1.64346541 1.83196865 1.65157369]]


In [40]:
np.array(s).shape

(5, 5, 200)

In [27]:
# for generating win-rate table
w = []
s = []

for agent in agents:
    l_ = []
    s_ = []
    for adversary in agents:
        t = []
        for a in agent:
            for ad in adversary:
                env.set_agents([a, ad])
                wins = tournament_wins(env, 500)
                t += [wins[0]]
        l_.append(np.mean(t))
        c = sms.DescrStatsW(t).tconfint_mean()
        s_.append((c[1]-c[0])/2)
    w.append(l_)
    s.append(s_)

In [28]:
print(np.array(w))
print(np.array(s))

[[0.49555556 0.52911111 0.741      0.98133333 0.254     ]
 [0.47       0.49666667 0.715      0.98       0.24266667]
 [0.23433333 0.281      0.491      0.832      0.148     ]
 [0.01266667 0.02866667 0.148      0.464      0.004     ]
 [0.72866667 0.74533333 0.846      0.984      0.478     ]]
[[0.10442795 0.12936505 0.0778537  0.02500644 0.25892295]
 [0.11177977 0.0971628  0.08151352 0.01791337 0.31265943]
 [0.07821124 0.08939073 0.03970625 0.38118614 0.66072265]
 [0.01744801 0.00286844 0.22871169        nan        nan]
 [0.25378763 0.34386544 0.05082482        nan        nan]]
