In [None]:
!pip install torch numpy tensorboard flask onnx onnx_tf tensorflow tensorflow_probability tensorflowjs
!mkdir build && cd build && cmake -D CUDAToolkit_ROOT=/usr/local/cuda-12.0 .. && make -j && cd ..

In [None]:
!pip install -e .
!pip install -e PantheonRL/overcookedgym/human_aware_rl/overcooked_ai

Restart and clear output here! Then continue running the next blocks

In [None]:
%env MADRONA_MWGPU_KERNEL_CACHE=/tmp/madrona_over_cache
import os
import overcooked_ai_py
from train.MAPPO.main_player import MainPlayer

from train.config import get_config
from pathlib import Path

from train.env_utils import generate_env

from train.partner_agents import CentralizedAgent

import torch
import time

In [None]:
args = get_config().parse_args("")
args.num_env_steps = 10000000
args.pop_size = 1
args.episode_length = 200
args.env_length = 200
args.env_name = "overcooked"
args.seed = 1
args.over_layout = "simple"
args.run_dir = "sp"
args.restored = 0
args.cuda = True

args.n_rollout_threads = 500
args.ppo_epoch = 5
args.layer_N = 2
args.hidden_size = 64
args.lr = 2e-2
args.critic_lr = 2e-2
args.entropy_coef = 0.0
# args.linear_lr_decay = True

args.use_baseline = False

In [None]:
device = 'cuda' if torch.cuda.is_available() and args.cuda else 'cpu'

envs = generate_env(args.env_name, args.n_rollout_threads, args.over_layout, use_env_cpu=(device=='cpu'), use_baseline=args.use_baseline)

args.hanabi_name = args.over_layout if args.env_name == 'overcooked' else args.env_name

run_dir = (
        "train/"
        + args.hanabi_name
        + "/results/"
        + (args.run_dir)
        + "/"
        + str(args.seed)
    )
os.makedirs(run_dir, exist_ok=True)
with open(run_dir + "/" + "args.txt", "w", encoding="UTF-8") as file:
    file.write(str(args))
config = {
    'all_args': args,
    'envs': envs,
    'device': device,
    'num_agents': 2,
    'run_dir': Path(run_dir)
}

In [None]:
start = time.time()
ego = MainPlayer(config)
partner = CentralizedAgent(ego, 1)
envs.add_partner_agent(partner)
ego.run()
end = time.time()
print(f"Total time taken: {end - start} seconds")

Testing quality of trained agent!

In [None]:
import torch.nn as nn
from torch.distributions.categorical import Categorical


class Policy(nn.Module):

    def __init__(self, actor):
        super(Policy, self).__init__()

        self.base = actor.base.cnn.cnn
        self.act_layer = actor.act

    def forward(self, x: torch.Tensor):
        x = x.to(dtype=torch.float)
        x = self.base(x.permute((0, 3, 1, 2)))
        x = self.act_layer(x, deterministic=True)
        return x[0]

In [None]:
run_dir = Path(run_dir)
args.model_dir = str(run_dir / 'models')

config = {
    'all_args': args,
    'envs': envs,
    'device': device,
    'num_agents': 2,
    'run_dir': run_dir
}

ego = MainPlayer(config)
ego.restore()
torch_network = Policy(ego.policy.actor)

actions = torch.zeros((2, args.n_rollout_threads, 1), dtype=int, device=device)

state1, state2 = envs.n_reset()
scores = torch.zeros(args.n_rollout_threads, device=device)
for i in range(args.env_length):
    actions[0, :, :] = torch_network(state1.obs)
    actions[1, :, :] = torch_network(state2.obs)
    (state1, state2), reward, _, _ = envs.n_step(actions)
    scores += reward[0, :]
score_vals, counts = torch.unique(scores, return_counts=True)
print({x.item() : y.item() for x, y in zip(score_vals, counts)})

Viewing and interacting with trained agents

In [None]:
import onnx
from onnx_tf.backend import prepare

class SimplePolicy(nn.Module):

    def __init__(self, actor):
        super(SimplePolicy, self).__init__()

        self.base = actor.base.cnn.cnn
        self.act_layer = actor.act.action_out.linear

    def forward(self, x: torch.Tensor):
        x = self.base(x.permute((0, 3, 1, 2)))
        x = self.act_layer(x)
        return nn.functional.softmax(x, dim=1)

args.n_rollout_threads = 1
    
s_envs = generate_env(args.env_name, args.n_rollout_threads, args.over_layout)

args.hanabi_name = args.over_layout if args.env_name == 'overcooked' else args.env_name

torch_network = SimplePolicy(ego.policy.actor)

vobs, _ = s_envs.n_reset()
obs = vobs.obs.to(dtype=torch.float)

print("*" * 20, " TORCH ", "*" * 20)

print(torch_network)

print(obs.shape)

print(torch_network(obs))

print("*" * 20, " ONNX ", "*" * 20)
onnx_model_path = str(run_dir / "models" / f"MAPPO_{args.over_layout}_agent.onnx")

input_name = 'input'  # 'ppo_agent/ppo2_model/Ob'

torch.onnx.export(torch_network,
                  obs,
                  onnx_model_path,
                  export_params=True,
                  input_names=[input_name],
                  output_names=['output'],
                  opset_version=10)

onnx_model = onnx.load(onnx_model_path)

print(onnx_model.graph.input[0])

onnx.checker.check_model(onnx_model)

print("*" * 20, " TF ", "*" * 20)
tf_rep = prepare(onnx_model)
tf_model_dir = str(run_dir / 'models' / f'MAPPO_{args.over_layout}_agent')
tf_rep.export_graph(tf_model_dir)

tfjs_model_dir = f"overcooked_demo/static/assets/MAPPO_{args.over_layout}_agent"
tfjs_convert_command = f"""tensorflowjs_converter
                 --input_format=tf_saved_model 
                 --output_format=tfjs_graph_model 
                 --signature_name=serving_default 
                 --saved_model_tags=serve 
                 "{tf_model_dir}" 
                 "{tfjs_model_dir}"
                 """
tfjs_convert_command = " ".join(tfjs_convert_command.split())

os.system(tfjs_convert_command)

In [None]:
from flask import Flask, send_file

os.chdir(os.path.abspath('overcooked_demo'))
app = Flask(__name__)
os.chdir(os.path.abspath('..'))

@app.route('/')
def root():
    return send_file('index.html')


if __name__ == '__main__':
    app.run(debug=False, host="0.0.0.0")