Decision Transformer: Reinforcement Learning via Sequence Modeling
Recreated by : Austin Runkle, Fatih Bozdogan, Haocheng Cao

In this project we will be implementing the decision transformer and comparing its preformance
to an existing RL model TD learning

IMPORTS

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
import numpy as np

#import gymnasium as gym
#from gym.wrappers import AtariPreprocessing, FrameStack

import matplotlib.pyplot as plt
import time

Decision Transformer - For Continuous Action

# Decision Transformer (Discrete) — What’s included

This block implements a **Decision Transformer for discrete action spaces** (e.g., Atari). It follows the “(rtg, state, previous action)” tokenization and uses **cross-entropy** to predict the next action at every action token.

## Tokenization
For a context window of length **K**, we create 3 tokens per step:
1. `rtg_t`  (return-to-go at step t)
2. `s_t`    (state at step t)
3. `a_{t-1}` (previous action; use a special START id for t=0)

So the full sequence is:
`[rtg_0, s_0, a_{-1},  rtg_1, s_1, a_0,  ...,  rtg_{K-1}, s_{K-1}, a_{K-2}]`.

At every **action token** position, the model predicts the current action `a_t`.

## Modules
- **CNNStateEncoder**: encodes stacked Atari frames `(4×84×84)` into a `d_model` vector.
- **CausalTransformer**: `nn.TransformerEncoder` with a **causal mask**.
- **DecisionTransformerDiscrete**:
  - Embeddings for RTG (`Linear(1→d)`), state (CNN or `Linear`), action (`Embedding(num_actions+1, d)` with an extra **START** id), timestep (`Embedding`), and token type (rtg/state/action).
  - Interleaves tokens `[rtg, s, a] * K`, applies a causal transformer, then selects hidden states at **action token** positions and projects to action logits.
- **step_mask_to_token_mask**: expands a step mask `(B,K)` to a token mask `(B,3K)`.
- **compute_dt_loss**: cross-entropy over action tokens, ignoring padded steps.
- **compute_rtg**: utility to compute returns-to-go from rewards.
- **dt_sample_action**: inference helper that returns the predicted action for the last step of a window.

## Shapes (batch-first)
- `rtg`:            `(B, K, 1)`
- `states`:         pixel `(B, K, 4, 84, 84)` **or** vector `(B, K, state_dim)`
- `actions_in`:     `(B, K)` longs — these are `a_{t-1}`, with `-1` at `t=0`
- `actions_target`: `(B, K)` longs — labels are `a_t`
- `timesteps`:      `(B, K)` longs
- `step_mask`:      `(B, K)` bool (True=valid, False=pad)
- `attention_mask`: `(B, 3K)` bool (True=keep). Use `step_mask_to_token_mask(step_mask)` to build it.

## Training loop (sketch)
1. Prepare offline windows from rollouts (or a dataset):
   - Compute `rtg` with `compute_rtg(rewards, gamma)`.
   - Build `(rtg, states, actions_in, actions_target, timesteps, step_mask)` for each window.
2. Forward:
   ```python
   loss, logits = compute_dt_loss(model, batch)
   ```
3. Backward:  
   ```python
   optimizer.zero_grad()
   loss.backward()
   optimizer.step()
   ```

## Inference
Use ``dt_sample_action`` to **autoregressively** pick the next action given the current window:  
```python
a = dt_sample_action(model, rtg_seq, state_seq, action_in_seq, timestep_seq)
```
Make sure the first ``action_in_seq[0] == -1`` (START).

## Suggested hyperparameters (starting points)
- ``d_model=256``,``n_layers=4``, ``n_heads=4``, ``dropout=0.1``
- ``K in [20, 30]``
- ``rtg_scale=1000.0``
- Optimizer: AdamW with ``lr=1e-4 ~ 3e-4``, ``weight_decay=0.1``
- Batch size depends on GPU memory (start with 16)


