In [1]:
import pickle
import random
from collections import namedtuple

import gym
import numpy as np
import ptan
import torch
import torch.optim as optim
from ptan.agent import float32_preprocessor

from util import PGN, RewardNet

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
from matplotlib import pyplot as plt

In [2]:
GAMMA = 0.99
LEARNING_RATE = 0.01
EPISODES_TO_TRAIN = 4
DEMO_BATCH = 50
seed = 0
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)

EpisodeStep = namedtuple('EpisodeStep', field_names=['state', 'action', 'reward', 'next_state'])
Trajectory = namedtuple('Trajectory', field_names=['prob', 'episode_steps'])

In [3]:
def calc_qvals(rewards):
    res = []
    sum_r = 0.0
    for r in reversed(rewards):
        sum_r *= GAMMA
        sum_r += r
        res.append(sum_r)
    return list(reversed(res))


def process_demonstrations(demo_samples):
    traj_states, traj_actions, traj_qvals, traj_prob = [], [], [], []
    for traj in demo_samples:
        states, actions, rewards, qvals = [], [], [], []
        traj_prob.append(traj.prob)
        for step in traj.episode_steps:
            states.append(step.state)
            actions.append(step.action)
            rewards.append(step.reward)
        qvals.extend(calc_qvals(rewards))

        traj_states.append(states)
        traj_actions.append(actions)
        traj_qvals.append(qvals)
    traj_states = np.array(traj_states, dtype=np.object)
    traj_actions = np.array(traj_actions, dtype=np.object)
    traj_qvals = np.array(traj_qvals, dtype=np.object)
    traj_prob = np.array(traj_prob, dtype=np.float)
    return {'states': traj_states, 'actions': traj_actions, 'qvals': traj_qvals, 'traj_probs': traj_prob}

In [4]:
env = gym.make('CartPole-v1')
agent_net = PGN(env.observation_space.shape[0], env.action_space.n)
reward_net = RewardNet(env.observation_space.shape[0] + 1)
agent = ptan.agent.PolicyAgent(agent_net, preprocessor=float32_preprocessor, apply_softmax=True)
exp_source = ptan.experience.ExperienceSourceFirstLast(env, agent, gamma=GAMMA)
optimizer_agent = optim.Adam(agent_net.parameters(), lr=LEARNING_RATE)
optimizer_reward = optim.Adam(reward_net.parameters(), lr=1e-2, weight_decay=1e-4)

In [5]:
with open('demonstrations.list.pkl', 'rb') as f:
    demonstrations = pickle.load(f)
assert (len(demonstrations) > DEMO_BATCH)
print(f'Number of demonstrations: {len(demonstrations)}')
demonstrations = process_demonstrations(demonstrations)

Number of demonstrations: 100


In [10]:
len(batch_states), len(batch_qvals)

(104, 104)

In [6]:
total_rewards = []
step_idx = 0
done_episodes = 0

batch_episodes = 0
batch_states, batch_actions, batch_qvals = [], [], []
cur_rewards = []
loss_rwd = 0.

for step_idx, exp in enumerate(exp_source):
    batch_states.append(exp.state)
    batch_actions.append(int(exp.action))
    x = torch.cat([float32_preprocessor(exp.state), float32_preprocessor([int(exp.action)])]).view(1, -1)
    reward = reward_net(x)
    cur_rewards.append(reward.item())

    if exp.last_state is None:
        batch_qvals.extend(calc_qvals(cur_rewards))
        cur_rewards.clear()
        batch_episodes += 1

    new_rewards = exp_source.pop_total_rewards()
    if new_rewards:
        done_episodes += 1
        reward = new_rewards[0]
        total_rewards.append(reward)
        mean_rewards = float(np.mean(total_rewards[-100:]))
        writer.add_scalar('reward', reward, done_episodes)
        writer.add_scalar('mean_reward', mean_rewards, done_episodes)
        writer.add_scalar('loss_rwd', loss_rwd, done_episodes) 
        print(f'{step_idx}: reward: {reward:6.2f}, mean_100: {mean_rewards:6.2f}, '
              f'episodes: {done_episodes}, reward function loss: {loss_rwd:6.4f}')
        
        if done_episodes%100 == 0 or mean_rewards>=500:
            N_samp = 20
            S1 = np.linspace(-5, 5, N_samp)
            S2 = np.linspace(-3.1457, 3.1457, N_samp)
            S3 = 0*S1
            S4 = 0*S1
            
            S5 = np.ones(N_samp)
            
            for action in [0,1]:
                Reward = np.zeros((N_samp,N_samp))
                for i in range(N_samp):
                    for j in range(N_samp):
                        state = [S1[i], S2[j], 0, 0]
