In [1]:
# --- IMPORTANT: Run these cells first in your Jupyter Notebook ---
# These commands will execute the entire gcn_layer.ipynb and temporal_transformer.ipynb files.
# This makes the classes (GCNLayer, TemporalTransformer) and variables (like TRANSFORMER_D_MODEL)
# defined in those notebooks available in this current notebook's global scope.
%run gcn_layer.ipynb
%run temporal_transformer.ipynb

--- Testing GCNLayer ---
Loaded A_hat with shape: torch.Size([62, 62])
GCNLayer instantiated with in_features=1, out_features=64
Dummy input X shape: torch.Size([4, 62, 1])
Output shape of GCNLayer: torch.Size([4, 62, 64])
GCNLayer test passed: Output shape matches expected shape!
--- GCNLayer Test Complete ---
--- Simulating GCNLayer processing in STGTEncoder ---
Loaded all_eeg_tensors from './eeg_tensors.pt'. Shape: torch.Size([28000, 62, 400])
Loaded A_hat with shape: torch.Size([62, 62])

Instantiated GCNLayer with in_features=1, out_features=64
Simulating a batch of EEG data with shape: torch.Size([4, 62, 400])

Processing each time step with GCNLayer...

Finished processing all time steps with GCNLayer.
Shape of the final GCN sequence output: torch.Size([4, 400, 62, 64])
Simulation successful: Final GCN sequence output shape matches expected!

This `final_gcn_sequence_output` is what will typically be fed into your Temporal Transformer.
The next step in building the STGT Encoder 



TemporalTransformer instantiated with d_model=3968, nhead=8, num_layers=3
Dummy input sequence shape: torch.Size([400, 4, 3968])
Output sequence shape of TemporalTransformer: torch.Size([400, 4, 3968])
TemporalTransformer test passed: Output shape matches expected shape!
--- TemporalTransformer Test Complete ---


In [None]:
# --- Verify that the classes and variables are now accessible ---
# You can uncomment these lines to check if the imports worked after running the above.
# print(f"GCNLayer class available: {GCNLayer is not None}")
# print(f"TemporalTransformer class available: {TemporalTransformer is not None}")
# print(f"TRANSFORMER_D_MODEL available: {TRANSFORMER_D_MODEL}")

import torch
import torch.nn as nn

# --- Define the STGTEncoder Class ---
# This class combines the GCNLayer and TemporalTransformer for spatio-temporal EEG processing.
class STGTEncoder(nn.Module):
    def __init__(self, in_channels, num_time_steps, gcn_out_features, 
                 transformer_d_model, transformer_nhead, transformer_num_layers, 
                 transformer_dim_feedforward, transformer_dropout=0.1):
        super(STGTEncoder, self).__init__()

        self.in_channels = in_channels # e.g., 62 for EEG channels
        self.num_time_steps = num_time_steps # e.g., 400 for EEG samples per trial

        # --- GCN Layer (Spatial Processing) ---
        # Takes 1 input feature (raw signal amplitude) and outputs 'gcn_out_features'
        self.gcn_layer = GCNLayer(in_features=1, out_features=gcn_out_features)

        # --- Validate and Set Transformer d_model ---
        # The transformer_d_model must be the flattened size of (in_channels * gcn_out_features)
        # as this will be the input dimension for each time step's token in the Transformer.
        expected_transformer_d_model = in_channels * gcn_out_features
        if transformer_d_model != expected_transformer_d_model:
            raise ValueError(
                f"transformer_d_model ({transformer_d_model}) must be equal to "
                f"in_channels ({in_channels}) * gcn_out_features ({gcn_out_features}) "
                f"= {expected_transformer_d_model}. Please check your hyperparameters."
            )
        self.transformer_d_model = transformer_d_model

        # --- Temporal Transformer (Temporal Processing) ---
        self.temporal_transformer = TemporalTransformer(
            d_model=self.transformer_d_model,
            nhead=transformer_nhead,
            num_encoder_layers=transformer_num_layers,
            dim_feedforward=transformer_dim_feedforward,
            dropout=transformer_dropout
        )

    def forward(self, eeg_batch_data, adj_matrix):
        # eeg_batch_data shape: (batch_size, in_channels, num_time_steps) e.g., (32, 62, 400)
        # adj_matrix shape: (in_channels, in_channels) e.g., (62, 62)

        batch_size = eeg_batch_data.size(0)
        
        # List to store the output features from the GCN for each time step
        gcn_outputs_sequence = []

        # --- Step 1: Spatial Processing with GCN for each Time Step ---
        for t in range(self.num_time_steps):
            # Extract X for the current time step for ALL samples in the batch
            # Shape: (batch_size, in_channels)
            X_for_current_timestep_batch = eeg_batch_data[:, :, t]

            # Reshape X to (batch_size, in_channels, 1) to match GCNLayer's in_features=1
            X_for_gcn_input = X_for_current_timestep_batch.unsqueeze(2) 

            # Pass X and A_hat through the GCNLayer
            # Output will be (batch_size, in_channels, gcn_out_features)
            gcn_output_t = self.gcn_layer(X_for_gcn_input, adj_matrix)
            
            # Store the GCN output for this time step
            gcn_outputs_sequence.append(gcn_output_t)

        # --- Step 2: Prepare GCN Outputs for Temporal Transformer ---
        # Stack the list of GCN outputs into a single tensor.
        # This creates a tensor of shape (batch_size, num_time_steps, in_channels, gcn_out_features).
        stacked_gcn_output = torch.stack(gcn_outputs_sequence, dim=1)
        
        # Flatten the (in_channels, gcn_out_features) part into a single dimension.
        # This converts each time step's (channels x features) into a single "token" vector.
        # Shape: (batch_size, num_time_steps, transformer_d_model)
        transformer_input_flat = stacked_gcn_output.view(
            batch_size, self.num_time_steps, self.transformer_d_model
        )

        # Permute dimensions to (num_time_steps, batch_size, transformer_d_model).
        # This is the standard input order expected by PyTorch's TransformerEncoder
        # when its 'batch_first' parameter is set to False (as it is in TemporalTransformer).
        transformer_input_permuted = transformer_input_flat.permute(1, 0, 2)
        
        # --- Step 3: Temporal Processing with Transformer ---
        # The Transformer processes the sequence of spatially-enriched tokens.
        # Output will have the same shape: (num_time_steps, batch_size, transformer_d_model).
        transformer_output_sequence = self.temporal_transformer(transformer_input_permuted)
        
        # --- Step 4: Aggregate Final Embedding ---
        # To get a single, fixed-size embedding for the entire EEG trial,
        # we average the Transformer's output sequence across the time steps.
        # The result is an embedding of shape (batch_size, transformer_d_model).
        final_eeg_embedding = transformer_output_sequence.mean(dim=0)
        
        return final_eeg_embedding

