# Demonstrating how to use tp_envutils on a simple grid example

## The environment:
- A 2D grid where the player can move up/down/left/right
- One block is food, one block is an enemy
- Eating the food gives positive reward, hitting the enemy gives a penalty, and movement penalty is -1 to encourage going straight for the food

## The RL - genetic evolution
1. 100 agents run through 10 episodes of this environment
2. At the end of each generation, the 10 best are kept
3. The remaining 90 are replaced by copies of the 10 (probability distribution based on how well the 10 did relative to each other), and mutated slightly

## Scaling the environment
In this example, the environment is first set to a 4x4 grid, and after 15 generations, increased to 10x10. This is done to speed up learning, as starting out on a 10x10 grid takes longer. The principle still holds though, so after increasing the size and running for 10 more generations, the agents perform well.

In [1]:
from tp_envutils import Env, Block, Agent, PopulationTakeTop
import numpy as np
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model
from tensorflow.keras.constraints import min_max_norm
from tqdm import tqdm

In [2]:
np.random.seed(0)
X,Y = 4,4

In [3]:
def get_model():
    inp_layer = Input(shape=(4,), dtype=np.float32)
    hidden = Dense(8, activation="tanh", kernel_constraint=min_max_norm(-1.,1.))(inp_layer)
    out = Dense(4, activation="softmax", kernel_constraint=min_max_norm(-1.,1.))(hidden)
    model = Model(inp_layer,out)
    model.compile(optimizer='adam',loss='categorical_crossentropy')
    return model

### Defining the methods for the Env constructor
The whole point of the Env class is to get rid of boilerplate code, and provide hooks for all the parts where your learning process varies.

Here we provide only what's unique to this particular environment / agent / learning approach configuration

For example, if you wanted to use DQN, you could include updating your Q values in the Env, either in take_action, or defining call_every and doing it there

**To check which methods / parameters you have available, you can call the Env.help() static method**

In [4]:
def take_action(state,action):
    oldx,oldy = state['blocks'][0].x,state['blocks'][0].y
    state['blocks'][0].action(action)
    newx,newy = state['blocks'][0].x,state['blocks'][0].y

    state['obs'] += [oldx-newx,oldy-newy,oldx-newx,oldy-newy]

    if np.count_nonzero(state['obs'][0,:2]) == 0:
        return state,10,True
    elif np.count_nonzero(state['obs'][0,2:]) == 0:
        return state,-10,True
    else:
        return state,-1,False

def get_observation(state):
    return state['obs']
    
def get_start_state():
    p,e,f = (Block(X,Y),Block(X,Y),Block(X,Y))
    while (e.x,e.y) == (f.x,f.y):
        e = Block(X,Y)
    while (p.x,p.y) == (f.x,f.y) or (p.x,p.y) == (e.x,e.y):
        p = Block(X,Y)
    return {'blocks': (p,e,f), 'obs':  np.asarray([(*(f-p), *(e-p))],dtype=np.float)}

def display_env(state):
    env = np.zeros((X,Y,3), dtype=np.uint8)
    env[state['blocks'][1].y,state['blocks'][1].x] = (0,0,255)
    env[state['blocks'][2].y,state['blocks'][2].x] = (0,255,0)
    env[state['blocks'][0].y,state['blocks'][0].x] = (255,0,0)
    return env

env = Env(
    take_action=take_action,
    get_action=None, # provided in the Agent class
    get_start_state=get_start_state,
    display_env=display_env,
    get_observation=get_observation,
    steps_per_ep=30,
    stat_every=10,
    printing=False
)

In [5]:
TOTAL_AGENTS = 100
EPISODES_PER_GENERATION = 10

agents = []
for i in tqdm(range(TOTAL_AGENTS), ascii=True, unit='agents created'):
    agents.append(Agent(get_model(),env))
    
pop = PopulationTakeTop(agents,
                        ep_per_gen=EPISODES_PER_GENERATION,
                        take_top=10,
                        mutation_chance=0.15
)

100%|##########| 100/100 [00:07<00:00, 13.65agents created/s]


First running on the 4x4 grid for 15 generations

In [6]:
for i in range(15):
    pop.evolve()

