In [1]:
import os
import random
from dataclasses import dataclass

import numpy as np
import torch
from datasets import load_dataset, load_from_disk
from transformers import DecisionTransformerConfig, DecisionTransformerModel, Trainer, TrainingArguments


In [None]:
# from accelerate import Accelerator

# accelerator = Accelerator()
# device = accelerator.device


In [2]:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device


In [2]:
#torch.cuda.set_device(1)
#device = torch.device("cpu")
torch.cuda.current_device()


0

### Step 1: Load the input data

In [3]:
dataset = load_from_disk("data/dataset/")


In [4]:
#dataset = dataset['train']
state_mean = dataset['state_mean']
state_std = dataset['state_std']


In [5]:
dataset = dataset['train']


In [6]:
len(dataset), len(dataset[0])


(985, 4)

In [7]:
dataset[0].keys()


dict_keys(['actions', 'dones', 'observations', 'rewards'])

In [8]:
act_dim = len(dataset[0]['actions'][0])
act_dim


29

In [9]:
state_dim = len(dataset[0]['observations'][0])
state_dim


291

In [10]:
state_mean = state_mean[:state_dim]
state_std = state_std[:state_dim]


In [11]:
len(dataset[0]['observations'])


100

In [15]:
dataset[0]['actions'][:2]


[[-0.06046953797340393,
  0.1278674453496933,
  -0.0035203518345952034,
  -0.10394623875617981,
  0.055450163781642914,
  0.13418617844581604,
  -0.13469792902469635,
  -0.11420702934265137,
  0.07225970178842545,
  -0.11960054934024811,
  -0.08815392851829529,
  -0.15141969919204712,
  0.13859893381595612,
  0.08930575102567673,
  0.03528134524822235,
  -0.10486803948879242,
  -0.054753512144088745,
  0.0004257550463080406,
  -0.34499096870422363,
  0.03908330947160721,
  0.14673146605491638,
  -0.13041988015174866,
  0.1698758453130722,
  -0.03221307322382927,
  0.021925274282693863,
  -0.04682639241218567,
  0.06890705972909927,
  -0.03945721685886383,
  -0.26306822896003723],
 [-0.06046953797340393,
  0.1278674453496933,
  -0.0035203518345952034,
  -0.10394623875617981,
  0.055450163781642914,
  0.13418617844581604,
  -0.13469792902469635,
  -0.11420702934265137,
  0.07225970178842545,
  -0.11960054934024811,
  -0.08815392851829529,
  -0.15141969919204712,
  0.13859893381595612,
  

### Step 2: Define a custom DataCollator for the transformers Trainer class

In [12]:
@dataclass
class DecisionTransformerGymDataCollator:
    return_tensors: str = "pt"
    max_len: int = 20 #subsets of the episode we use for training
    state_dim: int = 291  # size of state space
    act_dim: int = 29  # size of action space
    max_ep_len: int = 985 # max episode length in the dataset
    scale: float = 1000.0  # normalization of rewards/returns
    #state_mean: np.array = None  # to store state means
    #state_std: np.array = None  # to store state stds
    p_sample: np.array = None  # a distribution to take account trajectory lengths
    n_traj: int = 0 # to store the number of trajectories in the dataset

    def __init__(self, dataset, state_mean, state_std) -> None:
        self.act_dim = len(dataset[0]['actions'][0])
        self.state_dim = len(dataset[0]['observations'][0])
        self.dataset = dataset
        self.state_mean = state_mean
        self.state_std = state_std
        # calculate dataset stats for normalization of states
        states = []
        traj_lens = []

        self.n_traj = len(self.dataset)

        traj_lens = [len(self.dataset[0]) for i in range(self.n_traj)]
        traj_lens = np.array(traj_lens)
        self.p_sample = traj_lens / sum(traj_lens)

    def _discount_cumsum(self, x, gamma):
        discount_cumsum = np.zeros_like(x)
        discount_cumsum[-1] = x[-1]
        for t in reversed(range(x.shape[0] - 1)):
            discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
        return discount_cumsum

    def __call__(self, features):
        batch_size = len(features)
        # this is a bit of a hack to be able to sample of a non-uniform distribution
        batch_inds = np.random.choice(
            np.arange(self.n_traj),
            size=batch_size,
            replace=True,
            p=self.p_sample,  # reweights so we sample according to timesteps
        )
        # a batch of dataset features
        s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []

        for ind in batch_inds:
            # for feature in features:
            feature = self.dataset[int(ind)]
            si = random.randint(0, len(feature["rewards"]) - 1)

            # get sequences from dataset
            s.append(np.array(feature["observations"][si : si + self.max_len]).reshape(1, -1, self.state_dim))
            a.append(np.array(feature["actions"][si : si + self.max_len]).reshape(1, -1, self.act_dim))
            r.append(np.array(feature["rewards"][si : si + self.max_len]).reshape(1, -1, 1))

            d.append(np.array(feature["dones"][si : si + self.max_len]).reshape(1, -1))
            timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= self.max_ep_len] = self.max_ep_len - 1  # padding cutoff
            rtg.append(
                self._discount_cumsum(np.array(feature["rewards"][si:]), gamma=1.0)[
                    : s[-1].shape[1]   # TODO check the +1 removed here
                ].reshape(1, -1, 1)
            )
            if rtg[-1].shape[1] < s[-1].shape[1]:
                print("if true")
                rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

            # padding and state + reward normalization
            tlen = s[-1].shape[1]
            s[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, self.state_dim)), s[-1]], axis=1)
            s[-1] = (s[-1] - self.state_mean) / self.state_std
            a[-1] = np.concatenate(
                [np.ones((1, self.max_len - tlen, self.act_dim)) * -10.0, a[-1]],
                axis=1,
            )
            r[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), r[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, self.max_len - tlen)) * 2, d[-1]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), rtg[-1]], axis=1) / self.scale
            timesteps[-1] = np.concatenate([np.zeros((1, self.max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, self.max_len - tlen)), np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).float()
        a = torch.from_numpy(np.concatenate(a, axis=0)).float()
        r = torch.from_numpy(np.concatenate(r, axis=0)).float()
        d = torch.from_numpy(np.concatenate(d, axis=0))
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).float()
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).long()
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).float()

        return {
            "states": s,
            "actions": a,
            "rewards": r,
            "returns_to_go": rtg,
            "timesteps": timesteps,
            "attention_mask": mask,
        }


