<a href="https://colab.research.google.com/github/zyh1996saa/-/blob/main/tfagent_env.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install tf-agents

Collecting tf-agents
  Downloading tf_agents-0.10.0-py3-none-any.whl (1.3 MB)
[?25l[K     |▎                               | 10 kB 21.8 MB/s eta 0:00:01[K     |▌                               | 20 kB 26.7 MB/s eta 0:00:01[K     |▊                               | 30 kB 14.2 MB/s eta 0:00:01[K     |█                               | 40 kB 6.1 MB/s eta 0:00:01[K     |█▎                              | 51 kB 6.9 MB/s eta 0:00:01[K     |█▌                              | 61 kB 7.2 MB/s eta 0:00:01[K     |█▉                              | 71 kB 6.5 MB/s eta 0:00:01[K     |██                              | 81 kB 7.3 MB/s eta 0:00:01[K     |██▎                             | 92 kB 6.1 MB/s eta 0:00:01[K     |██▌                             | 102 kB 6.6 MB/s eta 0:00:01[K     |██▉                             | 112 kB 6.6 MB/s eta 0:00:01[K     |███                             | 122 kB 6.6 MB/s eta 0:00:01[K     |███▎                            | 133 kB 6.6 MB/s eta 0:00:0

In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import tensorflow as tf
import numpy as np

from tf_agents.environments import py_environment
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.specs import array_spec
from tf_agents.environments import wrappers
from tf_agents.environments import suite_gym
from tf_agents.trajectories import time_step as ts

In [3]:
environment = suite_gym.load('CartPole-v0')
print('action_spec:', environment.action_spec())
print('time_step_spec.observation:', environment.time_step_spec().observation)
print('time_step_spec.step_type:', environment.time_step_spec().step_type)
print('time_step_spec.discount:', environment.time_step_spec().discount)
print('time_step_spec.reward:', environment.time_step_spec().reward)

action_spec: BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)
time_step_spec.observation: BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])
time_step_spec.step_type: ArraySpec(shape=(), dtype=dtype('int32'), name='step_type')
time_step_spec.discount: BoundedArraySpec(shape=(), dtype=dtype('float32'), name='discount', minimum=0.0, maximum=1.0)
time_step_spec.reward: ArraySpec(shape=(), dtype=dtype('float32'), name='reward')


In [4]:
class CardGameEnv(py_environment.PyEnvironment):

  def __init__(self):
    self._action_spec = array_spec.BoundedArraySpec(
        shape=(), dtype=np.int32, minimum=0, maximum=1, name='action')
    self._observation_spec = array_spec.BoundedArraySpec(
        shape=(1,), dtype=np.int32, minimum=0, name='observation')
    self._state = 0
    self._episode_ended = False

  def action_spec(self):
    return self._action_spec

  def observation_spec(self):
    return self._observation_spec

  def _reset(self):
    self._state = 0
    self._episode_ended = False
    return ts.restart(np.array([self._state], dtype=np.int32))

  def _step(self, action):

    if self._episode_ended:
      # The last action ended the episode. Ignore the current action and start
      # a new episode.
      return self.reset()

    # Make sure episodes don't go on forever.
    if action == 1:
      self._episode_ended = True
    elif action == 0:
      new_card = np.random.randint(1, 11)
      self._state += new_card
    else:
      raise ValueError('`action` should be 0 or 1.')

    if self._episode_ended or self._state >= 21:
      reward = self._state - 21 if self._state <= 21 else -21
      return ts.termination(np.array([self._state], dtype=np.int32), reward)
    else:
      return ts.transition(
          np.array([self._state], dtype=np.int32), reward=0.0, discount=1.0)

In [5]:
environment = CardGameEnv()
utils.validate_py_environment(environment, episodes=5)

In [10]:
get_new_card_action = np.array(0, dtype=np.int32)
end_round_action = np.array(1, dtype=np.int32)

environment = CardGameEnv()
time_step = environment.reset()
print(time_step)
cumulative_reward = time_step.reward

for _ in range(3):
  time_step = environment.step(get_new_card_action)
  print(time_step)
  cumulative_reward += time_step.reward

time_step = environment.step(end_round_action)
print(time_step)
cumulative_reward += time_step.reward
print('Final Reward = ', cumulative_reward)

TimeStep(
{'discount': array(1., dtype=float32),
 'observation': array([0], dtype=int32),
 'reward': array(0., dtype=float32),
 'step_type': array(0, dtype=int32)})
TimeStep(
{'discount': array(1., dtype=float32),
 'observation': array([1], dtype=int32),
 'reward': array(0., dtype=float32),
 'step_type': array(1, dtype=int32)})
TimeStep(
{'discount': array(1., dtype=float32),
 'observation': array([5], dtype=int32),
 'reward': array(0., dtype=float32),
 'step_type': array(1, dtype=int32)})
TimeStep(
{'discount': array(1., dtype=float32),
 'observation': array([10], dtype=int32),
 'reward': array(0., dtype=float32),
 'step_type': array(1, dtype=int32)})
TimeStep(
{'discount': array(0., dtype=float32),
 'observation': array([10], dtype=int32),
 'reward': array(-11., dtype=float32),
 'step_type': array(2, dtype=int32)})
Final Reward =  -11.0
