<a href="https://colab.research.google.com/github/thegallier/timeseries/blob/main/timeseries2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

# Define the Dataset
class MaskedOrderDataset(Dataset):
    def __init__(self, num_samples, mask_prob=0.15, max_seq_len=50):
        self.num_samples = num_samples
        self.N_sec_ids = 100  # Number of unique securities
        self.N_distance_bins = 10  # Number of distance bins
        self.mask_prob = mask_prob
        self.max_seq_len = max_seq_len

        # Generate synthetic data
        self.timestamps = torch.randint(0, 1000, (num_samples,))
        self.security_ids = torch.randint(0, self.N_sec_ids, (num_samples,))
        self.buy_sell = torch.randint(0, 2, (num_samples,))  # 0: buy, 1: sell
        self.add_modify_delete = torch.randint(0, 3, (num_samples,))  # 0: add, 1: modify, 2: delete
        self.quantity = torch.randint(1, 1000, (num_samples,)).float()
        self.price = torch.rand(num_samples) * 100  # Prices between 0 and 100
        self.distance = torch.randint(0, self.N_distance_bins, (num_samples,))

        # Sort the data by timestamps (ascending)
        sorted_indices = torch.argsort(self.timestamps)
        self.timestamps = self.timestamps[sorted_indices]
        self.security_ids = self.security_ids[sorted_indices]
        self.buy_sell = self.buy_sell[sorted_indices]
        self.add_modify_delete = self.add_modify_delete[sorted_indices]
        self.quantity = self.quantity[sorted_indices]
        self.price = self.price[sorted_indices]
        self.distance = self.distance[sorted_indices]

        # Define mask indices
        self.security_id_mask = self.N_sec_ids
        self.buy_sell_mask = 2
        self.add_modify_delete_mask = 3
        self.distance_mask = self.N_distance_bins

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Get sequence from idx to end, up to max_seq_len
        end_idx = min(idx + self.max_seq_len, self.num_samples)
        seq_len = end_idx - idx
        sequence = {
            'timestamp': self.timestamps[idx:end_idx],
            'security_id': self.security_ids[idx:end_idx],
            'buy_sell': self.buy_sell[idx:end_idx],
            'add_modify_delete': self.add_modify_delete[idx:end_idx],
            'quantity': self.quantity[idx:end_idx],
            'price': self.price[idx:end_idx],
            'distance': self.distance[idx:end_idx],
        }

        # Apply masking logic to the first element (idx)
        sample = {key: sequence[key][0] for key in sequence}
        masked_sample = sample.copy()
        target = {}

        for key in ['security_id', 'buy_sell', 'add_modify_delete', 'quantity', 'price', 'distance']:
            if torch.rand(1).item() < self.mask_prob:
                # Mask the feature
                target[key] = sample[key]
                if key == 'security_id':
                    masked_sample[key] = self.security_id_mask
                elif key == 'buy_sell':
                    masked_sample[key] = self.buy_sell_mask
                elif key == 'add_modify_delete':
                    masked_sample[key] = self.add_modify_delete_mask
                elif key == 'distance':
                    masked_sample[key] = self.distance_mask
                elif key in ['quantity', 'price']:
                    masked_sample[key] = 0.0  # For continuous features, use 0.0 as masked value
            else:
                target[key] = None  # Not masked

        # Return masked_sample, target, and the sequence starting from idx
        return masked_sample, target, sequence

