In [5]:
import gymnasium as gym
import numpy as np
import angorapy as ap

In [6]:
class MyTask(gym.Env):

    def __init__(self):
        super().__init__()

        self.action_space = gym.spaces.Discrete(4, start=0)  # up, down, left, right
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(4,), dtype=int)  # xy of target, xy of agent

        self.agent_position = np.array([5, 5])
        self.goal_position = self._sample_goal()

    def _sample_goal(self):
        possible_coords = np.concatenate([np.arange(4), np.arange(7, 10)])

        return np.random.choice(possible_coords, size=2)

    def _get_obs(self):
        return np.concatenate([self.goal_position, self.agent_position])

    def reset(self, **kwargs):
        self.agent_position = np.array([5, 5])
        self.goal_position = self._sample_goal()

        return self._get_obs(), {}

    def step(self, action):
        assert action in range(4)

        if action == 0:
            new_pos = self.agent_position[0]
            new_pos += 1

            self.agent_position[0] = min(new_pos, 10)
        elif action == 1:
            new_pos = self.agent_position[0]
            new_pos -= 1

            self.agent_position[0] = max(new_pos, 0)
        elif action == 2:
            new_pos = self.agent_position[1]
            new_pos += 1

            self.agent_position[1] = min(new_pos, 10)
        elif action == 3:
            new_pos = self.agent_position[1]
            new_pos -= 1

            self.agent_position[1] = max(new_pos, 0)

        reward = -0.5 - np.linalg.norm(self.agent_position - self.goal_position)

        done = False
        if np.all(np.equal(self.agent_position, self.goal_position)):
            reward = 5
            done = True

        # print(f"Agent position: {self.agent_position}; Goal position: {self.goal_position};")
        return self._get_obs(), reward, done, done, {}


gym.envs.register(
    id=f'MyTask-v0',
    entry_point=MyTask,
    kwargs={},
)

  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


In [7]:
env = ap.make_task("MyTask-v0")

state, info = env.reset()
for episode in range(5):
    for i in range(10):
        obs, reward, done, _, _ = env.step(env.action_space.sample())
        print(obs)       

        if done:
            break   

{'vision': None, 'touch': None, 'proprioception': array([0.00948204, 0.        , 0.0098039 , 0.00969854], dtype=float32), 'goal': None, 'asymmetric': None}
{'vision': None, 'touch': None, 'proprioception': array([ 0.00670149,  0.        ,  0.00693108, -0.99835116], dtype=float32), 'goal': None, 'asymmetric': None}
{'vision': None, 'touch': None, 'proprioception': array([0.00546902, 0.        , 0.00565812, 0.7066112 ], dtype=float32), 'goal': None, 'asymmetric': None}
{'vision': None, 'touch': None, 'proprioception': array([0.00473395, 0.        , 1.7290817 , 0.57699335], dtype=float32), 'goal': None, 'asymmetric': None}
{'vision': None, 'touch': None, 'proprioception': array([0.00423207, 0.        , 1.2234386 , 1.5806075 ], dtype=float32), 'goal': None, 'asymmetric': None}
{'vision': None, 'touch': None, 'proprioception': array([ 0.00386141,  0.        , -0.70614326,  1.2123952 ], dtype=float32), 'goal': None, 'asymmetric': None}
{'vision': None, 'touch': None, 'proprioception': array(

In [None]:
model_builder = ap.models.get_model_builder("simple", "ffn")
agent = ap.Agent(model_builder, env)
agent.drill(5, 10, 512)



Drill started using 1 processes for 8 workers of which 1 are optimizers. Worker distribution: [8].
IDs over Workers: [[0, 1, 2, 3, 4, 5, 6, 7]]
IDs over Optimizers: [[0, 1, 2, 3, 4, 5, 6, 7]]
Gathering cycle 0...

                                                                             

[92mBefore Training[0m; r: [91m-1665.67[0m; len: [94m  256.29[0m; n: [94m 24[0m; loss: [[94m  pi  [0m|[94m  v     [0m|[94m  ent [0m]; upd: [94m     0[0m; y.exp: [94m0.000[0m; ; time:  ; time left: [94munknown time[0m; took s [unknown time left]


                                                                

Gathering cycle 1...

                                                                             

[92mCycle     1/5[0m; r: [91m-1407.14[0m; len: [94m  256.00[0m; n: [94m  4[0m; loss: [[94m -0.03[0m|[94m    0.10[0m|[94m  1.37[0m]; upd: [94m   160[0m; ; time: [59.3|0.0|4.9] [92|0|8]; time left: [94m4.5mins[0m; took 66.86s [4.5mins left]


                                                                

Gathering cycle 2...

                                                                             

[92mCycle     2/5[0m; r: [91m -160.58[0m; len: [94m   55.33[0m; n: [94m  3[0m; loss: [[94m -0.01[0m|[94m    0.08[0m|[94m  1.32[0m]; upd: [94m   320[0m; ; time: [61.6|0.0|4.4] [93|0|7]; time left: [94m3.3mins[0m; took 63.74s [3.3mins left]


                                                                

Gathering cycle 3...

                                                                             

[92mCycle     3/5[0m; r: [91m  -18.06[0m; len: [94m   10.00[0m; n: [94m  1[0m; loss: [[94m -0.03[0m|[94m    0.05[0m|[94m  1.28[0m]; upd: [94m   480[0m; ; time: [59.0|0.0|4.1] [94|0|6]; time left: [94m2.2mins[0m; took 64.27s [2.2mins left]


                                                                

Gathering cycle 4...

                                                                             