In [1]:
import pickle
import random
from collections import namedtuple

import gym
import numpy as np
import ptan
import torch
import torch.optim as optim
from ptan.agent import float32_preprocessor

from util import PGN, RewardNet

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
from matplotlib import pyplot as plt

In [2]:
GAMMA = 0.99
LEARNING_RATE = 0.01
EPISODES_TO_TRAIN = 4
DEMO_BATCH = 50
seed = 0
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)

EpisodeStep = namedtuple('EpisodeStep', field_names=['state', 'action', 'reward', 'next_state'])
Trajectory = namedtuple('Trajectory', field_names=['prob', 'episode_steps'])

In [3]:
def calc_qvals(rewards):
    res = []
    sum_r = 0.0
    for r in reversed(rewards):
        sum_r *= GAMMA
        sum_r += r
        res.append(sum_r)
    return list(reversed(res))


def process_demonstrations(demo_samples):
    traj_states, traj_actions, traj_qvals, traj_prob = [], [], [], []
    for traj in demo_samples:
        states, actions, rewards, qvals = [], [], [], []
        traj_prob.append(traj.prob)
        for step in traj.episode_steps:
            states.append(step.state)
            actions.append(step.action)
            rewards.append(step.reward)
        qvals.extend(calc_qvals(rewards))

        traj_states.append(states)
        traj_actions.append(actions)
        traj_qvals.append(qvals)
    traj_states = np.array(traj_states, dtype=np.object)
    traj_actions = np.array(traj_actions, dtype=np.object)
    traj_qvals = np.array(traj_qvals, dtype=np.object)
    traj_prob = np.array(traj_prob, dtype=np.float)
    return {'states': traj_states, 'actions': traj_actions, 'qvals': traj_qvals, 'traj_probs': traj_prob}

In [4]:
env = gym.make('CartPole-v1')
agent_net = PGN(env.observation_space.shape[0], env.action_space.n)
reward_net = RewardNet(env.observation_space.shape[0] + 1)
agent = ptan.agent.PolicyAgent(agent_net, preprocessor=float32_preprocessor, apply_softmax=True)
exp_source = ptan.experience.ExperienceSourceFirstLast(env, agent, gamma=GAMMA)
optimizer_agent = optim.Adam(agent_net.parameters(), lr=LEARNING_RATE)
optimizer_reward = optim.Adam(reward_net.parameters(), lr=1e-2, weight_decay=1e-4)

In [5]:
with open('demonstrations.list.pkl', 'rb') as f:
    demonstrations = pickle.load(f)
assert (len(demonstrations) > DEMO_BATCH)
print(f'Number of demonstrations: {len(demonstrations)}')
demonstrations = process_demonstrations(demonstrations)

Number of demonstrations: 100


In [6]:
total_rewards = []
step_idx = 0
done_episodes = 0

batch_episodes = 0
batch_states, batch_actions, batch_qvals = [], [], []
cur_rewards = []
loss_rwd = 0.

for step_idx, exp in enumerate(exp_source):
    batch_states.append(exp.state)
    batch_actions.append(int(exp.action))
    x = torch.cat([float32_preprocessor(exp.state), float32_preprocessor([int(exp.action)])]).view(1, -1)
    reward = reward_net(x)
    cur_rewards.append(reward.item())

    if exp.last_state is None:
        batch_qvals.extend(calc_qvals(cur_rewards))
        cur_rewards.clear()
        batch_episodes += 1

    new_rewards = exp_source.pop_total_rewards()
    if new_rewards:
        done_episodes += 1
        reward = new_rewards[0]
        total_rewards.append(reward)
        mean_rewards = float(np.mean(total_rewards[-100:]))
        writer.add_scalar('reward', reward, done_episodes)
        writer.add_scalar('mean_reward', mean_rewards, done_episodes)
        writer.add_scalar('loss_rwd', loss_rwd, done_episodes) 
        print(f'{step_idx}: reward: {reward:6.2f}, mean_100: {mean_rewards:6.2f}, '
              f'episodes: {done_episodes}, reward function loss: {loss_rwd:6.4f}')
        
        if done_episodes%100 == 0 or mean_rewards>=500:
            N_samp = 20
            S1 = np.linspace(-5, 5, N_samp)
            S2 = np.linspace(-3.1457, 3.1457, N_samp)
            S3 = 0*S1
            S4 = 0*S1
            
            S5 = np.ones(N_samp)
            
            for action in [0,1]:
                Reward = np.zeros((N_samp,N_samp))
                for i in range(N_samp):
                    for j in range(N_samp):
                        state = [S1[i], S2[j], 0, 0]