# Define the Model
class OrderModel(nn.Module):
    def __init__(self, N_sec_ids, N_distance_bins, embedding_dim=32, hidden_dim=64):
        super(OrderModel, self).__init__()

        # Mask indices
        self.security_id_mask = N_sec_ids
        self.buy_sell_mask = 2
        self.add_modify_delete_mask = 3
        self.distance_mask = N_distance_bins

        # Embedding layers for categorical features
        self.security_id_embedding = nn.Embedding(N_sec_ids + 1, embedding_dim, padding_idx=self.security_id_mask)
        self.buy_sell_embedding = nn.Embedding(3, embedding_dim, padding_idx=self.buy_sell_mask)  # 0,1,2(mask)
        self.add_modify_delete_embedding = nn.Embedding(4, embedding_dim, padding_idx=self.add_modify_delete_mask)  # 0,1,2,3(mask)
        self.distance_embedding = nn.Embedding(N_distance_bins + 1, embedding_dim, padding_idx=self.distance_mask)

        # Linear layers for continuous features
        self.quantity_linear = nn.Linear(1, embedding_dim)
        self.price_linear = nn.Linear(1, embedding_dim)

        # For timestamp (we don't mask timestamp)
        self.timestamp_linear = nn.Linear(1, embedding_dim)

        # LSTM for sequences
        self.lstm_input_dim = embedding_dim * 7  # Number of features
        self.lstm = nn.LSTM(input_size=self.lstm_input_dim, hidden_size=hidden_dim, batch_first=True)

        # Final fully connected layers
        self.fc1 = nn.Linear(embedding_dim * 7 + hidden_dim, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 128)
        self.relu = nn.ReLU()

        # Output heads
        self.security_id_head = nn.Linear(128, N_sec_ids)
        self.buy_sell_head = nn.Linear(128, 2)
        self.add_modify_delete_head = nn.Linear(128, 3)
        self.distance_head = nn.Linear(128, N_distance_bins)
        self.quantity_head = nn.Linear(128, 1)
        self.price_head = nn.Linear(128, 1)

    def forward(self, x, seq_data, seq_lengths):
        # x is a dictionary of features for the masked sample
        # seq_data is a dictionary of sequences
        # seq_lengths is a list of sequence lengths

        # Process the masked sample
        timestamp = x['timestamp'].unsqueeze(1).float()
        security_id = x['security_id']
        buy_sell = x['buy_sell']
        add_modify_delete = x['add_modify_delete']
        quantity = x['quantity'].unsqueeze(1)
        price = x['price'].unsqueeze(1)
        distance = x['distance']

        # Embeddings
        timestamp_emb = self.timestamp_linear(timestamp)
        security_id_emb = self.security_id_embedding(security_id)
        buy_sell_emb = self.buy_sell_embedding(buy_sell)
        add_modify_delete_emb = self.add_modify_delete_embedding(add_modify_delete)
        quantity_emb = self.quantity_linear(quantity)
        price_emb = self.price_linear(price)
        distance_emb = self.distance_embedding(distance)

        # Concatenate embeddings
        sample_emb = torch.cat([timestamp_emb,
                                security_id_emb,
                                buy_sell_emb,
                                add_modify_delete_emb,
                                quantity_emb,
                                price_emb,
                                distance_emb], dim=1)  # Shape: (batch_size, embedding_dim * 7)

        # Process the sequence data
        # For each feature in seq_data, get embeddings
        batch_size = timestamp.shape[0]

        seq_timestamp = seq_data['timestamp'].float()
        seq_security_id = seq_data['security_id']
        seq_buy_sell = seq_data['buy_sell']
        seq_add_modify_delete = seq_data['add_modify_delete']
        seq_quantity = seq_data['quantity']
        seq_price = seq_data['price']
        seq_distance = seq_data['distance']

        # Embeddings for sequence data
        seq_timestamp_emb = self.timestamp_linear(seq_timestamp.unsqueeze(-1))  # (batch_size, seq_len, embedding_dim)
        seq_security_id_emb = self.security_id_embedding(seq_security_id)
        seq_buy_sell_emb = self.buy_sell_embedding(seq_buy_sell)
        seq_add_modify_delete_emb = self.add_modify_delete_embedding(seq_add_modify_delete)
        seq_quantity_emb = self.quantity_linear(seq_quantity.unsqueeze(-1))
        seq_price_emb = self.price_linear(seq_price.unsqueeze(-1))
        seq_distance_emb = self.distance_embedding(seq_distance)

        # Concatenate sequence embeddings
        seq_emb = torch.cat([seq_timestamp_emb,
                             seq_security_id_emb,
                             seq_buy_sell_emb,
                             seq_add_modify_delete_emb,
                             seq_quantity_emb,
                             seq_price_emb,
                             seq_distance_emb], dim=2)  # Shape: (batch_size, seq_len, embedding_dim * 7)

        # Pack the sequences
        packed_seq_emb = nn.utils.rnn.pack_padded_sequence(seq_emb, seq_lengths, batch_first=True, enforce_sorted=False)

        # Pass through LSTM
        packed_output, (h_n, c_n) = self.lstm(packed_seq_emb)

        # Get the last hidden state for each sequence
        seq_context = h_n.squeeze(0)  # Shape: (batch_size, hidden_dim)

        # Combine the sample embedding with the sequence context
        combined = torch.cat([sample_emb, seq_context], dim=1)  # Shape: (batch_size, embedding_dim * 7 + hidden_dim)

        # Forward pass
        x = self.fc1(combined)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)

        # Outputs
        outputs = {
            'security_id': self.security_id_head(x),
            'buy_sell': self.buy_sell_head(x),
            'add_modify_delete': self.add_modify_delete_head(x),
            'distance': self.distance_head(x),
            'quantity': self.quantity_head(x),
            'price': self.price_head(x)
        }

        return outputs

