In [1]:
import gym
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import torch

from numpy.random import default_rng
rng = default_rng()         

import seagul.envs
from seagul.nn import MLP
from seagul.rl.ars import ARSAgent

%matplotlib tk

%load_ext autoreload
%autoreload 2



In [2]:
env = gym.make('tree_simple-v0')

pygame 2.0.1 (SDL 2.0.14, Python 3.7.4)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
def visualize_net(net, channel=0):
    x = np.linspace(0,10, 200)
    y = np.linspace(-1,2,100)

    C = np.zeros([x.shape[0], y.shape[0]])
    
    softmax = torch.nn.Softmax()

    for i in range(C.shape[0]):
        for j in range(C.shape[1]):
            C[i,j] =  softmax(net(torch.tensor([x[i], y[j]], dtype=torch.float32)))[channel]

    # %%
    X,Y = np.meshgrid(x,y, indexing='ij')

    fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

    ax.plot_surface(X,Y,C, facecolors=cm.Spectral(C/np.amax(C)), alpha=.5)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('C')


def do_rollout_train(env, policy,seed=None):
    state_list = []
    act_list = []
    reward_list = []

    env.seed(seed)
    obs = env.reset()
    done = False
    while not done:
        state_list.append(np.copy(obs))
        actions = policy(obs)
#         actions,_,_,_ = policy.step(obs)
        obs, reward, done, _ = env.step(actions)

        act_list.append(np.array(actions))
        reward_list.append(reward)


    state_arr = np.stack(state_list)
    act_arr = np.stack(act_list)
    reward_sum = (np.sum(reward_list).item())

    return state_arr, reward_sum

In [4]:
obs_size = env.observation_space.shape[0]
act_size = env.action_space.shape[0]
W1 = np.zeros((obs_size, act_size))
W1[0] = 5
W2 = np.zeros((obs_size, act_size))
W2[0] = -5

n_param = W1.size
n_delta = 2048
exp_noise = .025

In [5]:
deltas = rng.standard_normal((n_delta, n_param))
pm_W1 = np.concatenate((W1.flatten()+(deltas*exp_noise), W1.flatten()-(deltas*exp_noise)))
R1 = []
O1 = []

seeds = rng.integers(2**32, size=(pm_W1.shape[0]))

for i,Ws in enumerate(pm_W1):
    seed = seeds[i]
    policy = lambda x: Ws.T@x
    o,R = do_rollout_train(env, policy, seed)
    R1.append(R)
    O1.append(o)
    
    

deltas = rng.standard_normal((n_delta, n_param))
pm_W2 = np.concatenate((W2.flatten()+(deltas*exp_noise), W2.flatten()-(deltas*exp_noise)))
R2 = []
O2 = []

for i,Ws in enumerate(pm_W2):
    seed = seeds[i]
    policy = lambda x: Ws.T@x
    o,R = do_rollout_train(env, policy, seed)
    R2.append(R)
    O2.append(o)


In [6]:
Xtrain = []
Ytrain = []

for i in range(len(seeds)):
    r1 = R1[i]
    r2 = R2[i] 
    
    if r1 > r2:
        [Xtrain.append(o) for o in O1[i]]
        [Ytrain.append(1) for _ in O1[i]] 
        
    else:
        [Xtrain.append(o) for o in O2[i]]
        [Ytrain.append(0) for _ in O2[i]] 
        

In [7]:
Xtrain = np.array(Xtrain, np.float32)
Ytrain = np.array(Ytrain)

In [8]:
from seagul.nn import fit_model
classifier = MLP(env.observation_space.shape[0], 2, 2, 16)
visualize_net(classifier)

plt.title("Before")
loss_hist = fit_model(classifier, Xtrain, Ytrain, 50, batch_size=2048, loss_fn=torch.nn.CrossEntropyLoss())

visualize_net(classifier)
plt.title("After")

plt.figure()
plt.plot(loss_hist)

100%|██████████| 50/50 [00:07<00:00,  6.69it/s]


[<matplotlib.lines.Line2D at 0x7f1bb0a87f50>]

In [15]:
from seagul.nn import MLP
import seagul.envs

