In [9]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# Load and preprocess the original data
message_org = pd.read_csv(
    filepath_or_buffer='LOBSTER_SampleFile_AAPL_2012-06-21_5/AAPL_2012-06-21_34200000_57600000_message_5.csv',
    names=['Time', 'Type', 'Order ID', 'Size', 'Price', 'Direction']
    ).drop(['Time','Order ID'], axis=1)

orderbook_org = pd.read_csv(
    filepath_or_buffer='LOBSTER_SampleFile_AAPL_2012-06-21_5/AAPL_2012-06-21_34200000_57600000_orderbook_5.csv',
    names=['A1P', 'A1V', 'B1P', 'B1V', 'A2P', 'A2V', 'B2P', 'B2V', 'A3P', 'A3V', 'B3P', 'B3V', 'A4P', 'A4V', 'B4P', 'B4V',
           'A5P', 'A5V', 'B5P', 'B5V']
           ).loc[:, ['A2P', 'A2V', 'A1P', 'A1V', 'B1P', 'B1V', 'B2P', 'B2V']]

orderbook_org.drop_duplicates(inplace=True)
message_org = message_org.loc[orderbook_org.index]
orderbook_org.reset_index(inplace=True, drop=True)
message_org.reset_index(inplace=True, drop=True)

In [10]:
# Initialize the LOB state, agent's cash, and inventory
initial_cash = 1e6
orderbook_org['Cash'] = initial_cash  # Initial cash in hand
initial_inv = 1e4
orderbook_org['Inventory'] = initial_inv   # Initial inventory

execute_index = message_org['Type'] == 4
cash_change = (- (message_org[execute_index]['Direction'] * 
                  message_org[execute_index]['Size'] * 
                  message_org[execute_index]['Price'] / 1e4)).cumsum()
inventory_change = (message_org[execute_index]['Direction'] * 
                    message_org[execute_index]['Size']).cumsum()

orderbook_org['Cash'] += cash_change 
orderbook_org['Inventory'] += inventory_change

orderbook_org.loc[0,'Cash'] = initial_cash
orderbook_org.loc[0,'Inventory'] = initial_inv
orderbook_org.ffill(inplace=True)

orderbook_org = pd.concat([message_org, orderbook_org],axis=1)

del initial_cash, initial_inv, execute_index, cash_change, inventory_change

In [11]:
# Normalize orderbook data and transform to tensor
orderbook_norm = orderbook_org.copy()
for column in ['Size', 'Price', 'A2P', 'A2V', 'A1P', 'A1V', 'B1P', 'B1V', 'B2P', 'B2V', 'Cash', 'Inventory']:
    orderbook_norm[column] = (orderbook_norm[column] - orderbook_org[column].mean()) / orderbook_org[column].std()

lob_data = torch.tensor(
    orderbook_norm.to_numpy(),
    dtype=torch.float32
)

# Placeholder for action data preparation
# actions_data = torch.tensor(
#     message_org.to_numpy(),
#     dtype=torch.float32
# )


In [12]:
# Define the Autoencoder architecture
class LOBEncoder(nn.Module):
    def __init__(self, input_dim, state_dim):
        super(LOBEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, state_dim)
        )
    
    def forward(self, x):
        return self.encoder(x)

class LOBDecoder(nn.Module):
    def __init__(self, state_dim, output_dim):  # action_dim,
        super(LOBDecoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(state_dim, 64),  # + action_dim
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )
    
    def forward(self, s):  # ,u
        # s = torch.cat((s, u), dim=1)
        return self.decoder(s)

class LOBPredictor(nn.Module):
    def __init__(self, input_dim, state_dim, output_dim):  # action_dim,
        super(LOBPredictor, self).__init__()
        self.encoder = LOBEncoder(input_dim, state_dim)
        self.decoder = LOBDecoder(state_dim, output_dim)  # ,action_dim
    
    def forward(self, o):  # ,u
        s = self.encoder(o)
        o_next_pred = self.decoder(s)  # ,u
        return o_next_pred

# Example dimensions
input_dim = lob_data.shape[1]  # Observation dimension
state_dim = 8  # Agent state dimension
# action_dim = 4  # Action dimension (Type, Size, Price, Direction)
output_dim = lob_data.shape[1]  # Next observation dimension (same as input_dim)