# Collate function for DataLoader
def collate_fn(batch):
    # batch is a list of (masked_sample, target, sequence)
    batch_size = len(batch)
    masked_samples = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    sequences = [item[2] for item in batch]

    # Convert masked_samples to tensors
    batch_data = {}
    batch_target = {}
    seq_data = {}
    seq_lengths = []

    # Process data
    for key in ['timestamp', 'security_id', 'buy_sell', 'add_modify_delete', 'quantity', 'price', 'distance']:
        batch_data[key] = torch.tensor([sample[key] for sample in masked_samples])

    # Process sequences
    # Sequences are variable-length
    # For each key, we have a list of sequences
    for key in ['timestamp', 'security_id', 'buy_sell', 'add_modify_delete', 'quantity', 'price', 'distance']:
        seq_list = [torch.tensor(seq[key]) for seq in sequences]
        seq_padded = nn.utils.rnn.pad_sequence(seq_list, batch_first=True, padding_value=0)
        seq_data[key] = seq_padded
    # Record lengths
    seq_lengths = [len(seq['timestamp']) for seq in sequences]

    # Convert seq_lengths to tensor
    seq_lengths = torch.tensor(seq_lengths, dtype=torch.long)

    # Process target (exclude 'timestamp' as it's not in target)
    for key in ['security_id', 'buy_sell', 'add_modify_delete', 'quantity', 'price', 'distance']:
        batch_target[key] = []
        for target in targets:
            if target[key] is not None:
                batch_target[key].append(target[key])
            else:
                batch_target[key].append(-100 if key not in ['quantity', 'price'] else 0.0)

        # Convert to tensor and create mask
        if key in ['quantity', 'price']:
            # For continuous features
            batch_target[key + '_mask'] = torch.tensor([t != 0.0 for t in batch_target[key]], dtype=torch.bool)
            batch_target[key] = torch.tensor(batch_target[key], dtype=torch.float)
        else:
            # For categorical features
            batch_target[key + '_mask'] = torch.tensor([t != -100 for t in batch_target[key]], dtype=torch.bool)
            batch_target[key] = torch.tensor(batch_target[key], dtype=torch.long)

    return batch_data, batch_target, seq_data, seq_lengths

# Training and Testing Loop
def train_model():
    # Create Dataset and DataLoader
    dataset = MaskedOrderDataset(num_samples=10000)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)

    # Create model
    model = OrderModel(N_sec_ids=dataset.N_sec_ids, N_distance_bins=dataset.N_distance_bins)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Loss functions
    criterion_ce = nn.CrossEntropyLoss(ignore_index=-100)
    criterion_l1 = nn.L1Loss()

    # Training loop
    for epoch in range(5):
        model.train()
        total_loss = 0.0
        for batch_data, batch_target, seq_data, seq_lengths in dataloader:
            optimizer.zero_grad()

            outputs = model(batch_data, seq_data, seq_lengths)

            loss = 0.0

            # Compute loss for security_id
            target = batch_target['security_id']
            mask = batch_target['security_id_mask']
            if mask.any():
                loss_sec_id = criterion_ce(outputs['security_id'], target)
                loss += loss_sec_id

            # Compute loss for buy_sell
            target = batch_target['buy_sell']
            mask = batch_target['buy_sell_mask']
            if mask.any():
                loss_buy_sell = criterion_ce(outputs['buy_sell'], target)
                loss += loss_buy_sell

            # Compute loss for add_modify_delete
            target = batch_target['add_modify_delete']
            mask = batch_target['add_modify_delete_mask']
            if mask.any():
                loss_amd = criterion_ce(outputs['add_modify_delete'], target)
                loss += loss_amd

            # Compute loss for distance
            target = batch_target['distance']
            mask = batch_target['distance_mask']
            if mask.any():
                loss_distance = criterion_ce(outputs['distance'], target)
                loss += loss_distance

            # Compute loss for quantity
            target = batch_target['quantity']
            mask = batch_target['quantity_mask']
            if mask.any():
                output = outputs['quantity'].squeeze()
                loss_quantity = criterion_l1(output[mask], target[mask])
                loss += loss_quantity

            # Compute loss for price
            target = batch_target['price']
            mask = batch_target['price_mask']
            if mask.any():
                output = outputs['price'].squeeze()
                loss_price = criterion_l1(output[mask], target[mask])
                loss += loss_price

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader)}")

if __name__ == "__main__":
    train_model()

  seq_list = [torch.tensor(seq[key]) for seq in sequences]


KeyboardInterrupt: 

