In [2]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from tensordict import MemoryMappedTensor, TensorDict

In [57]:
state_emb = torch.nn.Linear(10, 2)(torch.randn(3,5,10).reshape(1,-1, 10))
# return_emb = torch.nn.Linear(1, 2)(torch.randn(3,5,1).reshape(1,-1, 1))
return_emb = torch.nn.Linear(1, 2)(torch.randn(3,5,1).reshape(1,-1, 1))
action_emb = torch.nn.Linear(1, 2)(torch.randn(3,5,1).reshape(1,-1, 1))

print(state_emb, return_emb, action_emb)

tensor([[[-0.0378, -0.2457],
         [ 0.9079,  0.9024],
         [ 0.5429, -1.0382],
         [ 0.8087, -0.1141],
         [-0.4617, -1.7037],
         [-0.2021, -0.6088],
         [-0.6994, -1.3875],
         [-0.0610,  0.0864],
         [ 0.6182, -1.9175],
         [-0.5597, -0.6317],
         [ 0.7547, -1.0077],
         [ 0.4440, -0.0141],
         [ 1.0847, -0.1549],
         [-0.6506, -1.3555],
         [-0.2390,  0.4642]]], grad_fn=<ViewBackward0>) tensor([[[-0.0869,  0.9546],
         [ 1.4300,  0.3155],
         [ 0.8465,  0.5613],
         [ 0.2271,  0.8223],
         [ 0.4929,  0.7103],
         [ 0.1141,  0.8699],
         [ 0.5938,  0.6678],
         [-0.4756,  1.1184],
         [ 0.1470,  0.8561],
         [-1.1627,  1.4079],
         [ 2.0324,  0.0617],
         [-0.6266,  1.1820],
         [ 1.4739,  0.2970],
         [ 2.1132,  0.0276],
         [ 2.2273, -0.0205]]], grad_fn=<ViewBackward0>) tensor([[[-0.4875, -0.8032],
         [-0.5098, -0.4745],
         [-0.5289,

In [58]:
res = torch.stack((state_emb,return_emb,action_emb), dim=1).permute(0,2,1,3).reshape(3, 3*5, 2)
res

tensor([[[-0.0378, -0.2457],
         [-0.0869,  0.9546],
         [-0.4875, -0.8032],
         [ 0.9079,  0.9024],
         [ 1.4300,  0.3155],
         [-0.5098, -0.4745],
         [ 0.5429, -1.0382],
         [ 0.8465,  0.5613],
         [-0.5289, -0.1935],
         [ 0.8087, -0.1141],
         [ 0.2271,  0.8223],
         [-0.5602,  0.2681],
         [-0.4617, -1.7037],
         [ 0.4929,  0.7103],
         [-0.5155, -0.3902]],

        [[-0.2021, -0.6088],
         [ 0.1141,  0.8699],
         [-0.5927,  0.7480],
         [-0.6994, -1.3875],
         [ 0.5938,  0.6678],
         [-0.5058, -0.5332],
         [-0.0610,  0.0864],
         [-0.4756,  1.1184],
         [-0.5554,  0.1975],
         [ 0.6182, -1.9175],
         [ 0.1470,  0.8561],
         [-0.5284, -0.2002],
         [-0.5597, -0.6317],
         [-1.1627,  1.4079],
         [-0.5834,  0.6108]],

        [[ 0.7547, -1.0077],
         [ 2.0324,  0.0617],
         [-0.4558, -1.2700],
         [ 0.4440, -0.0141],
         [

In [98]:
res[0].shape

torch.Size([15, 2])

In [99]:
mask[0].shape

torch.Size([15])

In [55]:
torch.nn.LayerNorm(2)(res).shape

torch.Size([3, 15, 2])

In [None]:
class DecisionTransformer(nn.Module):
    def __init__(self,
                 state_dim=51,
                 action_dim=1,
                 max_context_length=48,
                 max_ep_length=48,                 
                 model_dim=128
                 ):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_context_length = max_context_length
        self.max_ep_length = max_ep_length
        self.model_dim = model_dim


        self.embed_timestep = nn.Embedding(self.max_ep_length, self.model_dim)
        self.embed_return = torch.nn.Linear(1, self.model_dim)
        self.embed_state = torch.nn.Linear(self.state_dim, self.model_dim)
        self.embed_action = torch.nn.Linear(self.action_dim, self.model_dim)
        self.embed_ln = nn.LayerNorm(self.model_dim)

        self.predict_action = nn.Sequential(
            nn.Linear(self.model_dim, self.action_dim),
            nn.Tanh()
        )
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=self.model_dim,
            nhead=8,
            batch_first=True,
        )
        self.transformer = nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=6)

    def forward(self, states, actions, returns_to_go, timesteps, padding_mask=None):
        batch_size, seq_length = states.shape[0], states.shape[1]

        if padding_mask is None:
            padding_mask = torch.ones((batch_size, seq_length), dtype=torch.float32)
        
        state_embeddings = self.embed_state(states)
        action_embeddings = self.embed_action(actions)
        returns_embeddings = self.embed_return(returns_to_go)
        time_embeddings = self.embed_timestep(timesteps)

        # time embeddings are treated similar to positional embeddings
        state_embeddings = state_embeddings + time_embeddings
        action_embeddings = action_embeddings + time_embeddings
        returns_embeddings = returns_embeddings + time_embeddings

        stacked_inputs = torch.stack(
            (returns_embeddings, state_embeddings, action_embeddings), dim=1
        ).permute(0, 2, 1, 3).reshape(batch_size, 3*seq_length, self.model_dim)
        stacked_inputs = self.embed_ln(stacked_inputs)

        dummy_memory = torch.zeros(1, seq_length, self.model_dim)

        causal_mask = torch.triu(torch.full((3*seq_length, 3*seq_length), float('-inf')), diagonal=1)

        stacked_padding_mask = torch.stack((padding_mask,padding_mask,padding_mask), dim=1).permute(0,2,1).reshape(batch_size,3*seq_length)

        x = self.transformer(tgt=stacked_inputs,
                             memory=dummy_memory, 
                             tgt_mask=causal_mask,
                             tgt_key_padding_mask=stacked_padding_mask)
        x = x.reshape(batch_size, seq_length, 3, self.model_dim).permute(0, 2, 1, 3)

        return self.predict_action(x[:,1])

    def get_action(self, states, actions, rtg, timesteps):

        # Add batch dimension and reshape to [1, seq_len, state_dim] so input matches Transformer input format
        states = states.reshape(1, -1, self.state_dim)
        actions = actions.reshape(1, -1, self.action_dim)
        rtg = rtg.reshape(1, -1, 1)
        timesteps = timesteps.reshape(1, -1)

        if self.max_context_length is not None:
            states = states[:,-self.max_context_length:]
            actions = actions[:,-self.max_context_length:]
            rtg = rtg[:,-self.max_context_length:]
            timesteps = timesteps[:,-self.max_context_length:]

            # pad all tokens to sequence length
            padding_mask = torch.cat([torch.zeros(self.max_context_length-states.shape[1], dtype=torch.float32), 
                                      torch.ones(states.shape[1],dtype=torch.long)]).reshape(1,-1)
            states = torch.cat([torch.zeros((states.shape[0], self.max_context_length-states.shape[1], self.state_dim)), 
                                states], dim=1)
            actions = torch.cat([torch.zeros((actions.shape[0], self.max_context_length - actions.shape[1], self.action_dim)),
                                 actions], dim=1)
            rtg = torch.cat([torch.zeros((rtg.shape[0], self.max_context_length-rtg.shape[1], 1)), 
                             rtg], dim=1)
            timesteps = torch.cat([torch.zeros((timesteps.shape[0], self.max_context_length-timesteps.shape[1]), dtype=torch.long),
                                   timesteps], dim=1)
        else:
            padding_mask = None

        action_preds = self.forward(states, actions, rtg, timesteps, padding_mask)
        
        return action_preds[0,-1]

In [12]:
states = torch.randn(1, 48, 51)
actions = torch.randn(1, 48, 1)
returns_to_go = torch.randn(1, 48, 1)
timesteps = torch.arange(48).reshape(1, 48)

In [55]:
DecisionTransformer().get_action(states, actions, returns_to_go, timesteps)

tensor([0.4032], grad_fn=<SelectBackward0>)

# Prepare Train Data

In [None]:
train_td_array = torch.load('../outputs/battery_optimization_solution_1.pt', weights_only=False)

In [None]:
for elem in train_td_array:
    rtg =  torch.flip(torch.cumsum(torch.flip(elem['next', 'reward'], dims=[0]), dim=0), dims=[0])
    rtg = torch.concat((rtg, torch.zeros(1,1)))
    elem['return_to_go'] = rtg[0:48]
    elem['next', 'return_to_go'] = rtg[1:49]