predictor = LOBPredictor(input_dim, state_dim, output_dim)  # ,action_dim

# Hyperparameters
learning_rate = 1e-3
num_epochs = 200
batch_size = 1024

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(predictor.parameters(), lr=learning_rate)


In [13]:
# Prepare input (observations), actions, and target (next observations)
observations = lob_data[:-1, :]  # All except the last time step
# actions = actions_data[:-1, :]  # Corresponding actions for each observation
next_observations = lob_data[1:, :]  # All except the first time step

# DataLoader for batching
dataset = torch.utils.data.TensorDataset(observations, next_observations)  # ,actions
train_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)

# Training loop
for epoch in range(num_epochs):
    for obs, next_obs in train_loader:  # ,act
        # Forward pass
        predicted_next_obs = predictor(obs)  # ,act
        loss = criterion(predicted_next_obs, next_obs)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')


Epoch [1/200], Loss: 0.6594
Epoch [2/200], Loss: 0.6321
Epoch [3/200], Loss: 0.5388
Epoch [4/200], Loss: 0.5100
Epoch [5/200], Loss: 0.4319
Epoch [6/200], Loss: 0.4473
Epoch [7/200], Loss: 0.4192
Epoch [8/200], Loss: 0.4141
Epoch [9/200], Loss: 0.4251
Epoch [10/200], Loss: 0.4036
Epoch [11/200], Loss: 0.4153
Epoch [12/200], Loss: 0.3970
Epoch [13/200], Loss: 0.4298
Epoch [14/200], Loss: 0.3868
Epoch [15/200], Loss: 0.3722
Epoch [16/200], Loss: 0.3717
Epoch [17/200], Loss: 0.3766
Epoch [18/200], Loss: 0.3833
Epoch [19/200], Loss: 0.3892
Epoch [20/200], Loss: 0.3732
Epoch [21/200], Loss: 0.3686
Epoch [22/200], Loss: 0.3636
Epoch [23/200], Loss: 0.3628
Epoch [24/200], Loss: 0.3626
Epoch [25/200], Loss: 0.3662
Epoch [26/200], Loss: 0.3647
Epoch [27/200], Loss: 0.4014
Epoch [28/200], Loss: 0.3701
Epoch [29/200], Loss: 0.3788
Epoch [30/200], Loss: 0.3589
Epoch [31/200], Loss: 0.3954
Epoch [32/200], Loss: 0.3606
Epoch [33/200], Loss: 0.3741
Epoch [34/200], Loss: 0.3590
Epoch [35/200], Loss: 0

In [15]:
# Evaluation and Visualization
# Get a batch of data
data_iter = iter(train_loader)
obs_data, next_obs_data = next(data_iter)  # ,act_data

# Predict the next observation using the predictor
with torch.no_grad():
    predicted_next_obs_data = predictor(obs_data)  # ,act_data

# Denormalize the data for visualization
next_obs_data = pd.DataFrame(next_obs_data.numpy(), columns=list(orderbook_org.columns))
predicted_next_obs_data = pd.DataFrame(predicted_next_obs_data.numpy(), columns=list(orderbook_org.columns))

for column in ['Size', 'Price', 'A2P', 'A2V', 'A1P', 'A1V', 'B1P', 'B1V', 'B2P', 'B2V', 'Cash', 'Inventory']:
    next_obs_data[column] = next_obs_data[column] * orderbook_org[column].std() + orderbook_org[column].mean()
    predicted_next_obs_data[column] = predicted_next_obs_data[column] * orderbook_org[column].std() + orderbook_org[column].mean()

# Visualization (commented out for non-interactive environments)
# import matplotlib.pyplot as plt
# # Plot next observation (target) vs predicted next observation
# fig, ax = plt.subplots(2, 1, figsize=(10, 5))
# # Plot next observation (target)
# ax[0].plot(next_obs_data.values.flatten(), label='Next Observation (Target)')
# ax[0].set_title('Next Observation (Target)')
# # Plot predicted next observation
# ax[1].plot(predicted_next_obs_data.values.flatten(), label='Predicted Next Observation')
# ax[1].set_title('Predicted Next Observation')
# plt.tight_layout()
# plt.show()