In [6]:
def generate_sequences(model, past_data, security_id, num_timesteps, beam_width, num_samples, device='cpu'):
    """
    Generate sequences using beam search.

    Args:
        model: Trained model.
        past_data: Dictionary containing past sequences.
        security_id: The security ID to generate sequences for.
        num_timesteps: Number of timesteps to generate.
        beam_width: Beam width.
        num_samples: Number of samples to generate.
        device: Device to run the computations on.

    Returns:
        generated_sequences: A list of generated sequences.
    """
    model.eval()
    with torch.no_grad():
        # Initialize the beam with the past_data
        # Each beam entry is a tuple (sequence, cumulative_log_prob)
        beam = [({'timestamp': past_data['timestamp'],
                  'position': past_data['position'],
                  'security_id': past_data['security_id'],
                  'buy_sell': past_data['buy_sell'],
                  'add_modify_delete': past_data['add_modify_delete'],
                  'quantity': past_data['quantity'],
                  'price': past_data['price'],
                  'distance': past_data['distance']},
                 0.0)]  # Start with log probability 0

        for t in range(num_timesteps):
            new_beam = []
            for seq, cum_log_prob in beam:
                # Prepare input data
                # Use the last max_seq_len elements
                seq_len = len(seq['timestamp'])
                start_idx = max(0, seq_len - model.lstm_input_dim) if hasattr(model, 'lstm_input_dim') else 0
                seq_data = {key: torch.tensor(seq[key][start_idx:]).unsqueeze(0).to(device) for key in seq}
                seq_lengths = torch.tensor([seq_data['timestamp'].shape[1]]).to(device)

                # Create a sample with masked features (except security_id and timestamp)
                sample = {
                    'timestamp': torch.tensor([seq['timestamp'][-1] + 1]).to(device),
                    'position': torch.tensor([seq['position'][-1] + 1]).to(device),
                    'security_id': torch.tensor([security_id]).to(device),
                    'buy_sell': torch.tensor([model.buy_sell_mask if hasattr(model, 'buy_sell_mask') else 2]).to(device),
                    'add_modify_delete': torch.tensor([model.add_modify_delete_mask if hasattr(model, 'add_modify_delete_mask') else 3]).to(device),
                    'quantity': torch.tensor([0.0]).to(device),
                    'price': torch.tensor([0.0]).to(device),
                    'distance': torch.tensor([model.distance_mask if hasattr(model, 'distance_mask') else 10]).to(device)
                }

                # Predict missing features
                outputs = model(sample, seq_data, seq_lengths)

                # For each categorical feature, get top N predictions
                candidates = []

                # Adjust topk for buy_sell
                if 'buy_sell' in outputs:
                    buy_sell_probs = F.softmax(outputs['buy_sell'], dim=-1)
                    num_buy_sell_classes = buy_sell_probs.size(-1)
                    buy_sell_topk_k = min(beam_width, num_buy_sell_classes)
                    buy_sell_topk = torch.topk(buy_sell_probs, buy_sell_topk_k)
                else:
                    buy_sell_topk_k = 1
                    buy_sell_topk = None

                # Adjust topk for add_modify_delete
                if 'add_modify_delete' in outputs:
                    add_modify_delete_probs = F.softmax(outputs['add_modify_delete'], dim=-1)
                    num_amd_classes = add_modify_delete_probs.size(-1)
                    amd_topk_k = min(beam_width, num_amd_classes)
                    amd_topk = torch.topk(add_modify_delete_probs, amd_topk_k)
                else:
                    amd_topk_k = 1
                    amd_topk = None

                # Adjust topk for distance
                if 'distance' in outputs:
                    distance_probs = F.softmax(outputs['distance'], dim=-1)
                    num_distance_classes = distance_probs.size(-1)
                    distance_topk_k = min(beam_width, num_distance_classes)
                    distance_topk = torch.topk(distance_probs, distance_topk_k)
                else:
                    distance_topk_k = 1
                    distance_topk = None

                # For continuous features, we can use the predicted value directly
                quantity_pred = outputs['quantity'].item() if 'quantity' in outputs else 0.0
                price_pred = outputs['price'].item() if 'price' in outputs else 0.0

                # Generate combinations of top predictions
                for i in range(buy_sell_topk_k):
                    for j in range(amd_topk_k):
                        for k in range(distance_topk_k):
                            new_seq = {key: seq[key] + [sample[key].item()] for key in seq}
                            if buy_sell_topk is not None:
                                new_seq['buy_sell'][-1] = buy_sell_topk.indices[0][i].item()
                            else:
                                new_seq['buy_sell'][-1] = seq['buy_sell'][-1]  # Use previous value

                            if amd_topk is not None:
                                new_seq['add_modify_delete'][-1] = amd_topk.indices[0][j].item()
                            else:
                                new_seq['add_modify_delete'][-1] = seq['add_modify_delete'][-1]  # Use previous value

                            if distance_topk is not None:
                                new_seq['distance'][-1] = distance_topk.indices[0][k].item()
                            else:
                                new_seq['distance'][-1] = seq['distance'][-1]  # Use previous value

                            new_seq['quantity'][-1] = quantity_pred
                            new_seq['price'][-1] = price_pred

                            # Compute new cumulative log probability
                            log_prob = cum_log_prob
                            if buy_sell_topk is not None:
                                log_prob += torch.log(buy_sell_topk.values[0][i] + 1e-9).item()
                            if amd_topk is not None:
                                log_prob += torch.log(amd_topk.values[0][j] + 1e-9).item()
                            if distance_topk is not None:
                                log_prob += torch.log(distance_topk.values[0][k] + 1e-9).item()

                            candidates.append((new_seq, log_prob))

                # Keep top beam_width candidates
                candidates.sort(key=lambda x: x[1], reverse=True)
                new_beam.extend(candidates[:beam_width])

            # Keep top beam_width sequences
            new_beam.sort(key=lambda x: x[1], reverse=True)
            beam = new_beam[:beam_width]

        # After generation, sample from the beam to get the required number of samples
        generated_sequences = [seq for seq, _ in beam[:num_samples]]
        return generated_sequences


