In [None]:
%cd ../..
%reload_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from statistics import mean

from alpha_one.data.replay import ReplayDataManager
from alpha_one.model.model_manager import AlphaOneCheckpointManager

In [None]:
game_name = 'leduc_poker'
run_name = 'LP-local-39'

In [None]:
model_manager = AlphaOneCheckpointManager(game_name, run_name)
observation_data_manager = model_manager.observation_model_manager.get_replay_data_manager()

In [None]:
observation_buffer = observation_data_manager.load_replays(-1)

In [None]:
grouped_samples = defaultdict(list)
for sample in observation_buffer.data[-10000:]:
    grouped_samples[str(sample.observation)].append(sample)

In [None]:
entropies = []
for samples in grouped_samples.values():
    target_policy = np.zeros(len(samples[0].legals_mask))
    for sample in samples:
        target_policy += sample.policy
    target_policy /= len(samples)
    
    target_policy = target_policy[np.where(target_policy > 0)]
    entropies.append(np.sum(- target_policy * np.log(target_policy)))

In [None]:
mean(entropies)

In [None]:
state_ambiguities = [len({str(sample.policy[sample.legals_mask]) for sample in samples}) for samples in grouped_samples.values()]

In [None]:
histogram = defaultdict(lambda: 0)
for state_ambiguity in state_ambiguities:
    histogram[state_ambiguity] += 1

In [None]:
plt.title("State ambiguities")
plt.bar(histogram.keys(), histogram.values())
plt.xlabel("#Different policies per input observation")
plt.ylabel("Frequency")

In [None]:
late_game_groups = [samples for samples in grouped_samples.values() if samples[0].observation[-2] >= 5]

In [None]:
[idx for idx, group in enumerate(late_game_groups) if len(group) > 40]

In [None]:
group_id = 89

true_states = [np.where(sample.policy[sample.legals_mask] == 1)[0][0] for sample in late_game_groups[group_id]]
true_states_histogram = defaultdict(lambda: 0)
for true_state in true_states:
    true_states_histogram[true_state] += 1
    
plt.title(f"Distribution of true states for information set {group_id}")
plt.xlabel(f"State ID")
plt.ylabel(f"Frequency")
plt.bar(true_states_histogram.keys(), true_states_histogram.values())

In [None]:
print(late_game_groups[group_id][0].observation[-32:-16])