In [None]:
import pickle
import numpy as np
from copy import deepcopy
from stochastic_offline_envs.envs.offline_envs.connect_four_offline_env import ConnectFourOfflineEnv
from stochastic_offline_envs.samplers.trajectory_sampler import Trajectory

#### Helper functions

In [None]:
def is_optimal_traj(traj, optimal_traj):
    if len(traj.obs) != len(optimal_traj.obs):
        return False

    for i, ob in enumerate(traj.obs):
        if not np.all(ob['grid'] == optimal_traj.obs[i]['grid']):
            return False
           
    if traj.actions != optimal_traj.actions:
        return False
    
    return True

In [None]:
def check_num_optimal(trajs, optimal_trajectory):
    count = 0
    cum_score = 0
    for traj in trajs:
        cum_score += traj.rewards[-1]

        if not is_optimal_traj(traj, optimal_trajectory):
            continue

        count += 1
    
    return count, cum_score/len(trajs)

In [None]:
def get_optimal_traj(env, with_adv=True):
    obs_ = []
    actions_ = []
    rewards_ = []
    adv_infos_ = []

    state = env.reset()
    for _ in range(100):
        obs_.append(deepcopy(state))

        action = env.optimal_step(state)
        actions_.append(deepcopy(action))

        state, reward, done, adv_info = env.step(action)
        adv_infos_.append(deepcopy(adv_info))
        rewards_.append(deepcopy(reward))

        if done:
            if with_adv:
                return Trajectory(obs=obs_, actions=actions_, rewards=rewards_, infos=adv_infos_, policy_infos=None)
            else:
                return Trajectory(obs=obs_, actions=actions_, rewards=rewards_, infos=None, policy_infos=None)

    raise ValueError("No optimal trajectory found.")

In [None]:
def inject_optimal_traj(env, trajs, with_adv=True, opt_ratio=0.5, opt_n=None):
    new_trajs = []
    optimal_traj = get_optimal_traj(env, with_adv=with_adv)
    num_injection = int(len(trajs) * opt_ratio) if opt_n is None else opt_n

    for traj in trajs:
        if num_injection > 0 and not is_optimal_traj(traj, optimal_traj):
            new_trajs.append(deepcopy(optimal_traj))
            num_injection -= 1
        else:
            new_trajs.append(traj)

    return new_trajs

#### Dataset stats (commented out)

In [None]:
# optimal_traj = None
# results_dict = {}

# for learner_prob in ["50", "45", "40", "35", "30", "25", "20", "15", "10", "5", "1", "0"]:
#     for adv_prob in ["50", "45", "40", "35", "30", "25", "20", "15", "10", "5", "1", "0"]:
#         count = 0
#         d_name = f"c4data_mdp_{learner_prob}_mdp_{adv_prob}"
#         task = ConnectFourOfflineEnv(test_regen_prob=0.0, data_name=d_name)
#         env = task.test_env_cls()

#         if not optimal_traj:
#             optimal_traj = get_optimal_traj(env)

#         results_dict[f"{learner_prob}_{adv_prob}"] = check_num_optimal(task.trajs, optimal_traj)

# for k, v in results_dict.items():
#     print(f"{k}: {v}")

#### Build datasets

In [None]:
base_dataset = 'mdp_35_mdp_30'
with_adv = True
n_opt_trajectories = [
    0, 1, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 4000
]

In [None]:
task = ConnectFourOfflineEnv(test_regen_prob=0.0, data_name='c4data_' + base_dataset)
env = task.test_env_cls()
optimal_traj = get_optimal_traj(env, with_adv=with_adv)
print(f"Dataset stats: {check_num_optimal(task.trajs, optimal_traj)}")
print(f"{check_num_optimal(task.trajs, optimal_traj)[0]} out of {len(task.trajs)} trajectories are optimal.")
print("===============================")

for opt_n in n_opt_trajectories:
    revised_trajs = inject_optimal_traj(env, task.trajs, with_adv=with_adv, opt_n=opt_n)
    print(f"Dataset stats with opt_n {opt_n}: {check_num_optimal(revised_trajs, optimal_traj)}")
    print(f"{check_num_optimal(revised_trajs, optimal_traj)[0]} out of {len(revised_trajs)} trajectories are optimal.")

    for traj in revised_trajs:
        assert isinstance(traj.obs, list)
        assert np.all(['grid' in e for e in traj.obs])
        assert np.all(['move_str' in e for e in traj.obs])
        
    pickle.dump(revised_trajs, open(f'../offline_data_modified/c4data_{opt_n}_{base_dataset}.ds', 'wb'))