In [5]:
# Transformer Hyperparameters - You will choose these
# Ensure TRANSFORMER_D_MODEL matches (n_channels * GCN_out_features) from previous step
# If your GCN out_features is 64, then 62 * 64 = 3968
TRANSFORMER_D_MODEL = 62 * 64 
TRANSFORMER_NHEAD = 8         # Must divide TRANSFORMER_D_MODEL evenly
TRANSFORMER_NUM_LAYERS = 3    # Number of TransformerEncoderLayer blocks
TRANSFORMER_DIM_FEEDFORWARD = 4 * TRANSFORMER_D_MODEL # Typically 2x or 4x d_model
TRANSFORMER_DROPOUT = 0.1

In [6]:
import torch
import torch.nn as nn

class TemporalTransformer(nn.Module):
    def __init__(self, d_model, nhead, num_encoder_layers, dim_feedforward, dropout=0.1):
        super(TemporalTransformer, self).__init__()
        
        # Define a single Transformer Encoder Layer block
        # batch_first=False means input is (Sequence Length, Batch Size, Embedding Dimension)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=False # Keep this as False to match the default PyTorch Transformer API example
                              # (Sequence Length, Batch Size, Embedding Dimension)
        )
        
        # Stack multiple encoder_layer blocks to create the full Transformer Encoder
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=num_encoder_layers
        )
        
    def forward(self, src):
        # Input 'src' is EXPECTED to already be in the shape (n_time_steps, batch_size, d_model)
        # This reshaping logic should occur in the STGTEncoder, *before* calling TemporalTransformer.
        
        # Pass the input directly through the Transformer Encoder
        # The output 'output' will have the same shape as 'src'
        output = self.transformer_encoder(src)
        
        # Return the sequence of processed embeddings
        # Shape: (n_time_steps, batch_size, d_model)
        return output 


In [7]:
# temporal_transformer.py (continued)

if __name__ == "__main__":
    print("--- Testing TemporalTransformer ---")

    # Use the hyperparameters defined earlier
    d_model = TRANSFORMER_D_MODEL
    nhead = TRANSFORMER_NHEAD
    num_encoder_layers = TRANSFORMER_NUM_LAYERS
    dim_feedforward = TRANSFORMER_DIM_FEEDFORWARD
    dropout = TRANSFORMER_DROPOUT

    # Instantiate the TemporalTransformer
    temporal_transformer = TemporalTransformer(
        d_model=d_model,
        nhead=nhead,
        num_encoder_layers=num_encoder_layers,
        dim_feedforward=dim_feedforward,
        dropout=dropout
    )
    print(f"TemporalTransformer instantiated with d_model={d_model}, nhead={nhead}, num_layers={num_encoder_layers}")

    # Create a dummy input tensor
    # Mimics the shape (n_time_steps, batch_size, d_model)
    dummy_n_time_steps = 400 # Your EEG time steps
    dummy_batch_size = 4
    dummy_d_model = d_model # This should match TRANSFORMER_D_MODEL

    dummy_input_sequence = torch.randn(dummy_n_time_steps, dummy_batch_size, dummy_d_model)
    print(f"Dummy input sequence shape: {dummy_input_sequence.shape}")

    # Pass dummy data through the TemporalTransformer
    output_sequence = temporal_transformer(dummy_input_sequence)

    # Print the shape of the output
    print(f"Output sequence shape of TemporalTransformer: {output_sequence.shape}")

    # Expected output shape should be the same as input: (dummy_n_time_steps, dummy_batch_size, dummy_d_model)
    expected_shape = (dummy_n_time_steps, dummy_batch_size, dummy_d_model)
    if output_sequence.shape == expected_shape:
        print("TemporalTransformer test passed: Output shape matches expected shape!")
    else:
        print(f"TemporalTransformer test FAILED: Expected {expected_shape}, got {output_sequence.shape}")

    print("--- TemporalTransformer Test Complete ---")

--- Testing TemporalTransformer ---
TemporalTransformer instantiated with d_model=3968, nhead=8, num_layers=3
Dummy input sequence shape: torch.Size([400, 4, 3968])
Output sequence shape of TemporalTransformer: torch.Size([400, 4, 3968])
TemporalTransformer test passed: Output shape matches expected shape!
--- TemporalTransformer Test Complete ---
