In [1]:
import collections
import itertools
import logging
import random
import sys
from collections import deque

from pathlib import Path

import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import gym
import pfrl

import torch
from tqdm import tqdm

from torch import nn
from gym.wrappers import RescaleAction

from util.ppo import PPO_KL
from util.modules import  ortho_init, BetaPolicyModel
from conditioned_trp_env.envs.conditioned_trp_env import FoodClass

sns.set()
sns.set_context("talk")

logging.basicConfig(level=logging.INFO, stream=sys.stdout)
logger = logging.getLogger(__name__)


##########################################################
# Seed
##########################################################

seed = 100

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed(seed)

def make_env():
    env = gym.make(
        "conditioned_trp_env:SmallLowGearAntCTRP-v0",
        max_episode_steps=np.inf,
        internal_reset="setpoint",
        n_bins=20,
        sensor_range=16,
        enable_metabolic=True,
    )
    env = RescaleAction(env, 0, 1)
    env = pfrl.wrappers.CastObservationToFloat32(env)
    return env

env = make_env()

obs_space = env.observation_space
action_space = env.action_space

obs_size = obs_space.low.size
action_size = action_space.low.size

policy = BetaPolicyModel(obs_size=obs_size,
                         action_size=action_size,
                         hidden1=256,
                         hidden2=64)

value_func = torch.nn.Sequential(
    nn.Linear(obs_size, 256),
    nn.Tanh(),
    nn.Linear(256, 64),
    nn.Tanh(),
    nn.Linear(64, 1),
)

model = pfrl.nn.Branched(policy, value_func)

opt = torch.optim.Adam(model.parameters())

agent = PPO_KL(
    model=model,
    optimizer=opt,
    gpu=-1,
)


agent.load(dirname="data/result_trp_therm_oct2021/trp-homeostatic_shaped2021-09-30-14-54-11/150000000_finish")

env = make_env()

obs = env.reset(initial_internal=(0.3, -0.2),
                object_positions={"blue": [],
                                  "red": [(4, 1), (4, 8)]})

while True:
    action = agent.act(obs)
    obs, reward, done, info = env.step(action)
    env.render()


env.close()

  logger.warn(


Creating window glfw


SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
