# Transformer ILQR Training

This notebook demonstrates the step-by-step training process of the Transformer ILQR model. The notebook:

- Loads ILQR log data from a pickle file
- Processes the data into a pandas DataFrame
- Splits the data into training and testing sets
- Trains and saves the Transformer model using the training set

After training, the loss history is available for external plotting.

In [3]:
import sys
import os
import pandas as pd
import pickle

# Import required functions from transforerm_training.py
import sys
import os
module_path = os.path.abspath(os.path.join(os.getcwd()))
if module_path not in sys.path:
    sys.path.insert(0, module_path)
print(f"Module path: {module_path}")

from transformer_training import load_ilqr_logs, process_ilqr_logs, split_data, train_transformer

print('Imports successful.')

Module path: /Users/justin/PycharmProjects/quattro-transformer-ilqr/examples/cartpole/training
Imports successful.


## 1. Load ILQR Logs
Specify the path to your ILQR logs pickle file.

In [6]:
# Define the path to the combined ILQR logs file
log_file_path = "combined_ilqr_logs_range_-0.500_0.500_angle_-0.500_0.500.pkl"

ilqr_logs = load_ilqr_logs(log_file_path)
if ilqr_logs is None:
    raise ValueError('Failed to load ILQR logs')


Loading ILQR logs from combined_ilqr_logs_range_-0.500_0.500_angle_-0.500_0.500.pkl...
Loaded 14694 log entries.


## 2. Process ILQR Logs
Convert the raw log entries into a pandas DataFrame for further processing.

In [None]:
df = process_ilqr_logs(ilqr_logs)
print('Sample data from logs:')
print(df.head())

## 3. Split the Data
Shuffle and split the DataFrame into training and testing sets.

In [8]:
train_df, test_df = split_data(df, train_fraction=0.8, random_state=42)

print('Training set size:', len(train_df))
print('Test set size:', len(test_df))

Training set size: 11755
Test set size: 2939
Training set size: 11755
Test set size: 2939


## 4. Train and Save the Transformer Model
Train the Transformer ILQR model using the training DataFrame. You can adjust hyperparameters as needed.

In [9]:
# Set training hyperparameters
num_epochs = 1       # Adjust number of epochs
batch_size = 128     
learning_rate = 1e-4 
prompt_len = 1    
d_model = 256        
nhead = 8            

model_wrapper = train_transformer(train_df,
                                  num_epochs=num_epochs,
                                  batch_size=batch_size,
                                  learning_rate=learning_rate,
                                  prompt_len=prompt_len,
                                  d_model=d_model,
                                  nhead=nhead)

print('Training completed.')

# Save the trained model
model_wrapper.save("cartpole")
print('Model saved successfully.')

Transformer Loaded. Device is: mps


Processing dataset: 100%|██████████| 11755/11755 [00:00<00:00, 50134.17it/s]


TransformerPredictor(
  (state_embed): Linear(in_features=4, out_features=256, bias=True)
  (control_embed): Linear(in_features=5, out_features=256, bias=True)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (output_linear): Linear(in_features=256, out_features=5, bias=True)
  (transformer_decoder): TransformerEncoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=256, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, 

                                                                         

Epoch 1/1, Train Loss: 0.430058
Training completed.




## 5. Loss History
The training loss history is stored in `model_wrapper.train_loss_history`.
You can plot these loss curves externally.

In [11]:
print('Train loss history:', model_wrapper.train_loss_history)
if hasattr(model_wrapper, 'test_loss_history'):
    print('Test loss history:', model_wrapper.test_loss_history)

Train loss history: [0.4300576684754841]