In [7]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.tensorboard import SummaryWriter
import random
from collections import defaultdict
import numpy as np
import os

# Define the Dataset
class MaskedOrderDataset(Dataset):
    def __init__(self, num_samples, mask_prob=0.15, max_seq_len=50):
        self.num_samples = num_samples
        self.N_sec_ids = 100  # Number of unique securities
        self.N_distance_bins = 10  # Number of distance bins
        self.mask_prob = mask_prob
        self.max_seq_len = max_seq_len

        # Generate synthetic data
        self.timestamps = torch.randint(0, 1000, (num_samples,))
        self.security_ids = torch.randint(0, self.N_sec_ids, (num_samples,))
        self.buy_sell = torch.randint(0, 2, (num_samples,))  # 0: buy, 1: sell
        self.add_modify_delete = torch.randint(0, 3, (num_samples,))  # 0: add, 1: modify, 2: delete
        self.quantity = torch.randint(1, 1000, (num_samples,)).float()
        self.price = torch.rand(num_samples) * 100  # Prices between 0 and 100
        self.distance = torch.randint(0, self.N_distance_bins, (num_samples,))

        # Sort the data by timestamps (ascending)
        sorted_indices = torch.argsort(self.timestamps)
        self.timestamps = self.timestamps[sorted_indices]
        self.security_ids = self.security_ids[sorted_indices]
        self.buy_sell = self.buy_sell[sorted_indices]
        self.add_modify_delete = self.add_modify_delete[sorted_indices]
        self.quantity = self.quantity[sorted_indices]
        self.price = self.price[sorted_indices]
        self.distance = self.distance[sorted_indices]

        # Define mask indices
        self.security_id_mask = self.N_sec_ids
        self.buy_sell_mask = 2
        self.add_modify_delete_mask = 3
        self.distance_mask = self.N_distance_bins

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Get sequence from idx to end, up to max_seq_len
        end_idx = min(idx + self.max_seq_len, self.num_samples)
        seq_len = end_idx - idx
        sequence = {
            'timestamp': self.timestamps[idx:end_idx],
            'security_id': self.security_ids[idx:end_idx],
            'buy_sell': self.buy_sell[idx:end_idx],
            'add_modify_delete': self.add_modify_delete[idx:end_idx],
            'quantity': self.quantity[idx:end_idx],
            'price': self.price[idx:end_idx],
            'distance': self.distance[idx:end_idx],
            'position': torch.arange(seq_len)  # Row number
        }

        # Apply masking logic to the first element (idx)
        sample = {key: sequence[key][0] for key in sequence}
        masked_sample = sample.copy()
        target = {}

        # List of features to potentially mask
        maskable_keys = ['security_id', 'buy_sell', 'add_modify_delete', 'quantity', 'price', 'distance']

        # Decide how many features to mask (at least one)
        num_features_to_mask = random.randint(1, len(maskable_keys))

        # Randomly select features to mask
        features_to_mask = random.sample(maskable_keys, num_features_to_mask)

        for key in maskable_keys:
            if key in features_to_mask:
                # Mask the feature
                target[key] = sample[key]
                if key == 'security_id':
                    masked_sample[key] = self.security_id_mask
                elif key == 'buy_sell':
                    masked_sample[key] = self.buy_sell_mask
                elif key == 'add_modify_delete':
                    masked_sample[key] = self.add_modify_delete_mask
                elif key == 'distance':
                    masked_sample[key] = self.distance_mask
                elif key in ['quantity', 'price']:
                    masked_sample[key] = 0.0  # For continuous features, use 0.0 as masked value
            else:
                target[key] = None  # Not masked

        # Return masked_sample, target, and the sequence starting from idx
        return masked_sample, target, sequence

# Learnable Positional Encoder
class PositionalEncoder(nn.Module):
    def __init__(self, embedding_dim):
        super(PositionalEncoder, self).__init__()
        self.timestamp_encoder = nn.Linear(1, embedding_dim)
        self.position_encoder = nn.Embedding(1000, embedding_dim)  # Assume max 1000 positions

    def forward(self, timestamps, positions):
        timestamp_emb = self.timestamp_encoder(timestamps.unsqueeze(-1).float())
        position_emb = self.position_encoder(positions)
        return timestamp_emb + position_emb

# Base Model Class
class BaseModel(nn.Module):
    def __init__(self):
        super(BaseModel, self).__init__()

    def forward(self, *args, **kwargs):
        raise NotImplementedError