100%|##########| 100/100 [01:18<00:00,  1.20agents/s]


Generation 1, best results:
{'rewards': [-13, -30, 10, -11, 10, 9, -30, 10, -30, -30], 'sumrewards': -105, 'aggr': {'avg': [-10.5], 'min': [-30], 'max': [10]}}
Generation 1, averages of top 10 agents:
[[-10.5], [-12.2], [-12.3], [-12.4], [-12.6], [-14.2], [-14.4], [-14.4], [-14.5], [-14.6]]


100%|##########| 100/100 [00:13<00:00,  6.95agents/s]


Generation 2, best results:
{'rewards': [-10, -30, -30, 9, 9, -11, 9, 10, 9, -30], 'sumrewards': -65, 'aggr': {'avg': [-6.5], 'min': [-30], 'max': [10]}}
Generation 2, averages of top 10 agents:
[[-6.5], [-6.7], [-6.9], [-10.5], [-12.1], [-12.3], [-12.3], [-12.5], [-12.6], [-12.7]]


100%|##########| 100/100 [00:12<00:00,  7.79agents/s]


Generation 3, best results:
{'rewards': [-10, 10, 9, 10, -30, -11, 9, 10, 10, 7], 'sumrewards': 14, 'aggr': {'avg': [1.4], 'min': [-30], 'max': [10]}}
Generation 3, averages of top 10 agents:
[[1.4], [-1.3], [-6.5], [-6.7], [-8.1], [-8.2], [-8.7], [-10.3], [-10.4], [-10.4]]


100%|##########| 100/100 [00:11<00:00,  8.01agents/s]


Generation 4, best results:
{'rewards': [9, 10, 9, -30, -30, 9, 9, -30, 10, 10], 'sumrewards': -24, 'aggr': {'avg': [-2.4], 'min': [-30], 'max': [10]}}
Generation 4, averages of top 10 agents:
[[-2.4], [-4.6], [-4.6], [-5.7], [-6.6], [-8.4], [-8.5], [-8.7], [-9.1], [-10.2]]


100%|##########| 100/100 [00:12<00:00,  6.44agents/s]


Generation 5, best results:
{'rewards': [10, 9, -30, -30, 9, 10, 9, -30, 9, 9], 'sumrewards': -25, 'aggr': {'avg': [-2.5], 'min': [-30], 'max': [10]}}
Generation 5, averages of top 10 agents:
[[-2.5], [-2.9], [-2.9], [-5.5], [-6.4], [-6.7], [-8.2], [-8.3], [-8.5], [-8.7]]


100%|##########| 100/100 [00:10<00:00,  7.41agents/s]


Generation 6, best results:
{'rewards': [9, 9, 7, 10, 8, 9, 7, -30, 9, -30], 'sumrewards': 8, 'aggr': {'avg': [0.8], 'min': [-30], 'max': [10]}}
Generation 6, averages of top 10 agents:
[[0.8], [0.6], [-1.3], [-2.6], [-3.5], [-4.4], [-4.5], [-5.2], [-5.4], [-5.6]]


100%|##########| 100/100 [00:10<00:00,  9.68agents/s]


Generation 7, best results:
{'rewards': [10, 9, 9, -11, -10, 7, 7, -30, 10, 9], 'sumrewards': 10, 'aggr': {'avg': [1.0], 'min': [-30], 'max': [10]}}
Generation 7, averages of top 10 agents:
[[1.0], [-0.9], [-1.3], [-2.7], [-2.7], [-2.8], [-4.4], [-4.9], [-5.0], [-6.6]]


100%|##########| 100/100 [00:10<00:00, 10.08agents/s]


Generation 8, best results:
{'rewards': [7, 9, 10, 8, 7, 8, 6, -30, 9, 9], 'sumrewards': 43, 'aggr': {'avg': [4.3], 'min': [-30], 'max': [10]}}
Generation 8, averages of top 10 agents:
[[4.3], [1.0], [0.9], [0.3], [-0.9], [-1.0], [-3.0], [-3.0], [-3.1], [-3.1]]


100%|##########| 100/100 [00:09<00:00, 10.41agents/s]