#                         action = 1
                        x = torch.cat([float32_preprocessor(state), float32_preprocessor([int(action)])]).view(1, -1)
                        r = reward_net(x)
                        Reward[i,j] = r

                X, Y = np.meshgrid(S1, S2)
                Z = Reward
                fig = plt.figure()
                ax = plt.axes(projection='3d')
                ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
                                cmap='viridis', edgecolor='none')
                ax.set_title('surface');
                ax.set_xlabel('x')
                ax.set_ylabel('y')
                ax.set_zlabel('z');
                ax.view_init(azim=90, elev=90)
                if action ==0:
                    writer.add_figure('x1 vs x2 with a = 0', fig, global_step=done_episodes/100)
                if action ==1:
                    writer.add_figure('x1 vs x2 with a = 1', fig, global_step=done_episodes/100)    
            
        if mean_rewards >= 500:
            print(f'Solved in {step_idx} steps and {done_episodes} episodes!')
            torch.save(agent_net.state_dict(), 'cartpole_learner.mod')
            torch.save(reward_net.state_dict(), 'cartpole-v1_reward_func.mod')
            break
        

    if batch_episodes < EPISODES_TO_TRAIN:
        continue

    states_v = torch.FloatTensor(batch_states)
    batch_actions_t = torch.LongTensor(batch_actions)
    batch_qvals_v = torch.FloatTensor(batch_qvals)

    # reward function learning
    demo_states = demonstrations['states']
    demo_actions = demonstrations['actions']
    demo_probs = demonstrations['traj_probs']
    for rf_i in range(10):
        selected = np.random.choice(len(demonstrations), DEMO_BATCH)
        demo_states = demo_states[selected]
        demo_actions = demo_actions[selected]
        demo_probs = demo_probs[selected]
        demo_batch_states, demo_batch_actions = [], []
        for idx in range(len(demo_states)):
            demo_batch_states.extend(demo_states[idx])
            demo_batch_actions.extend(demo_actions[idx])
        demo_batch_states = torch.FloatTensor(demo_batch_states)
        demo_batch_actions = torch.FloatTensor(demo_batch_actions)
        D_demo = torch.cat([demo_batch_states, demo_batch_actions.view(-1, 1)], dim=-1)
        D_samp = torch.cat([states_v, batch_actions_t.float().view(-1, 1)], dim=-1)
        D_samp = torch.cat([D_demo, D_samp])
        # dummy importance weights - fix later
        z = torch.ones((D_samp.shape[0], 1))

        # objective
        D_demo_out = reward_net(D_demo)
        D_samp_out = reward_net(D_samp)
        D_samp_out = z * torch.exp(D_samp_out)
        loss_rwd = torch.mean(D_demo_out) - torch.log(torch.mean(D_samp_out))
        loss_rwd = -loss_rwd  # for maximization

        # update parameters
        optimizer_reward.zero_grad()
        loss_rwd.backward()
        optimizer_reward.step()

    # agent
    optimizer_agent.zero_grad()
    logits_v = agent_net(states_v)
    log_prob_v = torch.log_softmax(logits_v, dim=1)
    # REINFORCE
    log_prob_actions_v = batch_qvals_v * log_prob_v[range(len(batch_states)), batch_actions_t]
    loss_v = -log_prob_actions_v.mean()

    loss_v.backward()
    optimizer_agent.step()

    batch_episodes = 0
    batch_states.clear()
    batch_actions.clear()
    batch_qvals.clear()
env.close()
writer.close()