Supporting Functions

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ------------------------------
# 1) Pixel state encoder (for 4x84x84 Atari stacks)
# ------------------------------
class CNNStateEncoder(nn.Module):
    """
    Input:  (B*K, C=4, H=84, W=84)
    Output: (B*K, d_model)
    """
    def __init__(self, in_channels=4, d_model=256):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),           nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),           nn.ReLU(),
        )
        # For 84x84 input, this yields 64x7x7 after the conv stack
        self.fc = nn.Linear(64 * 7 * 7, d_model)

    def forward(self, x):
        # x: (B*K, C, H, W) | uint8 or float
        x = x / 255.0
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


# ------------------------------
# 2) Causal Transformer (GPT-style)
# ------------------------------
class CausalTransformer(nn.Module):
    def __init__(self, d_model=256, n_layers=4, n_heads=4, dropout=0.1):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads,
            dim_feedforward=4 * d_model, dropout=dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)

    def forward(self, x, attn_mask=None, key_padding_mask=None):
        # x: (B, L, d)
        # attn_mask: (L, L)  True = block, False = allow
        # key_padding_mask: (B, L) True = pad(ignored), False = keep
        return self.encoder(x, mask=attn_mask, src_key_padding_mask=key_padding_mask)


def build_causal_mask(L, device):
    """
    Upper-triangular True (=block future), lower incl. diagonal False (=allow past/self).
    Shape: (L, L)
    """
    return torch.triu(torch.ones(L, L, device=device), diagonal=1).bool()


# ------------------------------
# 3) Decision Transformer for DISCRETE actions
#    token order per step: [rtg_t, s_t, a_{t-1}], predict action at each action token
# ------------------------------
class DecisionTransformerDiscrete(nn.Module):
    """
    Works for discrete action spaces (e.g., Atari).
    - Pixel states: (B, K, 4, 84, 84) with pixel_inputs=True
    - Vector states: set pixel_inputs=False and in_channels=state_dim
    """
    def __init__(
        self,
        num_actions,
        d_model=256,
        max_timestep=4096,
        n_layers=4,
        n_heads=4,
        dropout=0.1,
        pixel_inputs=True,
        in_channels=4,      # if pixel_inputs=False, set to state_dim
        rtg_scale=1000.0    # scale RTG to ~[0,1] range for stability
    ):
        super().__init__()
        self.num_actions = num_actions
        self.d_model = d_model
        self.pixel_inputs = pixel_inputs
        self.rtg_scale = rtg_scale

        if pixel_inputs:
            self.state_encoder = CNNStateEncoder(in_channels=in_channels, d_model=d_model)
        else:
            self.state_proj = nn.Linear(in_channels, d_model)  # in_channels=state_dim

        # Action embedding: reserve an extra id for "START" (a_{-1})
        self.embed_action = nn.Embedding(num_actions + 1, d_model)

        # RTG, time, and token-type embeddings
        self.embed_rtg  = nn.Linear(1, d_model)
        self.embed_time = nn.Embedding(max_timestep, d_model)
        # token types: 0=rtg, 1=state, 2=action
        self.embed_tokentype = nn.Embedding(3, d_model)

        self.transformer = CausalTransformer(
            d_model=d_model, n_layers=n_layers, n_heads=n_heads, dropout=dropout
        )
        self.action_head = nn.Linear(d_model, num_actions)
        self.layer_norm = nn.LayerNorm(d_model)

    def _interleave(self, rtg_emb, state_emb, action_emb):
        # rtg_emb/state_emb/action_emb: (B, K, d) -> (B, 3K, d)
        B, K, D = rtg_emb.shape
        seq = torch.stack([rtg_emb, state_emb, action_emb], dim=2)  # (B, K, 3, d)
        seq = seq.view(B, K * 3, D)
        return seq

    def forward(self, rtg, states, actions, timesteps, attention_mask=None):
        """
        Args:
          rtg:       (B, K, 1)  float
          states:    (B, K, C, H, W)  or (B, K, state_dim)
          actions:   (B, K)  long; these are a_{t-1}, with -1 at t=0 (START)
          timesteps: (B, K)  long; either window-local [0..K-1] or real env steps
          attention_mask (optional): (B, 3K) bool; True=keep, False=pad/ignore

        Returns:
          logits: (B, K, num_actions) — per-step action logits (predict a_t at each action token)
        """
        B, K = actions.shape
        device = actions.device

        t_embed = self.embed_time(timesteps)  # (B, K, d)

        # 1) RTG embedding (scaled) + time + token type
        rtg_in  = rtg / self.rtg_scale
        rtg_emb = self.embed_rtg(rtg_in) \
                  + t_embed \
                  + self.embed_tokentype(torch.zeros_like(actions))  # 0=rtg

        # 2) State embedding
        if self.pixel_inputs:
            s = states.view(B * K, *states.shape[2:])   # (B*K, C, H, W)
            s_emb = self.state_encoder(s).view(B, K, self.d_model)
        else:
            s_emb = self.state_proj(states)             # (B, K, d)
        s_emb = s_emb \
                + t_embed \
                + self.embed_tokentype(torch.ones_like(actions))     # 1=state

        # 3) Action embedding (a_{t-1}); fill -1 with START id = num_actions
        a_ids = actions.clone()
        a_ids = torch.where(a_ids < 0, torch.full_like(a_ids, self.num_actions), a_ids)
        a_emb = self.embed_action(a_ids) \
                + t_embed \
                + self.embed_tokentype(torch.full_like(actions, 2))  # 2=action

        # 4) Interleave [rtg, s, a] * K  -> (B, 3K, d)
        x = self._interleave(rtg_emb, s_emb, a_emb)
        x = self.layer_norm(x)

        # 5) Causal attention
        L = x.size(1)
        causal_mask = build_causal_mask(L, device)

        # Convert token mask to key_padding_mask if provided
        key_padding_mask = None
        if attention_mask is not None:
            key_padding_mask = ~attention_mask  # True=pad

        # 6) Transformer + pick hidden states at action token positions
        out = self.transformer(x, attn_mask=causal_mask, key_padding_mask=key_padding_mask)
        # action token positions: 2,5,8,... -> 3*i + 2
        idx = torch.arange(K, device=device) * 3 + 2
        idx = idx.unsqueeze(0).expand(B, -1)  # (B, K)
        a_hidden = out.gather(dim=1, index=idx.unsqueeze(-1).expand(B, K, self.d_model))
        logits = self.action_head(a_hidden)   # (B, K, num_actions)
        return logits