# Order Model with Positional Encoder
class OrderModel(BaseModel):
    def __init__(self, N_sec_ids, N_distance_bins, embedding_dim=32, hidden_dim=64):
        super(OrderModel, self).__init__()

        # Mask indices
        self.security_id_mask = N_sec_ids
        self.buy_sell_mask = 2
        self.add_modify_delete_mask = 3
        self.distance_mask = N_distance_bins

        # Embedding layers for categorical features
        self.security_id_embedding = nn.Embedding(N_sec_ids + 1, embedding_dim, padding_idx=self.security_id_mask)
        self.buy_sell_embedding = nn.Embedding(3, embedding_dim, padding_idx=self.buy_sell_mask)  # 0,1,2(mask)
        self.add_modify_delete_embedding = nn.Embedding(4, embedding_dim, padding_idx=self.add_modify_delete_mask)  # 0,1,2,3(mask)
        self.distance_embedding = nn.Embedding(N_distance_bins + 1, embedding_dim, padding_idx=self.distance_mask)

        # Linear layers for continuous features
        self.quantity_linear = nn.Linear(1, embedding_dim)
        self.price_linear = nn.Linear(1, embedding_dim)

        # Positional Encoder
        self.positional_encoder = PositionalEncoder(embedding_dim)

        # LSTM for sequences
        self.lstm_input_dim = embedding_dim * 7  # Number of features
        self.lstm = nn.LSTM(input_size=self.lstm_input_dim, hidden_size=hidden_dim, batch_first=True)

        # Final fully connected layers
        self.fc1 = nn.Linear(embedding_dim * 7 + hidden_dim, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 128)
        self.relu = nn.ReLU()

        # Output heads
        self.security_id_head = nn.Linear(128, N_sec_ids)
        self.buy_sell_head = nn.Linear(128, 2)
        self.add_modify_delete_head = nn.Linear(128, 3)
        self.distance_head = nn.Linear(128, N_distance_bins)
        self.quantity_head = nn.Linear(128, 1)
        self.price_head = nn.Linear(128, 1)

    def forward(self, x, seq_data, seq_lengths):
        # x is a dictionary of features for the masked sample
        # seq_data is a dictionary of sequences
        # seq_lengths is a list of sequence lengths

        batch_size = x['timestamp'].shape[0]

        # Process the masked sample
        timestamp = x['timestamp']
        position = x['position']
        security_id = x['security_id']
        buy_sell = x['buy_sell']
        add_modify_delete = x['add_modify_delete']
        quantity = x['quantity'].unsqueeze(1)
        price = x['price'].unsqueeze(1)
        distance = x['distance']

        # Embeddings
        positional_emb = self.positional_encoder(timestamp, position)
        security_id_emb = self.security_id_embedding(security_id)
        buy_sell_emb = self.buy_sell_embedding(buy_sell)
        add_modify_delete_emb = self.add_modify_delete_embedding(add_modify_delete)
        quantity_emb = self.quantity_linear(quantity)
        price_emb = self.price_linear(price)
        distance_emb = self.distance_embedding(distance)

        # Concatenate embeddings
        sample_emb = torch.cat([positional_emb,
                                security_id_emb,
                                buy_sell_emb,
                                add_modify_delete_emb,
                                quantity_emb,
                                price_emb,
                                distance_emb], dim=1)  # Shape: (batch_size, embedding_dim * 7)

        # Process the sequence data
        # For each feature in seq_data, get embeddings

        seq_timestamp = seq_data['timestamp']
        seq_position = seq_data['position']
        seq_security_id = seq_data['security_id']
        seq_buy_sell = seq_data['buy_sell']
        seq_add_modify_delete = seq_data['add_modify_delete']
        seq_quantity = seq_data['quantity']
        seq_price = seq_data['price']
        seq_distance = seq_data['distance']

        # Embeddings for sequence data
        seq_positional_emb = self.positional_encoder(seq_timestamp, seq_position)
        seq_security_id_emb = self.security_id_embedding(seq_security_id)
        seq_buy_sell_emb = self.buy_sell_embedding(seq_buy_sell)
        seq_add_modify_delete_emb = self.add_modify_delete_embedding(seq_add_modify_delete)
        seq_quantity_emb = self.quantity_linear(seq_quantity.unsqueeze(-1))
        seq_price_emb = self.price_linear(seq_price.unsqueeze(-1))
        seq_distance_emb = self.distance_embedding(seq_distance)

        # Concatenate sequence embeddings
        seq_emb = torch.cat([seq_positional_emb,
                             seq_security_id_emb,
                             seq_buy_sell_emb,
                             seq_add_modify_delete_emb,
                             seq_quantity_emb,
                             seq_price_emb,
                             seq_distance_emb], dim=2)  # Shape: (batch_size, seq_len, embedding_dim * 7)

        # Pack the sequences
        packed_seq_emb = pack_padded_sequence(seq_emb, seq_lengths.cpu(), batch_first=True, enforce_sorted=False)

        # Pass through LSTM
        packed_output, (h_n, c_n) = self.lstm(packed_seq_emb)

        # Get the last hidden state for each sequence
        seq_context = h_n[-1]  # Shape: (batch_size, hidden_dim)

        # Combine the sample embedding with the sequence context
        combined = torch.cat([sample_emb, seq_context], dim=1)  # Shape: (batch_size, embedding_dim * 7 + hidden_dim)

        # Forward pass
        x = self.fc1(combined)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)

        # Outputs
        outputs = {
            'security_id': self.security_id_head(x),
            'buy_sell': self.buy_sell_head(x),
            'add_modify_delete': self.add_modify_delete_head(x),
            'distance': self.distance_head(x),
            'quantity': self.quantity_head(x),
            'price': self.price_head(x)
        }

        return outputs