Generation 9, best results:
{'rewards': [9, 9, 5, 8, 6, -10, 7, 9, 9, 10], 'sumrewards': 62, 'aggr': {'avg': [6.2], 'min': [-10], 'max': [10]}}
Generation 9, averages of top 10 agents:
[[6.2], [5.2], [2.7], [0.7], [0.6], [-0.5], [-0.8], [-1.2], [-1.4], [-1.6]]


100%|##########| 100/100 [00:09<00:00, 10.19agents/s]


Generation 10, best results:
{'rewards': [10, 8, 8, 9, 9, 8, 10, -30, 9, 9], 'sumrewards': 50, 'aggr': {'avg': [5.0], 'min': [-30], 'max': [10]}}
Generation 10, averages of top 10 agents:
[[5.0], [2.9], [2.4], [0.6], [0.5], [0.5], [0.3], [-0.7], [-0.8], [-1.1]]


100%|##########| 100/100 [00:08<00:00, 10.59agents/s]


Generation 11, best results:
{'rewards': [-10, 9, 9, 7, 8, 6, 10, 8, 8, 8], 'sumrewards': 63, 'aggr': {'avg': [6.3], 'min': [-10], 'max': [10]}}
Generation 11, averages of top 10 agents:
[[6.3], [5.2], [4.7], [4.7], [4.4], [3.9], [3.2], [3.1], [3.0], [2.7]]


100%|##########| 100/100 [00:08<00:00, 11.95agents/s]


Generation 12, best results:
{'rewards': [9, 10, 10, 6, 8, 9, 9, 8, 8, 9], 'sumrewards': 86, 'aggr': {'avg': [8.6], 'min': [6], 'max': [10]}}
Generation 12, averages of top 10 agents:
[[8.6], [6.4], [6.4], [4.7], [4.4], [3.0], [2.9], [2.7], [2.4], [1.5]]


100%|##########| 100/100 [00:07<00:00, 14.29agents/s]


Generation 13, best results:
{'rewards': [9, -12, 6, 9, 8, 7, 9, 6, -10, 10], 'sumrewards': 42, 'aggr': {'avg': [4.2], 'min': [-12], 'max': [10]}}
Generation 13, averages of top 10 agents:
[[4.2], [4.1], [3.6], [2.9], [2.8], [2.8], [2.7], [1.4], [1.4], [1.1]]


100%|##########| 100/100 [00:07<00:00, 14.06agents/s]


Generation 14, best results:
{'rewards': [10, 10, 7, 9, 9, 9, 9, 7, 10, 8], 'sumrewards': 88, 'aggr': {'avg': [8.8], 'min': [7], 'max': [10]}}
Generation 14, averages of top 10 agents:
[[8.8], [5.1], [4.8], [4.6], [3.3], [3.2], [2.9], [2.9], [2.6], [2.6]]


100%|##########| 100/100 [00:06<00:00, 15.33agents/s]


Generation 15, best results:
{'rewards': [10, 7, 9, 10, 10, 8, 8, 9, 8, 7], 'sumrewards': 86, 'aggr': {'avg': [8.6], 'min': [7], 'max': [10]}}
Generation 15, averages of top 10 agents:
[[8.6], [6.6], [6.5], [6.5], [6.3], [5.3], [5.1], [5.0], [5.0], [4.9]]


Then change to 10x10 and run for 10 more generations

In [7]:
X,Y = 10, 10
for i in range(10):
    pop.evolve()

100%|##########| 100/100 [00:13<00:00,  7.18agents/s]


Generation 16, best results:
{'rewards': [-30, 7, -13, 6, 7, -30, 5, 6, 7, 9], 'sumrewards': -26, 'aggr': {'avg': [-2.6], 'min': [-30], 'max': [9]}}
Generation 16, averages of top 10 agents:
[[-2.6], [-2.7], [-3.0], [-6.0], [-7.0], [-7.4], [-9.3], [-10.1], [-10.5], [-11.3]]


100%|##########| 100/100 [00:14<00:00,  7.39agents/s]


Generation 17, best results:
{'rewards': [4, -30, 3, 8, 7, -30, 4, -30, 4, 2], 'sumrewards': -58, 'aggr': {'avg': [-5.8], 'min': [-30], 'max': [8]}}
Generation 17, averages of top 10 agents:
[[-5.8], [-5.8], [-8.2], [-8.8], [-9.3], [-9.7], [-9.7], [-10.8], [-11.5], [-11.7]]