# ------------------------------
# 4) Step mask -> token mask (each step expands to 3 tokens)
# ------------------------------
def step_mask_to_token_mask(step_mask):
    """
    step_mask:  (B, K) bool, True=valid step, False=padding
    return:     (B, 3K) bool
    """
    B, K = step_mask.shape
    return step_mask.unsqueeze(-1).expand(B, K, 3).reshape(B, K * 3)


# ------------------------------
# 5) Cross-entropy loss over action tokens (ignoring padded steps)
# ------------------------------
def compute_dt_loss(model, batch, ignore_index=-100):
    """
    batch must contain:
      rtg:            (B, K, 1) float
      states:         (B, K, C, H, W) or (B, K, state_dim)
      actions_in:     (B, K) long  — a_{t-1}, with -1 at t=0
      actions_target: (B, K) long  — labels a_t
      timesteps:      (B, K) long
      step_mask:      (B, K) bool (True=valid, False=pad)
    Returns:
      loss (scalar), logits (B, K, num_actions)
    """
    rtg       = batch['rtg']
    states    = batch['states']
    actions_in= batch['actions_in']
    targets   = batch['actions_target']
    timesteps = batch['timesteps']
    step_mask = batch.get('step_mask', torch.ones_like(actions_in, dtype=torch.bool))

    token_mask = step_mask_to_token_mask(step_mask)
    logits = model(rtg, states, actions_in, timesteps, attention_mask=token_mask)

    B, K, A = logits.shape
    logits_flat  = logits.reshape(B * K, A)
    targets_flat = targets.reshape(B * K)
    # ignore padded steps
    targets_flat_masked = targets_flat.masked_fill(~step_mask.view(-1), ignore_index)
    loss = F.cross_entropy(logits_flat, targets_flat_masked, ignore_index=ignore_index)
    return loss, logits