# Additional Models (Assuming models are unchanged, not included here for brevity)

# [Include other models here if needed]

# Collate function for DataLoader
def collate_fn(batch):
    # batch is a list of (masked_sample, target, sequence)
    batch_size = len(batch)
    masked_samples = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    sequences = [item[2] for item in batch]

    # Convert masked_samples to tensors
    batch_data = {}
    batch_target = {}
    seq_data = {}
    seq_lengths = []

    # Process data
    for key in ['timestamp', 'position', 'security_id', 'buy_sell', 'add_modify_delete', 'quantity', 'price', 'distance']:
        batch_data[key] = torch.tensor([sample[key] for sample in masked_samples])

    # Process sequences
    # Sequences are variable-length
    # For each key, we have a list of sequences
    for key in ['timestamp', 'position', 'security_id', 'buy_sell', 'add_modify_delete', 'quantity', 'price', 'distance']:
        seq_list = [torch.tensor(seq[key]) for seq in sequences]
        seq_padded = nn.utils.rnn.pad_sequence(seq_list, batch_first=True, padding_value=0)
        seq_data[key] = seq_padded
    # Record lengths
    seq_lengths = [len(seq['timestamp']) for seq in sequences]

    # Convert seq_lengths to tensor
    seq_lengths = torch.tensor(seq_lengths, dtype=torch.long)

    # Process target (exclude 'timestamp' and 'position' as they're not in target)
    for key in ['security_id', 'buy_sell', 'add_modify_delete', 'quantity', 'price', 'distance']:
        batch_target[key] = []
        for target in targets:
            if target[key] is not None:
                batch_target[key].append(target[key])
            else:
                batch_target[key].append(-100 if key not in ['quantity', 'price'] else 0.0)

        # Convert to tensor and create mask
        if key in ['quantity', 'price']:
            # For continuous features
            batch_target[key + '_mask'] = torch.tensor([t != 0.0 for t in batch_target[key]], dtype=torch.bool)
            batch_target[key] = torch.tensor(batch_target[key], dtype=torch.float)
        else:
            # For categorical features
            batch_target[key + '_mask'] = torch.tensor([t != -100 for t in batch_target[key]], dtype=torch.bool)
            batch_target[key] = torch.tensor(batch_target[key], dtype=torch.long)

    return batch_data, batch_target, seq_data, seq_lengths

# Training and Testing Loop with Checkpointing and TensorBoard
def train_model(model_class, model_params, dataset_params, training_params, save_path=None, load_path=None):
    # Create Dataset and DataLoader
    dataset = MaskedOrderDataset(**dataset_params)
    dataloader = DataLoader(dataset, batch_size=training_params['batch_size'], shuffle=True, collate_fn=collate_fn)

    # Create model
    model = model_class(**model_params).to(training_params['device'])
    optimizer = optim.Adam(model.parameters(), lr=training_params['learning_rate'])

    # Load checkpoint if provided
    start_epoch = 0
    if load_path and os.path.exists(load_path):
        checkpoint = torch.load(load_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Loaded checkpoint from '{load_path}' (epoch {start_epoch})")

    # TensorBoard SummaryWriter
    writer = SummaryWriter()

    # Loss functions
    criterion_ce = nn.CrossEntropyLoss(ignore_index=-100)
    criterion_l1 = nn.L1Loss()

    # Training loop
    for epoch in range(start_epoch, training_params['num_epochs']):
        model.train()
        total_loss = 0.0
        for batch_idx, (batch_data, batch_target, seq_data, seq_lengths) in enumerate(dataloader):
            optimizer.zero_grad()

            # Move data to device
            device = training_params.get('device', 'cpu')
            batch_data = {k: v.to(device) for k, v in batch_data.items()}
            batch_target = {k: v.to(device) for k, v in batch_target.items()}
            seq_data = {k: v.to(device) for k, v in seq_data.items()}
            seq_lengths = seq_lengths.to(device)

            outputs = model(batch_data, seq_data, seq_lengths)

            loss = 0.0

            # Compute loss for each feature
            for key in ['security_id', 'buy_sell', 'add_modify_delete', 'distance']:
                if key in outputs:
                    target = batch_target[key]
                    mask = batch_target[key + '_mask']
                    if mask.any():
                        output = outputs[key]
                        loss_feature = criterion_ce(output, target)
                        loss += loss_feature

            for key in ['quantity', 'price']:
                if key in outputs:
                    target = batch_target[key]
                    mask = batch_target[key + '_mask']
                    if mask.any():
                        output = outputs[key].squeeze()
                        loss_feature = criterion_l1(output[mask], target[mask])
                        loss += loss_feature

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Log training loss to TensorBoard
            writer.add_scalar('Loss/train_batch', loss.item(), epoch * len(dataloader) + batch_idx)

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}, Loss: {avg_loss}")

        # Log average loss per epoch to TensorBoard
        writer.add_scalar('Loss/train_epoch', avg_loss, epoch+1)

        # Save checkpoint
        if save_path:
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }
            torch.save(checkpoint, save_path)
            print(f"Checkpoint saved at '{save_path}'")

    writer.close()
    return model

