In [113]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as dis
from taxi_env import TaxiEnv
%load_ext autoreload
%autoreload 2
from tools import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [114]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
device

device(type='cpu')

In [None]:
class PyTorchPolicy:
    def __init__(self, state_size, action_size, lr = 0.001):
        self.state_size = state_size
        self.action_size = action_size

        self.policy = nn.Sequential(
            nn.Linear(state_size, action_size)
        ).to(device)

        self.optimizer = optim.Adam(self.policy.parameters(), lr = lr)
        self.criterion = nn.CrossEntropyLoss()

    def get_action(self, state):
        state = torch.tensor(state, dtype = torch.float, device = device)
        probs = nn.functional.softmax(self.policy(state), dim = 0)
        m = dis.Categorical(probs)
        action = m.sample()
        return action.item(), m.log_prob(action)

    def update(self, state, action, reward):
        self.optimizer.zero_grad()

        state = torch.tensor(state, dtype = torch.float, device = device)
        action = torch.tensor(action, dtype = torch.long)
        
        loss = reward * self.criterion(self.policy(state).cpu(), action)

        loss.backward()
        self.optimizer.step()


def train(fuel_limit = 5000, episodes = 5000, lr = 0.001, gamma = 0.99):
    env = TaxiEnv()
    
    obs, _ = env.reset()
    state_size = len(get_state(obs))
    action_size = 6

    policy_model = PyTorchPolicy(state_size, action_size, lr = lr)

    rewards_per_episode = []

    for episode in range(episodes):
        obs, _ = env.reset()
        state = get_state(obs)

        done = False
        total_reward = 0
        log_probs = []
        rewards = []

        while not done:
            action, log_prob = policy_model.get_action(state)
            obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            next_state = get_state(obs)
            state = next_state

            if env.passenger_picked_up:
                reward += 0.05

            total_reward += reward
            log_probs.append(log_prob)
            rewards.append(reward)

            # print(state)
            # env.render(action, reward)

        rewards_per_episode.append(total_reward)

        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + gamma * G
            returns.insert(0, G)

        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + 1e-9)

        loss = torch.stack(log_probs) * -returns
        loss = loss.sum()

        policy_model.optimizer.zero_grad()
        loss.backward()
        policy_model.optimizer.step()

        if (episode + 1) % 10 == 0:
            print(f"Episode {episode + 1}/{episodes}, Average Reward: {np.mean(rewards_per_episode[-10:]):.3f}")

    env.close()
    return policy_model

In [116]:
policy_model = train()

Episode 10/5000, Average Reward: -20508.410
Episode 20/5000, Average Reward: -18643.680
Episode 30/5000, Average Reward: -18656.380
Episode 40/5000, Average Reward: -16385.570
Episode 50/5000, Average Reward: -18670.360
Episode 60/5000, Average Reward: -14053.640
Episode 70/5000, Average Reward: -17023.800
Episode 80/5000, Average Reward: -16798.500
Episode 90/5000, Average Reward: -17615.500
Episode 100/5000, Average Reward: -15798.260
Episode 110/5000, Average Reward: -15181.650
Episode 120/5000, Average Reward: -14723.760
Episode 130/5000, Average Reward: -12493.560
Episode 140/5000, Average Reward: -11968.830
Episode 150/5000, Average Reward: -12269.720
Episode 160/5000, Average Reward: -10343.200
Episode 170/5000, Average Reward: -12783.950
Episode 180/5000, Average Reward: -12122.350
Episode 190/5000, Average Reward: -14030.000
Episode 200/5000, Average Reward: -14750.700
Episode 210/5000, Average Reward: -12591.000
Episode 220/5000, Average Reward: -9616.280
Episode 230/5000, Av

In [117]:
torch.save(policy_model.policy.state_dict(), "policy")
torch.save(policy_model.optimizer.state_dict(), "optimizer")