# ------------------------------
# 6) Returns-to-go utility
# ------------------------------
def compute_rtg(rewards, gamma=1.0):
    """
    rewards: (T,) list/tensor
    return:  (T, 1) where rtg[t] = sum_{t' >= t} gamma^{t'-t} * r[t']
    """
    if not torch.is_tensor(rewards):
        rewards = torch.tensor(rewards, dtype=torch.float32)
    T = rewards.shape[0]
    rtg = torch.zeros(T, dtype=torch.float32)
    running = 0.0
    for t in reversed(range(T)):
        running = float(rewards[t]) + gamma * running
        rtg[t] = running
    return rtg.unsqueeze(-1)  # (T, 1)


# ------------------------------
# 7) Inference: sample the next action from the last step
# ------------------------------
@torch.no_grad()
def dt_sample_action(model, rtg_seq, state_seq, action_in_seq, timestep_seq):
    """
    Single-window input (predict at the final step):
      rtg_seq:       (K, 1)
      state_seq:     (K, C, H, W) or (K, state_dim)
      action_in_seq: (K,) long; index 0 should be -1 (START)
      timestep_seq:  (K,) long
    Returns:
      int: predicted discrete action a_t at the last position
    """
    model.eval()
    rtg       = rtg_seq.unsqueeze(0)
    states    = state_seq.unsqueeze(0)
    actions   = action_in_seq.unsqueeze(0)
    timesteps = timestep_seq.unsqueeze(0)
    logits = model(rtg, states, actions, timesteps)   # (1, K, A)
    return int(torch.argmax(logits[0, -1]).item())


Online v. Offline RL: 
* Online: Learn from experience
* Offline RL: Learn from shown experience

Decision Transformer Function - Provided Primarily from the Decision Transformer Paper

In [3]:
# R, s, a, t: returns -to -go , states , actions , or timesteps
# K: context length ( length of each input to DecisionTransformer )
# transformer : transformer with causal masking (GPT)
# embed_s , embed_a , embed_R : linear embedding layers
# embed_t : learned episode positional embedding
# pred_a : linear action prediction layer
# main model
def DecisionTransformer (R , s , a , t ):
    # compute embeddings for tokens
    pos_embedding = embed_t ( t ) # per - timestep ( note : not per - token )
    s_embedding = embed_s ( s ) + pos_embedding
    a_embedding = embed_a ( a ) + pos_embedding
    R_embedding = embed_R ( R ) + pos_embedding
    # interleave tokens as (R_1 , s_1 , a_1 , ... , R_K , s_K )
    input_embeds = stack ( R_embedding , s_embedding , a_embedding )
    # use transformer to get hidden states
    hidden_states = transformer ( input_embeds = input_embeds )
    # select hidden states for action prediction tokens
    a_hidden = unstack ( hidden_states ). actions
    # predict action
    return pred_a ( a_hidden )
# training loop
for (R , s , a , t ) in dataloader : # dims : ( batch_size , K, dim )
    a_preds = DecisionTransformer (R , s , a , t )
    loss = mean (( a_preds - a )**2) # L2 loss for continuous actions
    optimizer . zero_grad (); loss . backward (); optimizer . step ()
# evaluation loop
target_return = 1 # for instance , expert - level return
R , s , a , t , done = [ target_return ] , [ env . reset ()] , [] , [1] , False
while not done : # autoregressive generation / sampling
    # sample next action
    action = DecisionTransformer (R , s , a , t )[ -1] # for cts actions
    new_s , r , done , _ = env . step ( action )
    # append new tokens to sequence
    R = R + [ R [ -1] - r] # decrement returns -to -go with reward
    s , a , t = s + [ new_s ] , a + [ action ] , t + [ len ( R )]
    R , s , a , t = R [ - K :] , ... # only keep context length of K

NameError: name 'dataloader' is not defined