In [13]:
collator = DecisionTransformerGymDataCollator(dataset, state_mean, state_std)


In [24]:
# Check collator outputs

# ret = collator.__call__(dataset)
# ret['states'][0][:2], len(ret['states'][0][0])


### Step 3: Extend Decision Transformer Model to include a loss function

In order to train the model with the ðŸ¤— trainer class, we first need to ensure the dictionary it returns contains a loss, in this case L-2 norm of the models action predictions and the targets.

We achieve this by making a TrainableDT class, which inherits from the Decision Transformer model.

In [14]:
import random
from dataclasses import dataclass

import numpy as np
import torch

from transformers import DecisionTransformerConfig, DecisionTransformerModel, Trainer, TrainingArguments


In [15]:
class TrainableDT(DecisionTransformerModel):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, **kwargs):
        output = super().forward(**kwargs)
        # add the DT loss
        action_preds = output[1]
        action_targets = kwargs["actions"]
        attention_mask = kwargs["attention_mask"]
        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        action_targets = action_targets.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]

        loss = torch.mean((action_preds - action_targets) ** 2)

        return {"loss": loss}

    def original_forward(self, **kwargs):
        return super().forward(**kwargs)


In [16]:
config = DecisionTransformerConfig(state_dim=collator.state_dim, act_dim=collator.act_dim)
model = TrainableDT(config)


### Step 4: Define training hyperparameters and train the model
Here, we define the training hyperparameters and our Trainer class that we'll use to train our Decision Transformer model. The transformers Trainer class required a number of arguments, defined in the TrainingArguments class. We use the same hyperparameters are in the authors original implementation, but train for fewer iterations. 

This step takes about an hour, so you may leave it running. Note the authors train for at least 3 hours, so the results presented here are not as performant as the models hosted on the ðŸ¤— hub.

In [17]:
os.environ["WANDB_DISABLED"] = "true" # we disable weights and biases logging for this tutorial


In [18]:
training_args = TrainingArguments(
    output_dir="output/",
    remove_unused_columns=False,
    num_train_epochs=120,
    per_device_train_batch_size=64,
    learning_rate=1e-4,
    weight_decay=1e-4,
    warmup_ratio=0.1,
    optim="adamw_torch",
    max_grad_norm=0.25,
)



Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [19]:
# https://discuss.huggingface.co/t/setting-specific-device-for-trainer/784/3
training_args.device


device(type='cuda', index=0)

In [20]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collator,
)

trainer.train()


  0%|          | 0/1920 [00:00<?, ?it/s]

{'loss': 0.0095, 'grad_norm': 0.011781188659369946, 'learning_rate': 8.217592592592593e-05, 'epoch': 31.25}
{'loss': 0.001, 'grad_norm': 0.012242179363965988, 'learning_rate': 5.3240740740740744e-05, 'epoch': 62.5}
{'loss': 0.0008, 'grad_norm': 0.006828837562352419, 'learning_rate': 2.4305555555555558e-05, 'epoch': 93.75}
{'train_runtime': 52.8741, 'train_samples_per_second': 2235.499, 'train_steps_per_second': 36.313, 'train_loss': 0.0031153700469682614, 'epoch': 120.0}


TrainOutput(global_step=1920, training_loss=0.0031153700469682614, metrics={'train_runtime': 52.8741, 'train_samples_per_second': 2235.499, 'train_steps_per_second': 36.313, 'total_flos': 2799337488696000.0, 'train_loss': 0.0031153700469682614, 'epoch': 120.0})

In [21]:
trainer.save_model('trained_models')