# Function to load a saved model
def load_model(model_class, model_params, checkpoint_path, device='cpu'):
    model = model_class(**model_params).to(device)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Model loaded from '{checkpoint_path}'")
    return model

# Beam Search Generation Function (Assuming unchanged)

# Hyperparameter Optimization Function (Assuming unchanged)

# Example Usage
if __name__ == "__main__":
    # Training parameters
    training_params = {
        'batch_size': 32,
        'learning_rate': 0.001,
        'num_epochs': 5,
        'device': 'cuda' if torch.cuda.is_available() else 'cpu'
    }

    # Dataset parameters
    dataset_params = {
        'num_samples': 5000,
        'mask_prob': 0.15,
        'max_seq_len': 50
    }

    # Common parameters
    output_keys = ['security_id', 'buy_sell', 'add_modify_delete', 'distance', 'quantity', 'price']
    num_classes_dict = {
        'security_id': 100,
        'buy_sell': 2,
        'add_modify_delete': 3,
        'distance': 10
    }

    # Example: Training OrderModel with checkpointing and TensorBoard
    print("Training OrderModel with checkpointing and TensorBoard...")
    model_params_order = {
        'N_sec_ids': 100,
        'N_distance_bins': 10,
        'embedding_dim': 32,
        'hidden_dim': 64
    }

    # Paths for saving and loading checkpoints
    checkpoint_path = 'order_model_checkpoint.pth'

    model_order = train_model(
        model_class=OrderModel,
        model_params=model_params_order,
        dataset_params=dataset_params,
        training_params=training_params,
        save_path=checkpoint_path,
        load_path=None  # Set to checkpoint_path if you want to load from a checkpoint
    )

    # Example: Loading the trained model
    loaded_model_order = load_model(
        model_class=OrderModel,
        model_params=model_params_order,
        checkpoint_path=checkpoint_path,
        device=training_params['device']
    )

    # Generate sequences using beam search with loaded model
    print("\nGenerating sequences using beam search with loaded OrderModel...")
    # Prepare past_data
    past_data = {
        'timestamp': [1000, 1001, 1002],
        'position': [0, 1, 2],
        'security_id': [5, 5, 5],
        'buy_sell': [0, 1, 0],
        'add_modify_delete': [0, 1, 2],
        'quantity': [500.0, 600.0, 700.0],
        'price': [50.0, 51.0, 52.0],
        'distance': [3, 2, 1]
    }

    # Parameters for generation
    security_id = 5
    num_timesteps = 5
    beam_width = 3
    num_samples = 3

    generated_sequences = generate_sequences(
        model=loaded_model_order,
        past_data=past_data,
        security_id=security_id,
        num_timesteps=num_timesteps,
        beam_width=beam_width,
        num_samples=num_samples,
        device=training_params['device']
    )

    # Compute distribution of buys vs sells
    total_buy = 0
    total_sell = 0
    total_actions = 0

    for seq in generated_sequences:
        buy_sell_seq = seq['buy_sell'][-num_timesteps:]  # Only consider generated timesteps
        total_buy += buy_sell_seq.count(0)
        total_sell += buy_sell_seq.count(1)
        total_actions += len(buy_sell_seq)

    buy_percentage = (total_buy / total_actions) * 100 if total_actions > 0 else 0
    sell_percentage = (total_sell / total_actions) * 100 if total_actions > 0 else 0

    print(f"\nGenerated {num_samples} sequences of {num_timesteps} timesteps each.")
    print(f"Buy actions: {buy_percentage:.2f}%")
    print(f"Sell actions: {sell_percentage:.2f}%")

    # To visualize TensorBoard logs, run the following command in your terminal:
    # tensorboard --logdir runs

Training OrderModel with checkpointing and TensorBoard...


  seq_list = [torch.tensor(seq[key]) for seq in sequences]


Epoch 1, Loss: 377.2614810238978
Checkpoint saved at 'order_model_checkpoint.pth'
Epoch 2, Loss: 349.6525611634467
Checkpoint saved at 'order_model_checkpoint.pth'
Epoch 3, Loss: 335.3304006005548
Checkpoint saved at 'order_model_checkpoint.pth'
Epoch 4, Loss: 316.08162004021324
Checkpoint saved at 'order_model_checkpoint.pth'
Epoch 5, Loss: 296.4181347500746
Checkpoint saved at 'order_model_checkpoint.pth'
Model loaded from 'order_model_checkpoint.pth'

Generating sequences using beam search with loaded OrderModel...

Generated 3 sequences of 5 timesteps each.
Buy actions: 100.00%
Sell actions: 0.00%


  checkpoint = torch.load(checkpoint_path, map_location=device)