Neural Network for Q Learning Atari using convolution neural network
https://docs.pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html <- Building Neural Networks
https://docs.pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html <- training classifier
https://arxiv.org/pdf/1312.5602 <- confusion matrix sizing

In [4]:
# convolution neural network to work with atari
class QLearningNetwork(nn.Module):
    def __init__(self, num_actions):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size = 8, stride = 4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size = 4, stride = 2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size = 3, stride = 1),
            nn.ReLU()
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(7 * 7 * 64, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions)
        )

    def forward(self, x):
        x = x / 255.0
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        return self.fc_layers(x)

Temporal Difference Learning - Q learning agent

In [5]:
# TODO - Adapt to be a Q-learning agent <- Neural network
class TD_QLearningAgent(BaseAgent):
    def agent_init(self, agent_info={}):
        self.rand_generator = np.random.RandomState(agent_info.get("seed"))
        # Discount factor (gamma) to use in the updates.
        self.discount = agent_info.get("discount")
        # The learning rate or step size parameter (alpha) to use in updates.
        self.step_size = agent_info.get("step_size")

        self.num_states = agent_info.get("num_states")
        self.num_actions = agent_info.get("num_actions")

        # initialize the neural network

        # This line is drawn from PyTorch documentation
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.q_net = QLearningNetwork(self.num_actions).to(self.device)
        self.optimizer = optim.SGD(self.q_net.parameters(), lr = .001, momentum = .9)
        self.loss_fn = nn.MSELoss()

        # initialize the agent init state and agent to none
        self.state = None
        self.action = None
        
    def agent_start(self, state):
        tensor = torch.tensor(state, dtype = torch.float32, device = self.device).unsqueeze(0)
        q_values = self.q_net(tensor)
        action = torch.argmax(q_values, dim = 1).item()
        self.last_state = state
        self.last_action = action
        return action

    def agent_step(self, reward, state):
        # get the current and next state as tensor
        cur_state = torch.tensor(self.last_state, dtype = torch.float32, device = self.device).unsqueeze(0)
        next_state = torch.tensor(state, dtype = torch.float32, device = self.device).unsqueeze(0)

        q_values = self.q_net(cur_state)
        next_q = self.q_net(next_state)

        loss = self.loss_fn(q_values[0, self.last_action], (reward + self.discount * torch.max(next_q)).detach())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # e greedy next action
        action = torch.argmax(next_q, dim = 1).item()
        self.last_state = state
        self.last_action = action
        return action

    def agent_end(self, reward):
        # for agent_end compute just the last action 
        cur_state = torch.tensor(self.last_state, dtype = torch.float32, device = self.device).unsqueeze(0)
        q_values = self.q_net(cur_state)
        loss = self.loss_fn((q_values[0, self.last_action]), torch.tensor(reward, device = self.device))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def agent_cleanup(self):        
        self.last_state = None
        self.last_action = None

NameError: name 'BaseAgent' is not defined

Cliff Walk Test

