In [2]:
import torch
import torch.nn as F

In [3]:
rnn = F.LSTM(10, 20, 2)

In [4]:
rnn

LSTM(10, 20, num_layers=2)

In [5]:
import torch
import torch.nn as nn

class SequentialLSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        """
        Initialize the LSTM model that processes sequences of observations.
        
        Args:
            input_size (int): Size of each input observation
            hidden_size (int): Number of features in the hidden state
            output_size (int): Size of output action space
            num_layers (int): Number of LSTM layers
        """
        super().__init__()
        
        self.hidden_size = hidden_size
        
        # Main LSTM layer to process sequences
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size, 
            num_layers=num_layers,
            batch_first=True  # Use (batch, seq, feature) format
        )
        
        # Linear layer to project LSTM output to action logits
        self.action_head = nn.Linear(hidden_size, output_size)
        
        # Initialize hidden state and cell state
        self.hidden = None
        self.cell = None
        
    def reset_states(self, batch_size=1, device='cpu'):
        """Reset the hidden and cell states"""
        self.hidden = torch.zeros(1, batch_size, self.hidden_size).to(device)
        self.cell = torch.zeros(1, batch_size, self.hidden_size).to(device)
        
    def forward(self, x):
        """
        Forward pass through the network.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, input_size)
                            where seq_len=500 in your case
        Returns:
            outputs (torch.Tensor): Action logits for each timestep
        """
        # Process the sequence through LSTM
        lstm_out, (self.hidden, self.cell) = self.lstm(x, (self.hidden, self.cell))
        
        # Project LSTM outputs to action space for each timestep
        action_logits = self.action_head(lstm_out)
        
        return action_logits

# Example usage:
if __name__ == "__main__":
    # Example parameters
    BATCH_SIZE = 32
    SEQ_LENGTH = 500
    INPUT_SIZE = 100  # Size of each observation
    HIDDEN_SIZE = 512
    OUTPUT_SIZE = 10  # Size of action space
    
    # Create model
    model = SequentialLSTMModel(
        input_size=INPUT_SIZE,
        hidden_size=HIDDEN_SIZE,
        output_size=OUTPUT_SIZE
    )
    
    # Reset states for new sequence
    model.reset_states(batch_size=BATCH_SIZE)
    
    # Example forward pass
    dummy_input = torch.randn(BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
    outputs = model(dummy_input)
    
    print(f"Input shape: {dummy_input.shape}")
    print(f"Output shape: {outputs.shape}")

Input shape: torch.Size([32, 500, 100])
Output shape: torch.Size([32, 500, 10])
