In [87]:
import numpy as np
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as dis
from taxi_env import TaxiEnv
%load_ext autoreload
%autoreload 2
from tools import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [88]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
device

device(type='cpu')

In [None]:
class PyTorchPolicy:
    def __init__(self, state_size, action_size, lr = 0.001):
        self.state_size = state_size
        self.action_size = action_size

        self.policy = nn.Sequential(
            nn.Linear(state_size, state_size // 2),
            nn.ReLU(),
            nn.Linear(state_size // 2, action_size),
        ).to(device)

        self.optimizer = optim.Adam(self.policy.parameters(), lr = lr)
        self.criterion = nn.CrossEntropyLoss()

    def get_action(self, state):
        state = torch.tensor(state, dtype = torch.float, device = device)
        probs = nn.functional.softmax(self.policy(state), dim = 0)
        m = dis.Categorical(probs)
        action = m.sample()
        return action.item(), m.log_prob(action)

    def update(self, state, action, reward):
        self.optimizer.zero_grad()

        state = torch.tensor(state, dtype = torch.float, device = device)
        action = torch.tensor(action, dtype = torch.long)
        
        loss = reward * self.criterion(self.policy(state).cpu(), action)

        loss.backward()
        self.optimizer.step()


def train(fuel_limit = 5000, episodes = 5000, lr = 0.001, gamma = 0.99):
    env = TaxiEnv(fuel_limit)
    
    obs, _ = env.reset()
    state_size = len(get_state(obs))
    action_size = 6

    policy_model = PyTorchPolicy(state_size, action_size, lr = lr)

    rewards_per_episode = []
    pickup_per_episode = []
    step_per_episode = []

    for episode in range(episodes):
        obs, _ = env.reset()
        state = get_state(obs)

        done = False
        total_reward = 0
        log_probs = []
        rewards = []

        first_pickup = True
        first_visit = [True] * 4

        while not done:
            action, log_prob = policy_model.get_action(state)
            obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            next_state = get_state(obs)
            state = next_state

            if env.passenger_picked_up:
                reward += 0.05
                if first_pickup:
                    reward += 200
                    first_pickup = False
            for i in range(0, 4, 2):
                if first_visit[i]:
                    if state[i] == 0 and state[i + 1] == 0:
                        reward += 50
                        first_visit[i] = False
            if terminated:
                reward += 1000

            total_reward += reward
            log_probs.append(log_prob)
            rewards.append(reward)

            # print(state)
            # env.render(action, reward)

        rewards_per_episode.append(total_reward)
        pickup_per_episode.append(not first_pickup)
        step_per_episode.append(len(rewards))

        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + gamma * G
            returns.insert(0, G)

        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + 1e-9)

        loss = torch.stack(log_probs) * -returns
        loss = loss.sum()

        policy_model.optimizer.zero_grad()
        loss.backward()
        policy_model.optimizer.step()

        if (episode + 1) % 10 == 0:
            print(f"Episode {episode + 1}/{episodes}, Average Reward: {np.mean(rewards_per_episode[-10:]):07.3f}, Average Step: {np.mean(step_per_episode[-10:]):05f}, Pickup Count: {np.count_nonzero(pickup_per_episode[-10:]):02d}, Success count: {np.count_nonzero(np.array(step_per_episode[-10:]) < fuel_limit):02d}")

    env.close()
    return policy_model

In [90]:
policy_model = train(fuel_limit = 1000)

Episode 10/5000, Average Reward: -4775.990, Average Step: 1000.000000, Pickup Count: 01, Success count: 00
Episode 20/5000, Average Reward: -4427.140, Average Step: 1000.000000, Pickup Count: 04, Success count: 00
Episode 30/5000, Average Reward: -4402.735, Average Step: 1000.000000, Pickup Count: 04, Success count: 00
Episode 40/5000, Average Reward: -3828.535, Average Step: 921.300000, Pickup Count: 04, Success count: 01
Episode 50/5000, Average Reward: -4513.905, Average Step: 908.400000, Pickup Count: 06, Success count: 01
Episode 60/5000, Average Reward: -4413.900, Average Step: 1000.000000, Pickup Count: 04, Success count: 00
Episode 70/5000, Average Reward: -3969.525, Average Step: 976.400000, Pickup Count: 03, Success count: 01
Episode 80/5000, Average Reward: -3888.135, Average Step: 995.300000, Pickup Count: 05, Success count: 01
Episode 90/5000, Average Reward: -4003.470, Average Step: 1000.000000, Pickup Count: 02, Success count: 00


KeyboardInterrupt: 

In [None]:
torch.save(policy_model.policy.state_dict(), "policy")
torch.save(policy_model.optimizer.state_dict(), "optimizer")