12: reward:  12.00, mean_100:  12.00, episodes: 1, reward function loss: 0.0000
29: reward:  17.00, mean_100:  14.50, episodes: 2, reward function loss: 0.0000
60: reward:  31.00, mean_100:  20.00, episodes: 3, reward function loss: 0.0000
73: reward:  13.00, mean_100:  18.25, episodes: 4, reward function loss: 0.0001
104: reward:  31.00, mean_100:  20.80, episodes: 5, reward function loss: 0.0001
126: reward:  22.00, mean_100:  21.00, episodes: 6, reward function loss: 0.0001
137: reward:  11.00, mean_100:  19.57, episodes: 7, reward function loss: 0.0001
151: reward:  14.00, mean_100:  18.88, episodes: 8, reward function loss: -0.0002
172: reward:  21.00, mean_100:  19.11, episodes: 9, reward function loss: -0.0002
210: reward:  38.00, mean_100:  21.00, episodes: 10, reward function loss: -0.0002
269: reward:  59.00, mean_100:  24.45, episodes: 11, reward function loss: -0.0002
300: reward:  31.00, mean_100:  25.00, episodes: 12, reward function loss: -0.0003
316: reward:  16.00, mea

4734: reward:  41.00, mean_100:  47.34, episodes: 100, reward function loss: -0.0014
4762: reward:  28.00, mean_100:  47.50, episodes: 101, reward function loss: -0.0014
4796: reward:  34.00, mean_100:  47.67, episodes: 102, reward function loss: -0.0014
4858: reward:  62.00, mean_100:  47.98, episodes: 103, reward function loss: -0.0014
4926: reward:  68.00, mean_100:  48.53, episodes: 104, reward function loss: -0.0029
4953: reward:  27.00, mean_100:  48.49, episodes: 105, reward function loss: -0.0029
4982: reward:  29.00, mean_100:  48.56, episodes: 106, reward function loss: -0.0029
5013: reward:  31.00, mean_100:  48.76, episodes: 107, reward function loss: -0.0029
5083: reward:  70.00, mean_100:  49.32, episodes: 108, reward function loss: -0.0029
5145: reward:  62.00, mean_100:  49.73, episodes: 109, reward function loss: -0.0029
5167: reward:  22.00, mean_100:  49.57, episodes: 110, reward function loss: -0.0029
5206: reward:  39.00, mean_100:  49.37, episodes: 111, reward fun

12437: reward:  68.00, mean_100:  77.03, episodes: 200, reward function loss: -0.0047
12488: reward:  51.00, mean_100:  77.26, episodes: 201, reward function loss: -0.0047
12547: reward:  59.00, mean_100:  77.51, episodes: 202, reward function loss: -0.0047
12618: reward:  71.00, mean_100:  77.60, episodes: 203, reward function loss: -0.0047
12676: reward:  58.00, mean_100:  77.50, episodes: 204, reward function loss: -0.0053
12741: reward:  65.00, mean_100:  77.88, episodes: 205, reward function loss: -0.0053
12812: reward:  71.00, mean_100:  78.30, episodes: 206, reward function loss: -0.0053
12882: reward:  70.00, mean_100:  78.69, episodes: 207, reward function loss: -0.0053
12947: reward:  65.00, mean_100:  78.64, episodes: 208, reward function loss: -0.0059
13015: reward:  68.00, mean_100:  78.70, episodes: 209, reward function loss: -0.0059
13062: reward:  47.00, mean_100:  78.95, episodes: 210, reward function loss: -0.0059
13095: reward:  33.00, mean_100:  78.89, episodes: 211

26435: reward: 244.00, mean_100: 142.22, episodes: 296, reward function loss: -0.0136
26665: reward: 230.00, mean_100: 143.93, episodes: 297, reward function loss: -0.0136
26889: reward: 224.00, mean_100: 145.57, episodes: 298, reward function loss: -0.0136
27077: reward: 188.00, mean_100: 147.08, episodes: 299, reward function loss: -0.0136
27245: reward: 168.00, mean_100: 148.08, episodes: 300, reward function loss: -0.0146
27479: reward: 234.00, mean_100: 149.91, episodes: 301, reward function loss: -0.0146
27690: reward: 211.00, mean_100: 151.43, episodes: 302, reward function loss: -0.0146
27888: reward: 198.00, mean_100: 152.70, episodes: 303, reward function loss: -0.0146
28117: reward: 229.00, mean_100: 154.41, episodes: 304, reward function loss: -0.0160
28347: reward: 230.00, mean_100: 156.06, episodes: 305, reward function loss: -0.0160
28584: reward: 237.00, mean_100: 157.72, episodes: 306, reward function loss: -0.0160
28839: reward: 255.00, mean_100: 159.57, episodes: 307