In [6]:
# Create empty CliffWalkEnvironment class.
class CliffWalkEnvironment:
    def __init__(self):
        self.reward_state_term = (None, None, None)
    
    def env_init(self, env_info={}):
        reward = None
        state = None
        termination = None
        self.reward_state_term = (reward, state, termination)
        self.grid_h = env_info.get("grid_height", 4) 
        self.grid_w = env_info.get("grid_width", 12)
        self.start_loc = (self.grid_h - 1, 0)
        self.goal_loc = (self.grid_h - 1, self.grid_w - 1)
        self.cliff = [(self.grid_h - 1, i) for i in range(1, (self.grid_w - 1))]

    def state(self):
        return loc[0] * self.grid_w + loc[1]

    def env_start(self):
        """The first method called when the episode starts, called before the
        agent starts.
    
        Returns:
            The first state from the environment.
        """
        reward = 0
        # agent_loc will hold the current location of the agent
        self.agent_loc = self.start_loc
        # state is the one dimensional state representation of the agent location.
        state = self.state(self.agent_loc)
        termination = False
        self.reward_state_term = (reward, state, termination)
        return self.reward_state_term[1]

    def isInBounds(x, y, width, height):
        return 0 <= y < width and 0 <= x < height

    
    def env_step(self, action):
        """A step taken by the environment.
    
        Args:
            action: The action taken by the agent
    
        Returns:
            (float, state, Boolean): a tuple of the reward, state,
                and boolean indicating if it's terminal.
        """
        
        x, y = self.agent_loc
    
        # UP
        if action == 0:
            x = x - 1
            
        # LEFT
        elif action == 1:
            y = y - 1
            
        # DOWN
        elif action == 2:
            x = x + 1
            
        # RIGHT
        elif action == 3:
            y = y + 1
            
        else: 
            raise Exception(str(action) + " not in recognized actions [0: Up, 1: Left, 2: Down, 3: Right]!")
    
        # if the action takes the agent out-of-bounds
        # then the agent stays in the same state
        if not isInBounds(x, y, self.grid_w, self.grid_h):
            x, y = self.agent_loc
            
        # assign the new location to the environment object
        self.agent_loc = (x, y)
        
        # by default, assume -1 reward per step and that we did not terminate
        reward = -1
        terminal = False
        
        # assign the reward and terminal variables 
        # - if the agent falls off the cliff (don't forget to reset agent location!)
        # - if the agent reaches the goal state
        
        if self.agent_loc in self.cliff:
            reward = -100
            self.agent_loc = self.start_loc
    
        if (self.agent_loc == self.goal_loc):
            terminal = True
    
        # update
        self.reward_state_term = (reward, self.state(self.agent_loc), terminal)
        return self.reward_state_term

https://gymnasium.farama.org/api/spaces/ <- this is for the wrapper to the grid world

In [7]:
class DiscreteEnv:
    def __init__(self, n):
        self.n = n
    def sample(self):
        return np.random.randint(self.n)

class CliffWalk:
    def __init__(self, env):
        self.env = env
        # 4 actions
        self.action_space = DiscreteSpace(4)
        # grid height, grid width
        self.observation_space = DiscreteEnv(4 * 12)

    def step(self, action):
        result = self.env.env_step(action)


        reward = result[0]
        state = result[0]
        terminated = result[2]
        
        info = {}

        return state, reward, terminated, False, info


    def reset(self):
        state = self.env.env_start()
        info = {}
        return state, info
        

    

In [8]:
env_info = {"grid_height": 4, "grid_width": 12, "seed": 0}
agent_info = {"discount": 1, "step_size": 0.01, "seed": 0}

# The Optimal Policy that strides just along the cliff
policy = np.ones(shape=(env_info['grid_width'] * env_info['grid_height'], 4)) * 0.25
policy[36] = [1, 0, 0, 0]
for i in range(24, 35):
    policy[i] = [0, 0, 0, 1]
policy[35] = [0, 0, 1, 0]

agent_info.update({"policy": policy})

true_values_file = "optimal_policy_value_fn.npy"
_ = run_experiment(env_info, agent_info, num_episodes=5000, experiment_name="Policy Evaluation on Optimal Policy",
                   plot_freq=500, true_values_file=true_values_file)

plt.show()

NameError: name 'run_experiment' is not defined

Training Agents on Atari 

https://ale.farama.org/environments/complete_list/

In [9]:
env = gym.make("Breakout",obs_type = "rgb", frame_skip = 1, repeat_action_probability = 0, full_action_space = False)
env.reset()

# get the number of actions
num_actions = env.action_space.n

# get the actions associated with inputs
# for breakout
# 0 = Back
# 1 = launch
# 2 = left
# 3 = right
meaning = env.unwrapped.get_action_meanings()

# for testing 
obs, reward, terminated, truncated, info = env.step(0)

NameNotFound: Environment `Breakout` doesn't exist.

Compare Results