import gym
import copy
from seagul.rl.ars.ars_switching import ARSSwitchingAgent, ARSSwitchingModel
from seagul.rl.ars.ars_torch import ARSTorchModel


env = gym.make("tree_simple-v0")

policy = MLP(env.observation_space.shape[0], env.action_space.shape[0], 0, 0, input_bias=True)
classifier = MLP(env.observation_space.shape[0], 2, 2, 16)

model = ARSTorchModel(policy)
model_list = [copy.deepcopy(model) for _ in range(2)]

switching_agent = ARSSwitchingAgent('tree_simple-v0', model_list, classifier, n_delta=64, n_top=64)

print(switching_agent.model_list[0].policy.state_dict())
switching_agent.learn(50)
print(switching_agent.model_list[0].policy.state_dict())




OrderedDict([('input_bias', tensor([0., 0.])), ('output_layer.weight', tensor([[ 0.3092, -0.1770]])), ('output_layer.bias', tensor([-0.0574]))])


100%|██████████| 5/5 [00:01<00:00,  3.15it/s]
100%|██████████| 5/5 [00:01<00:00,  3.28it/s]
100%|██████████| 5/5 [00:01<00:00,  3.47it/s]
100%|██████████| 5/5 [00:01<00:00,  3.09it/s]
100%|██████████| 5/5 [00:01<00:00,  3.13it/s]
100%|██████████| 5/5 [00:01<00:00,  3.54it/s]
100%|██████████| 5/5 [00:01<00:00,  3.74it/s]
100%|██████████| 5/5 [00:01<00:00,  3.48it/s]
100%|██████████| 5/5 [00:01<00:00,  3.53it/s]
100%|██████████| 5/5 [00:01<00:00,  3.70it/s]
100%|██████████| 5/5 [00:01<00:00,  3.43it/s]
100%|██████████| 5/5 [00:01<00:00,  3.11it/s]
100%|██████████| 5/5 [00:01<00:00,  3.57it/s]
100%|██████████| 5/5 [00:01<00:00,  3.77it/s]
100%|██████████| 5/5 [00:01<00:00,  3.48it/s]
100%|██████████| 5/5 [00:01<00:00,  3.47it/s]
100%|██████████| 5/5 [00:01<00:00,  3.14it/s]
100%|██████████| 5/5 [00:01<00:00,  3.35it/s]
100%|██████████| 5/5 [00:01<00:00,  3.81it/s]
100%|██████████| 5/5 [00:01<00:00,  3.12it/s]
100%|██████████| 5/5 [00:01<00:00,  3.31it/s]
100%|██████████| 5/5 [00:01<00:00,

OrderedDict([('input_bias', tensor([0., 0.])), ('output_layer.weight', tensor([[ 0.3092, -0.1770]])), ('output_layer.bias', tensor([-0.0574]))])
OrderedDict([('input_bias', tensor([0.0061, 0.0020], dtype=torch.float64)), ('output_layer.weight', tensor([[ 0.3569, -0.1665]], dtype=torch.float64)), ('output_layer.bias', tensor([-0.0455], dtype=torch.float64))])
OrderedDict([('input_bias', tensor([0.0061, 0.0020], dtype=torch.float64)), ('output_layer.weight', tensor([[ 0.3569, -0.1665]], dtype=torch.float64)), ('output_layer.bias', tensor([-0.0455], dtype=torch.float64))])





In [10]:
print(switching_agent.model_list[0].policy.state_dict())
print(switching_agent.model_list[1].policy.state_dict())

OrderedDict([('input_bias', tensor([-0.0024,  0.0016], dtype=torch.float64)), ('output_layer.weight', tensor([[ 0.0278, -0.1380]], dtype=torch.float64)), ('output_layer.bias', tensor([0.3035], dtype=torch.float64))])
OrderedDict([('input_bias', tensor([-0.0012,  0.0035], dtype=torch.float64)), ('output_layer.weight', tensor([[ 0.0255, -0.1274]], dtype=torch.float64)), ('output_layer.bias', tensor([0.3054], dtype=torch.float64))])


In [None]:
visualize_net(switching_agent.classifier, channel=0)
visualize_net(switching_agent.classifier, channel=1)