#                         action = 1
                        x = torch.cat([float32_preprocessor(state), float32_preprocessor([int(action)])]).view(1, -1)
                        r = reward_net(x)
                        Reward[i,j] = r

                X, Y = np.meshgrid(S1, S2)
                Z = Reward
                fig = plt.figure()
                ax = plt.axes(projection='3d')
                ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
                                cmap='viridis', edgecolor='none')
                ax.set_title('surface');
                ax.set_xlabel('x')
                ax.set_ylabel('y')
                ax.set_zlabel('z');
                ax.view_init(azim=90, elev=90)
                if action ==0:
                    writer.add_figure('x1 vs x2 with a = 0', fig, global_step=done_episodes/100)
                if action ==1:
                    writer.add_figure('x1 vs x2 with a = 1', fig, global_step=done_episodes/100)    
            
        if mean_rewards >= 500:
            print(f'Solved in {step_idx} steps and {done_episodes} episodes!')
            torch.save(agent_net.state_dict(), 'cartpole_learner.mod')
            torch.save(reward_net.state_dict(), 'cartpole-v1_reward_func.mod')
            break
        

    if batch_episodes < EPISODES_TO_TRAIN:
        continue

    states_v = torch.FloatTensor(batch_states)
    batch_actions_t = torch.LongTensor(batch_actions)
    batch_qvals_v = torch.FloatTensor(batch_qvals)

    # reward function learning
    demo_states = demonstrations['states']
    demo_actions = demonstrations['actions']
    demo_probs = demonstrations['traj_probs']
    for rf_i in range(10):
        # Check is preprocessing is required 
        selected = np.random.choice(len(demonstrations), DEMO_BATCH)
        print(selected)
        demo_states = demo_states[selected]
        demo_actions = demo_actions[selected]
#         demo_probs = demo_probs[selected]
        demo_batch_states, demo_batch_actions = [], []
        for idx in range(len(demo_states)):
            demo_batch_states.extend(demo_states[idx])
            demo_batch_actions.extend(demo_actions[idx])
        demo_batch_states = torch.FloatTensor(demo_batch_states)
        demo_batch_actions = torch.FloatTensor(demo_batch_actions)
        D_demo = torch.cat([demo_batch_states, demo_batch_actions.view(-1, 1)], dim=-1)
        D_samp = torch.cat([states_v, batch_actions_t.float().view(-1, 1)], dim=-1)
        D_samp = torch.cat([D_demo, D_samp])
        # dummy importance weights - fix later
        z = torch.ones((D_samp.shape[0], 1))

        # objective
        D_demo_out = reward_net(D_demo)
        D_samp_out = reward_net(D_samp)
        D_samp_out = z * torch.exp(D_samp_out)
        loss_rwd = torch.mean(D_demo_out) - torch.log(torch.mean(D_samp_out))
        loss_rwd = -loss_rwd  # for maximization

        # update parameters
        optimizer_reward.zero_grad()
        loss_rwd.backward()
        optimizer_reward.step()

    # agent
    optimizer_agent.zero_grad()
    logits_v = agent_net(states_v)
    log_prob_v = torch.log_softmax(logits_v, dim=1)
    # REINFORCE
    log_prob_actions_v = batch_qvals_v * log_prob_v[range(len(batch_states)), batch_actions_t]
    loss_v = -log_prob_actions_v.mean()

    loss_v.backward()
    optimizer_agent.step()

    batch_episodes = 0
    batch_states.clear()
    batch_actions.clear()
    batch_qvals.clear()