### HalfCheetah Env

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DecisionTransformerContinuous(nn.Module):
    def __init__(
        self,
        state_dim,
        act_dim,
        hidden_size,
        max_length=20,
        max_ep_len=4096,
        action_tanh=True, # HalfCheetah actions are usually bounded [-1, 1]
        **kwargs
    ):
        super().__init__()
        self.state_dim = state_dim
        self.act_dim = act_dim
        self.max_length = max_length
        self.hidden_size = hidden_size

        # 1. Embeddings (Changed from Discrete)
        # We use Linear layers because states and actions are continuous floats
        self.embed_t = nn.Embedding(max_ep_len, hidden_size)
        self.embed_s = nn.Linear(state_dim, hidden_size)
        self.embed_a = nn.Linear(act_dim, hidden_size)
        self.embed_rtg = nn.Linear(1, hidden_size)

        # 2. Transformer (GPT-2 style)
        self.embed_ln = nn.LayerNorm(hidden_size)
        
        # Note: You can use your existing CausalTransformer logic here, 
        # or standard PyTorch modules.
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_size, 
                nhead=4, 
                dim_feedforward=4*hidden_size, 
                dropout=0.1, 
                batch_first=True,
                norm_first=True
            ),
            num_layers=3
        )

        # 3. Prediction Heads
        # We only predict actions for the original DT paper
        self.predict_action = nn.Sequential(
            nn.Linear(hidden_size, act_dim),
            nn.Tanh() if action_tanh else nn.Identity()
        )

    def forward(self, states, actions, returns_to_go, timesteps):
        # states: (B, K, state_dim)
        # actions: (B, K, act_dim)
        # returns_to_go: (B, K, 1)
        # timesteps: (B, K)

        batch_size, seq_length = states.shape[0], states.shape[1]

        # Embeddings
        time_embeddings = self.embed_t(timesteps)
        s_embeddings = self.embed_s(states) + time_embeddings
        a_embeddings = self.embed_a(actions) + time_embeddings
        rtg_embeddings = self.embed_rtg(returns_to_go) + time_embeddings

        # Stack inputs: [RTG, State, Action]
        # Shape becomes (B, 3 * K, hidden_size)
        stacked_inputs = torch.stack(
            (rtg_embeddings, s_embeddings, a_embeddings), dim=1
        ).permute(0, 2, 1, 3).reshape(batch_size, 3 * seq_length, self.hidden_size)
        
        stacked_inputs = self.embed_ln(stacked_inputs)

        # Causal Masking
        # We need to mask future tokens. 
        mask = torch.triu(torch.ones(3 * seq_length, 3 * seq_length), diagonal=1).to(states.device)
        mask = mask.bool()

        # Transformer Forward
        # (Using pytorch's built-in mask handling)
        x = self.transformer(stacked_inputs, mask=mask, is_causal=True)

        # Reshape to retrieve representations corresponding to state embeddings
        # The pattern is [RTG, State, Action]. To predict action, we use the embedding at 'State'
        # Indices: 1, 4, 7, ... -> x[:, 1::3]
        x = x.reshape(batch_size, seq_length, 3, self.hidden_size)
        state_reps = x[:, :, 1, :] # Take the hidden state after the State token

        action_preds = self.predict_action(state_reps)

        return action_preds

In [11]:
def get_batch(batch_size, context_len, state_dim, act_dim, device):
    # This simulates a dataloader. 
    # In a real scenario, you would sample subsequences from a D4RL dataset buffer.
    
    # Generating synthetic data for demonstration
    # (B, K, Dim)
    states = torch.randn(batch_size, context_len, state_dim).to(device)
    actions = torch.randn(batch_size, context_len, act_dim).to(device)
    rtg = torch.randn(batch_size, context_len, 1).to(device)
    timesteps = torch.randint(0, 100, (batch_size, context_len)).to(device)
    
    return states, actions, rtg, timesteps

# Setup Environment
env_name = "HalfCheetah-v4"
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