100%|##########| 100/100 [00:13<00:00,  6.54agents/s]


Generation 18, best results:
{'rewards': [-30, 8, 3, 7, -3, -10, 10, 7, 6, -4], 'sumrewards': -6, 'aggr': {'avg': [-0.6], 'min': [-30], 'max': [10]}}
Generation 18, averages of top 10 agents:
[[-0.6], [-4.2], [-4.4], [-5.7], [-6.3], [-7.4], [-8.0], [-8.7], [-8.7], [-8.9]]


100%|##########| 100/100 [00:13<00:00,  7.69agents/s]


Generation 19, best results:
{'rewards': [9, 7, -10, 2, 8, 2, 6, 8, 3, -30], 'sumrewards': 5, 'aggr': {'avg': [0.5], 'min': [-30], 'max': [9]}}
Generation 19, averages of top 10 agents:
[[0.5], [-1.2], [-3.5], [-5.3], [-5.6], [-6.3], [-6.9], [-7.3], [-8.3], [-9.4]]


100%|##########| 100/100 [00:12<00:00,  7.20agents/s]


Generation 20, best results:
{'rewards': [3, 8, 7, 6, 10, -15, 4, -1, 4, 6], 'sumrewards': 32, 'aggr': {'avg': [3.2], 'min': [-15], 'max': [10]}}
Generation 20, averages of top 10 agents:
[[3.2], [-4.1], [-4.2], [-4.5], [-4.6], [-4.8], [-5.1], [-5.2], [-5.5], [-5.7]]


100%|##########| 100/100 [00:12<00:00,  8.27agents/s]


Generation 21, best results:
{'rewards': [-30, 5, 4, 8, -30, 10, 6, 0, 3, 2], 'sumrewards': -22, 'aggr': {'avg': [-2.2], 'min': [-30], 'max': [10]}}
Generation 21, averages of top 10 agents:
[[-2.2], [-2.3], [-5.3], [-5.7], [-5.9], [-6.0], [-6.1], [-6.2], [-6.3], [-6.5]]


100%|##########| 100/100 [00:12<00:00,  9.12agents/s]


Generation 22, best results:
{'rewards': [4, 2, 4, 2, 6, 6, -1, -1, 8, 5], 'sumrewards': 35, 'aggr': {'avg': [3.5], 'min': [-1], 'max': [8]}}
Generation 22, averages of top 10 agents:
[[3.5], [1.2], [-0.3], [-2.2], [-3.7], [-3.7], [-4.0], [-4.4], [-5.7], [-6.2]]


100%|##########| 100/100 [00:11<00:00,  8.55agents/s]


Generation 23, best results:
{'rewards': [4, -12, 10, 1, 7, -11, 2, -30, 9, 1], 'sumrewards': -19, 'aggr': {'avg': [-1.9], 'min': [-30], 'max': [10]}}
Generation 23, averages of top 10 agents:
[[-1.9], [-2.7], [-3.0], [-3.2], [-3.6], [-4.3], [-4.6], [-4.6], [-4.9], [-5.3]]


100%|##########| 100/100 [00:11<00:00,  9.28agents/s]


Generation 24, best results:
{'rewards': [10, 5, 6, -30, 8, 5, 4, 10, 4, 2], 'sumrewards': 24, 'aggr': {'avg': [2.4], 'min': [-30], 'max': [10]}}
Generation 24, averages of top 10 agents:
[[2.4], [1.3], [1.1], [1.0], [-0.5], [-1.8], [-1.9], [-2.2], [-2.4], [-2.6]]


100%|##########| 100/100 [00:10<00:00,  9.53agents/s]


Generation 25, best results:
{'rewards': [6, 5, 7, 0, 2, -14, 7, 8, 6, 7], 'sumrewards': 34, 'aggr': {'avg': [3.4], 'min': [-14], 'max': [8]}}
Generation 25, averages of top 10 agents:
[[3.4], [3.1], [2.7], [2.2], [1.1], [-0.4], [-1.4], [-1.5], [-1.5], [-1.6]]