# --- Test and Verify the Full STGTEncoder ---
if __name__ == "__main__":
    print("--- Testing STGTEncoder ---")

    # --- 1. Load Data ---
    # Adjust these paths to where your files are located
    EEG_TENSORS_PATH = "./eeg_tensors.pt" 
    ADJ_MATRIX_PATH = "/home/sanu/Projects/Capstone_EEG/adj_matrix.pt" 

    try:
        all_eeg_data = torch.load(EEG_TENSORS_PATH)
        print(f"Loaded all_eeg_tensors. Shape: {all_eeg_data.shape}")
        A_hat = torch.load(ADJ_MATRIX_PATH)
        print(f"Loaded A_hat. Shape: {A_hat.shape}")
    except FileNotFoundError as e:
        print(f"Error loading required files: {e}")
        print("Please ensure 'eeg_tensors.pt' and 'adj_matrix.pt' exist at the specified paths.")
        # If running in a notebook, you might remove exit() to continue with dummy data for testing
        exit()

    # --- 2. Define Hyperparameters for STGTEncoder ---
    in_channels = 62            # Number of EEG channels
    num_time_steps = 400        # Number of time steps in each EEG trial
    gcn_out_features = 64       # Number of output features from the GCNLayer per channel

    # Transformer's d_model must match (in_channels * gcn_out_features) for the flattening
    transformer_d_model = in_channels * gcn_out_features # 62 * 64 = 3968
    transformer_nhead = 8
    transformer_num_layers = 3
    transformer_dim_feedforward = 4 * transformer_d_model
    transformer_dropout = 0.1

    # --- 3. Instantiate STGTEncoder ---
    stgt_encoder = STGTEncoder(
        in_channels=in_channels,
        num_time_steps=num_time_steps,
        gcn_out_features=gcn_out_features,
        transformer_d_model=transformer_d_model,
        transformer_nhead=transformer_nhead,
        transformer_num_layers=transformer_num_layers,
        transformer_dim_feedforward=transformer_dim_feedforward,
        transformer_dropout=transformer_dropout
    )
    print(f"\nSTGTEncoder instantiated with gcn_out_features={gcn_out_features} and transformer_d_model={transformer_d_model}")

    # --- 4. Prepare Dummy Batch from Loaded Data ---
    batch_size = 4
    if all_eeg_data.shape[0] < batch_size:
        print(f"Warning: Not enough samples ({all_eeg_data.shape[0]}) for batch size {batch_size}. Using all available samples.")
        batch_size = all_eeg_data.shape[0]

    # Select a batch of EEG data (e.g., first 'batch_size' samples)
    dummy_eeg_batch = all_eeg_data[:batch_size, :, :]
    print(f"Dummy EEG batch shape: {dummy_eeg_batch.shape}")

    # --- 5. Pass through STGTEncoder ---
    print("\nRunning STGTEncoder forward pass...")
    final_embedding = stgt_encoder(dummy_eeg_batch, A_hat)

    # --- 6. Print Output Shape ---
    print(f"Shape of the final EEG embedding from STGTEncoder: {final_embedding.shape}")

    # Expected output shape: (batch_size, transformer_d_model)
    expected_output_shape = (batch_size, transformer_d_model)
    if final_embedding.shape == expected_output_shape:
        print("STGTEncoder test passed: Output embedding shape matches expected shape! ✅")
    else:
        print(f"STGTEncoder test FAILED: Expected {expected_output_shape}, got {final_embedding.shape} ❌")

    print("\n--- STGTEncoder Test Complete ---")
    print("This final_embedding is the rich EEG representation ready for the next phases: "
          "Contrastive Alignment, MTR Head, and Text Decoder.")