env.close()
writer.close()

13: reward:  13.00, mean_100:  13.00, episodes: 1, reward function loss: 0.0000
29: reward:  16.00, mean_100:  14.50, episodes: 2, reward function loss: 0.0000
60: reward:  31.00, mean_100:  20.00, episodes: 3, reward function loss: 0.0000
[1 0 3 1 1 3 0 2 2 3 2 3 3 3 2 1 2 2 3 2 3 3 2 2 3 0 1 2 3 2 1 2 1 0 2 2 3
 0 3 2 3 0 0 2 0 2 3 2 2 3]
[0 0 0 1 2 0 1 2 3 2 2 3 1 0 0 0 3 3 0 2 2 0 0 3 0 1 2 0 0 3 1 3 0 2 1 1 1
 3 0 0 0 1 1 2 1 0 0 1 2 1]
[1 3 1 0 0 0 3 3 1 2 1 3 1 3 0 0 3 1 2 0 2 3 3 2 3 1 1 1 3 2 0 0 1 0 3 2 3
 3 2 0 2 3 2 2 1 1 3 3 0 3]
[0 0 2 2 3 1 1 0 3 0 2 3 3 1 3 1 3 2 0 1 1 3 3 3 1 1 1 3 3 2 2 2 0 3 3 1 2
 3 1 3 3 1 1 1 0 2 2 0 1 1]
[0 3 0 0 3 2 1 2 1 2 0 2 0 1 1 1 0 3 0 3 0 0 0 0 1 3 3 0 3 2 2 1 1 1 2 1 1
 0 1 2 1 1 0 0 1 1 2 3 3 1]
[3 3 3 2 2 1 2 3 2 0 3 3 1 1 1 2 1 1 1 2 0 3 1 0 3 1 1 2 0 2 3 0 2 0 2 0 0
 3 3 3 0 0 3 1 3 2 3 0 3 0]
[3 3 1 2 1 0 0 3 0 3 0 0 1 0 2 2 2 0 0 2 0 2 3 2 3 3 0 1 1 1 0 0 3 3 1 3 0
 3 0 1 2 2 0 2 1 3 3 3 1 2]
[1 2 3 2 1 2 3 3 0 0 2 1 3 1 2 2 1 0 2 

KeyboardInterrupt: 

In [None]:
## Testing 
from util import Agent
agent_net.eval()
agent_ = Agent(agent_net, apply_softmax=True, preprocessor=ptan.agent.float32_preprocessor)

for i in range(10):
    state = env.reset()
    Reward = 0
    done = False
    while not done:
        env.render()
        action = agent_(state)
        new_state, reward, done, _ = env.step(int(action))
        Reward += reward
    print("Trial :", i, " Reward: ", Reward)
env.close()

In [None]:
selected, len(demonstrations), DEMO_BATCH

In [None]:
%matplotlib
S1 = np.linspace(-5, 5, 100)
S2 = np.linspace(-3.1457, 3.1457, 100)
S3 = 0*S1
S4 = 0*S1
S5 = np.ones(100)
Reward = np.zeros((100,100))
for i in range(100):
    for j in range(100):
        state = [S1[i], S2[j], 0, 0]
        action = 1
        x = torch.cat([float32_preprocessor(state), float32_preprocessor([int(action)])]).view(1, -1)
        r = reward_net(x)
        Reward[i,j] = r
        
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
                cmap='viridis', edgecolor='none')
ax.set_title('surface');
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z');
ax.view_init(azim=90, elev=90)

In [None]:
demonstrations['states'][0]