63824: reward: 416.00, mean_100: 383.28, episodes: 392, reward function loss: -0.0296
64324: reward: 500.00, mean_100: 386.02, episodes: 393, reward function loss: -0.0296
64713: reward: 389.00, mean_100: 387.67, episodes: 394, reward function loss: -0.0296
65213: reward: 500.00, mean_100: 390.22, episodes: 395, reward function loss: -0.0296
65713: reward: 500.00, mean_100: 392.78, episodes: 396, reward function loss: -0.0394
66213: reward: 500.00, mean_100: 395.48, episodes: 397, reward function loss: -0.0394
66532: reward: 319.00, mean_100: 396.43, episodes: 398, reward function loss: -0.0394
67032: reward: 500.00, mean_100: 399.55, episodes: 399, reward function loss: -0.0394
67119: reward:  87.00, mean_100: 398.74, episodes: 400, reward function loss: -0.0098
67512: reward: 393.00, mean_100: 400.33, episodes: 401, reward function loss: -0.0098
68012: reward: 500.00, mean_100: 403.22, episodes: 402, reward function loss: -0.0098
68233: reward: 221.00, mean_100: 403.45, episodes: 403

106554: reward: 500.00, mean_100: 444.24, episodes: 488, reward function loss: 0.0000
107054: reward: 500.00, mean_100: 444.24, episodes: 489, reward function loss: 0.0000
107554: reward: 500.00, mean_100: 445.87, episodes: 490, reward function loss: 0.0000
108054: reward: 500.00, mean_100: 446.46, episodes: 491, reward function loss: 0.0000
108554: reward: 500.00, mean_100: 447.30, episodes: 492, reward function loss: -0.0000
109054: reward: 500.00, mean_100: 447.30, episodes: 493, reward function loss: -0.0000
109554: reward: 500.00, mean_100: 448.41, episodes: 494, reward function loss: -0.0000
110054: reward: 500.00, mean_100: 448.41, episodes: 495, reward function loss: -0.0000
110554: reward: 500.00, mean_100: 448.41, episodes: 496, reward function loss: -0.0131
111054: reward: 500.00, mean_100: 448.41, episodes: 497, reward function loss: -0.0131
111554: reward: 500.00, mean_100: 450.22, episodes: 498, reward function loss: -0.0131
112054: reward: 500.00, mean_100: 450.22, episo

In [19]:
## Testing 
from util import Agent
agent_net.eval()
agent_ = Agent(agent_net, apply_softmax=True, preprocessor=ptan.agent.float32_preprocessor)

for i in range(10):
    state = env.reset()
    Reward = 0
    done = False
    while not done:
        env.render()
        action = agent_(state)
        new_state, reward, done, _ = env.step(int(action))
        Reward += reward
    print("Trial :", i, " Reward: ", Reward)
env.close()

Trial : 0  Reward:  10.0
Trial : 1  Reward:  9.0
Trial : 2  Reward:  11.0
Trial : 3  Reward:  10.0
Trial : 4  Reward:  15.0
Trial : 5  Reward:  10.0
Trial : 6  Reward:  10.0
Trial : 7  Reward:  8.0
Trial : 8  Reward:  8.0
Trial : 9  Reward:  10.0


In [12]:
state

array([ 0.04120645, -0.00432056, -0.01407001, -0.01187006])

In [None]:
%matplotlib
S1 = np.linspace(-5, 5, 100)
S2 = np.linspace(-3.1457, 3.1457, 100)
S3 = 0*S1
S4 = 0*S1
S5 = np.ones(100)
Reward = np.zeros((100,100))
for i in range(100):
    for j in range(100):
        state = [S1[i], S2[j], 0, 0]
        action = 1
        x = torch.cat([float32_preprocessor(state), float32_preprocessor([int(action)])]).view(1, -1)
        r = reward_net(x)
        Reward[i,j] = r
        
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
                cmap='viridis', edgecolor='none')
ax.set_title('surface');
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z');
ax.view_init(azim=90, elev=90)