In [16]:
next_obs_data.head(10)

Unnamed: 0,Type,Size,Price,Direction,A2P,A2V,A1P,A1V,B1P,B1V,B2P,B2V,Cash,Inventory
0,1.0,18.0,5853200.0,1.0,5859800.0,200.0,5859400.0,200.0,5853300.0,18.0,5853200.0,18.000015,999999.875,10000.0
1,1.0,18.0,5859100.0,-1.0,5859400.0,200.0,5859100.0,18.0,5853300.0,18.0,5853200.0,18.000015,999999.875,10000.0
2,1.0,18.0,5859200.0,-1.0,5859200.0,18.000008,5859100.0,18.0,5853300.0,18.0,5853200.0,18.000015,999999.875,10000.0
3,3.0,18.0,5853200.0,1.0,5859200.0,18.000008,5859100.0,18.0,5853300.0,18.0,5853000.0,150.0,999999.875,10000.0
4,3.0,18.0,5859100.0,-1.0,5859300.0,118.000008,5859200.0,18.0,5853300.0,18.0,5853000.0,150.0,999999.875,10000.0
5,3.0,18.0,5859300.0,-1.0,5859300.0,100.000015,5859200.0,18.0,5853300.0,18.0,5853000.0,150.0,999999.875,10000.0
6,3.0,18.0,5859200.0,-1.0,5859400.0,200.0,5859300.0,100.0,5853300.0,18.0,5853000.0,150.0,999999.875,10000.0
7,1.0,18.0,5853600.0,1.0,5859400.0,200.0,5859300.0,100.0,5853600.0,18.0,5853300.0,18.000015,999999.875,10000.0
8,1.0,18.0,5853500.0,1.0,5859400.0,200.0,5859300.0,100.0,5853600.0,18.0,5853500.0,18.000015,999999.875,10000.0
9,1.0,20.000004,5857300.0,1.0,5859400.0,200.0,5859300.0,100.0,5857300.0,20.0,5853600.0,18.000015,999999.875,10000.0


In [17]:
predicted_next_obs_data.head(10)

Unnamed: 0,Type,Size,Price,Direction,A2P,A2V,A1P,A1V,B1P,B1V,B2P,B2V,Cash,Inventory
0,1.208879,65.776443,5855409.0,0.419337,5858183.5,104.298592,5857490.5,188.888184,5854522.5,57.072754,5854189.5,131.437729,1647858.875,9171.865234
1,1.203244,61.185112,5854856.0,0.426575,5857645.0,101.95372,5857010.5,200.67746,5853909.0,31.210205,5853615.5,99.275055,1336266.625,9911.882812
2,1.230187,45.088951,5857642.0,-0.305384,5859181.5,63.849068,5858943.0,18.775482,5855649.5,31.570694,5855317.0,131.716263,1940450.875,8224.232422
3,1.276087,32.153458,5857398.0,-0.358538,5858546.5,10.955643,5858861.5,-11.91655,5855326.5,72.170029,5855076.0,192.6828,2107776.0,7786.164062
4,1.656815,34.892467,5856111.5,0.316319,5857998.0,4.302826,5857900.5,14.694473,5854874.5,72.781281,5854530.0,182.284088,960680.375,9872.416016
5,1.603282,44.189919,5856129.0,-0.602271,5856877.0,58.937805,5856688.0,26.926727,5854205.0,66.442596,5853507.0,226.530273,1286169.875,9402.919922
6,1.647295,41.829876,5856073.5,-0.612603,5856773.0,54.986298,5856628.0,22.759575,5854166.5,69.211823,5853441.0,234.286667,1067542.125,9763.589844
7,1.294078,59.92926,5857052.5,-0.292061,5858614.5,98.145401,5858026.5,55.150826,5855267.0,48.687637,5854726.5,183.619171,1718338.125,8776.001953
8,1.242853,59.487518,5855772.0,0.403244,5858409.5,92.454849,5857790.5,135.750061,5854667.5,34.09726,5854342.0,103.778664,1663115.125,9340.094727
9,1.249893,59.697388,5855744.0,0.403932,5858374.0,91.993561,5857766.5,135.591354,5854645.5,34.599594,5854316.5,104.001938,1625157.375,9406.420898