print(f"Env: {env_name}, State Dim: {state_dim}, Action Dim: {act_dim}")

  logger.deprecation(


Env: HalfCheetah-v4, State Dim: 17, Action Dim: 6


In [12]:

# CONFIGURATION
BATCH_SIZE = 64
CONTEXT_LEN = 20
HIDDEN_SIZE = 128
LR = 1e-4
STEPS = 1000

# 1. Select Device (Crucial for Mac M4)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# 2. Initialize Model
model = DecisionTransformerContinuous(
    state_dim=state_dim,
    act_dim=act_dim,
    hidden_size=HIDDEN_SIZE,
    max_length=CONTEXT_LEN
).to(device)

optimizer = optim.AdamW(model.parameters(), lr=LR)
loss_fn = nn.MSELoss()

# 3. Train
model.train()
start_time = time.time()

for step in range(STEPS):
    # Fetch data (Replace this with real D4RL data loader later)
    states, true_actions, rtg, timesteps = get_batch(BATCH_SIZE, CONTEXT_LEN, state_dim, act_dim, device)

    # Forward
    action_preds = model(states, true_actions, rtg, timesteps)

    # Loss (Compare predicted action to real action)
    loss = loss_fn(action_preds, true_actions)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 100 == 0:
        print(f"Step {step}: Loss = {loss.item():.4f}")

print(f"Training finished in {time.time()-start_time:.2f}s")

Using device: mps




Step 0: Loss = 1.2753
Step 100: Loss = 1.0472
Step 200: Loss = 1.0224
Step 300: Loss = 1.0267
Step 400: Loss = 1.0039
Step 500: Loss = 0.9943
Step 600: Loss = 1.0207
Step 700: Loss = 1.0118
Step 800: Loss = 1.0466
Step 900: Loss = 1.0409
Training finished in 29.03s


In [13]:
def evaluate_dt(env, model, target_return=1000.0):
    model.eval()
    
    # Initial setup
    state, _ = env.reset()
    
    # We need to track the history for the context window
    states = torch.from_numpy(state).reshape(1, 1, state_dim).float().to(device)
    actions = torch.zeros((1, 1, act_dim)).float().to(device) # Placeholder for first action
    rtg = torch.tensor([[[target_return]]]).float().to(device)
    timesteps = torch.tensor([[0]]).long().to(device)
    
    episode_return = 0
    done = False
    
    while not done:
        # Crop context to max_length
        if states.shape[1] > 20:
            states = states[:, -20:, :]
            actions = actions[:, -20:, :]
            rtg = rtg[:, -20:, :]
            timesteps = timesteps[:, -20:]

        # Predict Action
        with torch.no_grad():
            action_preds = model(states, actions, rtg, timesteps)
            # Take the last predicted action
            action = action_preds[0, -1].cpu().numpy()

        # Step Env
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        episode_return += reward

        # Append to sequence
        # Note: We append the action we JUST took to the sequence for the next step
        cur_action_tensor = torch.from_numpy(action).reshape(1, 1, act_dim).float().to(device)
        next_state_tensor = torch.from_numpy(next_state).reshape(1, 1, state_dim).float().to(device)
        next_rtg = rtg[0, -1, 0] - reward
        next_rtg_tensor = next_rtg.reshape(1, 1, 1).to(device)
        next_timestep_tensor = (timesteps[0, -1] + 1).reshape(1, 1).to(device)

        actions = torch.cat([actions, cur_action_tensor], dim=1)
        states = torch.cat([states, next_state_tensor], dim=1)
        rtg = torch.cat([rtg, next_rtg_tensor], dim=1)
        timesteps = torch.cat([timesteps, next_timestep_tensor], dim=1)

    return episode_return

# Test run
print("Running Evaluation...")
score = evaluate_dt(env, model, target_return=500.0)
print(f"Episode Reward: {score}")

Running Evaluation...
Episode Reward: -55.328165904981475


In [None]:
# Just a random data loader: 
from torch.utils.data import DataLoader, TensorDataset

# Example tensors (replace these with your data)
R_tensor = torch.randn(1000, 1)
s_tensor = torch.randn(1000, 4)
a_tensor = torch.randn(1000, 1)
t_tensor = torch.randn(1000, 1)

dataset = TensorDataset(R_tensor, s_tensor, a_tensor, t_tensor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)