Decision Transformer: Reinforcement Learning via Sequence Modeling
Recreated by : Austin Runkle, Fatih Bozdogan

In this project we will be implementing the decision transformer and comparing its preformance
to an existing RL model TD learning

IMPORTS

In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

import gymnasium as gym
from gym.wrappers import AtariPreprocessing, FrameStack

import matplotlib.pyplot as plt
import time

Decision Transformer - For Continuous Action

In [9]:
# For continuous action, we use an MSE
# https://github.com/kzl/decision-transformer



Supporting Functions

Online v. Offline RL: 
* Online: Learn from experience
* Offline RL: Learn from shown experience

Decision Transformer Function

In [5]:
# R, s, a, t: returns -to -go , states , actions , or timesteps
# K: context length ( length of each input to DecisionTransformer )
# transformer : transformer with causal masking (GPT)
# embed_s , embed_a , embed_R : linear embedding layers
# embed_t : learned episode positional embedding
# pred_a : linear action prediction layer
# main model
def DecisionTransformer (R , s , a , t ):
    # compute embeddings for tokens
    pos_embedding = embed_t ( t ) # per - timestep ( note : not per - token )
    s_embedding = embed_s ( s ) + pos_embedding
    a_embedding = embed_a ( a ) + pos_embedding
    R_embedding = embed_R ( R ) + pos_embedding
    # interleave tokens as (R_1 , s_1 , a_1 , ... , R_K , s_K )
    input_embeds = stack ( R_embedding , s_embedding , a_embedding )
    # use transformer to get hidden states
    hidden_states = transformer ( input_embeds = input_embeds )
    # select hidden states for action prediction tokens
    a_hidden = unstack ( hidden_states ). actions
    # predict action
    return pred_a ( a_hidden )
# training loop
for (R , s , a , t ) in dataloader : # dims : ( batch_size , K, dim )
    a_preds = DecisionTransformer (R , s , a , t )
    loss = mean (( a_preds - a )**2) # L2 loss for continuous actions
    optimizer . zero_grad (); loss . backward (); optimizer . step ()
# evaluation loop
target_return = 1 # for instance , expert - level return
R , s , a , t , done = [ target_return ] , [ env . reset ()] , [] , [1] , False
while not done : # autoregressive generation / sampling
    # sample next action
    action = DecisionTransformer (R , s , a , t )[ -1] # for cts actions
    new_s , r , done , _ = env . step ( action )
    # append new tokens to sequence
    R = R + [ R [ -1] - r] # decrement returns -to -go with reward
    s , a , t = s + [ new_s ] , a + [ action ] , t + [ len ( R )]
    R , s , a , t = R [ - K :] , ... # only keep context length of K

NameError: name 'dataloader' is not defined

Neural Network for Q Learning Atari using convolution neural network
https://docs.pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html <- Building Neural Networks
https://docs.pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html <- training classifier
https://arxiv.org/pdf/1312.5602 <- confusion matrix sizing

In [6]:
# convolution neural network to work with atari
class QLearningNetwork(nn.Module):
    def __init__(self, num_actions):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size = 8, stride = 4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size = 4, stride = 2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size = 3, stride = 1),
            nn.ReLU()
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(7 * 7 * 64, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions)
        )

    def forward(self, x):
        x = x / 255.0
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        return self.fc_layers(x)

Temporal Difference Learning - Q learning agent

In [7]:
# TODO - Adapt to be a Q-learning agent <- Neural network
class TD_QLearningAgent(BaseAgent):
    def agent_init(self, agent_info={}):
        self.rand_generator = np.random.RandomState(agent_info.get("seed"))
        # Discount factor (gamma) to use in the updates.
        self.discount = agent_info.get("discount")
        # The learning rate or step size parameter (alpha) to use in updates.
        self.step_size = agent_info.get("step_size")

        self.num_states = agent_info.get("num_states")
        self.num_actions = agent_info.get("num_actions")

        # initialize the neural network

        # This line is drawn from PyTorch documentation
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.q_net = QLearningNetwork(self.num_actions).to(self.device)
        self.optimizer = optim.SGD(self.q_net.parameters(), lr = .001, momentum = .9)
        self.loss_fn = nn.MSELoss()

        # initialize the agent init state and agent to none
        self.state = None
        self.action = None
        
    def agent_start(self, state):
        tensor = torch.tensor(state, dtype = torch.float32, device = self.device).unsqueeze(0)
        q_values = self.q_net(tensor)
        action = torch.argmax(q_values, dim = 1).item()
        self.last_state = state
        self.last_action = action
        return action

    def agent_step(self, reward, state):
        # get the current and next state as tensor
        cur_state = torch.tensor(self.last_state, dtype = torch.float32, device = self.device).unsqueeze(0)
        next_state = torch.tensor(state, dtype = torch.float32, device = self.device).unsqueeze(0)

        q_values = self.q_net(cur_state)
        next_q = self.q_net(next_state)

        loss = self.loss_fn(q_values[0, self.last_action], (reward + self.discount * torch.max(next_q)).detach())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # e greedy next action
        action = torch.argmax(next_q, dim = 1).item()
        self.last_state = state
        self.last_action = action
        return action

    def agent_end(self, reward):
        # for agent_end compute just the last action 
        cur_state = torch.tensor(self.last_state, dtype = torch.float32, device = self.device).unsqueeze(0)
        q_values = self.q_net(cur_state)
        loss = self.loss_fn((q_values[0, self.last_action]), torch.tensor(reward, device = self.device))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def agent_cleanup(self):        
        self.last_state = None
        self.last_action = None

NameError: name 'BaseAgent' is not defined

Training Agents on Atari - THERE IS A LINK ON THE SLIDES FROM 11/2 

https://ale.farama.org/environments/complete_list/

In [8]:
env = gym.make("Breakout",obs_type = "rgb", frame_skip = 1, repeat_action_probability = 0, full_action_space = False)
env.reset()

# get the number of actions
num_actions = env.action_space.n

# get the actions associated with inputs
# for breakout
# 0 = Back
# 1 = launch
# 2 = left
# 3 = right
meaning = env.unwrapped.get_action_meanings()

# for testing 
obs, reward, terminated, truncated, info = env.step(0)

NameNotFound: Environment `Breakout` doesn't exist.

Compare Results