In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from polarbert.config import PolarBertConfig

In [None]:
import yaml
config_file = "/groups/pheno/inar/PolarBERT/configs/polarbert_new.yaml"

try:
    config = PolarBertConfig.from_yaml(config_file)

    # --- Calculate Runtime Params (example, needs actual train_loader length) ---
    # Replace '10000' with len(train_loader) from your script
    # Note: len(dataloader) with IterableDataset might be tricky, often needs manual setting.
    # Let's assume you know the number of batches per epoch for the dataloader.
    num_batches_per_epoch = 10000 # Placeholder! Calculate this properly.
    config.calculate_runtime_params(num_batches_per_epoch)
    print("\n--- Calculated Runtime Params ---")
    print(f"Total Steps: {config.training.total_steps}")
    print(f"Grad Accum Steps: {config.training.gradient_accumulation_steps}")
    print(f"Effective Batch Size: {config.training.effective_batch_size}")
    print(f"Final pct_start for Scheduler: {config.training.pct_start:.4f}")


    print("\n--- Accessing Config ---")
    print(f"Model Name: {config.model.model_name}")
    print(f"Time Embedding Dim: {config.model.embedding.time_embedding_dim}")
    print(f"Optimizer Beta1: {config.training.adam_beta1}")
    print(f"Checkpoint Dir: {config.training.checkpoint.dirpath}")

    # Example saving
    # config.save_yaml('saved_config_copy.yaml')

    # Example loading from checkpoint dir
    # cp_path = 'checkpoints/your_model_run/epoch=01-step=1000.ckpt'
    # config_from_cp = PolarBertConfig.from_checkpoint(cp_path)
    # print(f"\nLoaded config from checkpoint dir. Project: {config_from_cp.training.logging.project}")

except (FileNotFoundError, ValueError, yaml.YAMLError, TypeError) as e:
     print(f"\nError loading/validating/calculating config: {e}")

Loading configuration from: /groups/pheno/inar/PolarBERT/configs/polarbert_new.yaml
Validating configuration...
Calculating runtime training parameters...
  Logical Batch Size: 4096
  Max Per Device Batch Size: 4096
  Calculated Per Device Batch Size: 4096
  Gradient Accumulation Steps: 1
  Effective Batch Size: 4096
  Steps per Epoch (logical): 10000
  Total Steps: 200000
  Overriding pct_start based on warm_up_steps: 0.2000 -> 0.0050

--- Calculated Runtime Params ---
Total Steps: 200000
Grad Accum Steps: 1
Effective Batch Size: 4096
Final pct_start for Scheduler: 0.0050

--- Accessing Config ---
Model Name: polarbert_time_embed
Time Embedding Dim: 128
Optimizer Beta1: 0.9
Checkpoint Dir: checkpoints


In [9]:
from pprint import pprint
config_dict = config.to_dict()
pprint(config_dict)

{'data': {'max_per_device_batch_size': 4096,
          'num_workers': 1,
          'persistent_workers': True,
          'pin_memory': False,
          'train_dir': '/groups/pheno/inar/icecube_kaggle/memmaped_100M_127',
          'train_events': 100000000,
          'val_dir': '/groups/pheno/inar/icecube_kaggle/memmaped_eval_1M_127',
          'val_events': 200000},
 'model': {'embedding': {'aux_embedding_dim': 4,
                         'charge_embedding_dim': 16,
                         'charge_vocab_size': 128,
                         'dom_embedding_dim': 108,
                         'dom_vocab_size': 5162,
                         'embedding_dim': 256,
                         'embedding_projection': False,
                         'masking_charges': False,
                         'masking_doms': True,
                         'masking_prob': 0.25,
                         'masking_times': False,
                         'time_embedding_dim': 128,
                         'tim

config parameters
Embedding:
- time_embedding_dim: 128
- dom_embedding_dim: 111 # with aux (1) and charge (16) will be 128
- charge_embedding_dim: 16
- time_vocab_size: 52000 # includes, padding, maksing, and overflow tokens
- dom_vocab_size: 5162 # includes padding and masking tokens
- charge_num_bins: 32 # includes padding and masking tokens
- masking_doms: True
- masking_times: True
- masking_charges: True
- masking_prob: 0.25 
Model
- num_layers: 8
- num_heads: 4 # assert sum emb dim is divisible by num_heads
- hidden_size: 1024
- ffd_type: "SwiGLU" # or "MLP"
- use_rope: False
- use_positional_embedding: False # assert use_rope is False

data:
  max_per_device_batch_size: 1024 # Maximum batch size that fits on the GPU
  train_dir: '/path/to/train/data'
  val_dir: '/path/to/val/data'
  train_events: 10000000
  val_events: 200000
  pin_memory: false
  num_workers: 1 # Should remain 1 when using an IterableDataset
  persistent_workers: true

training:
  max_epochs: 1
  logical_batch_size: 200 # Batch size used for training (will use gradient accumulation if necessary)
  val_check_interval: 0.1
  gpus: 1
  precision: "16-mixed"
  gradient_clip_val: 2.0
  max_lr: 3e-3
  optimizer: "AdamW"
  optimier_kwargs:
    betas: [0.9, 0.95]
    eps: 1e-8
    weight_decay: 0.1
  scheduler: 'onecycle'
  warm_up_steps: null  # Number of steps for learning rate warm-up
  pct_start: 0.2       # Percentage of training for warm-up when warm_up_steps is not provided
  div_factor: 25.0
  final_div_factor: 1e4
  
logging:
  project: 'PolarBERT'
  checkpoint:
    dirpath: 'checkpoints'
    save_top_k: 1
    monitor: 'val/full_loss'
    mode: 'min'
    save_last: true
    save_final: true
  


MAX_TIME_VOCAB = 51996   # Max relative time value allowed (0 to 51996 ns)
TIME_PADDING_IDX = 51997 # Index for padded time positions
TIME_MASK_IDX = 51998    # Reserved index for MASK token in time vocab
# TIME_UNUSED_IDX = 51999 # Reserved index

DOM_PADDING_IDX = 0      # Padding index used for DOM IDs in input data
DOM_MASK_IDX = 5161      # Reserved index for MASK token in DOM vocab (5160 real DOMs)
DOM_VOCAB_SIZE = 5160 + 2 # 5160 DOMs + PAD + MASK


In [None]:
import torch
import torch.nn as nn
import numpy as np # Only needed for np.inf if torch.inf is not available

# --- Define Constants (Using PAD=0, MASK=LAST_IDX scheme) ---

PAD_IDX = 0

# Time Vocab
TIME_VOCAB_SIZE = 52000 # Total size: 0=PAD, 1..51998=Data(0..51997ns), 51999=MASK
TIME_MASK_IDX = TIME_VOCAB_SIZE - 1 # 51999
# Max duration corresponds to index TIME_MASK_IDX - 1
MAX_TIME_DURATION = TIME_MASK_IDX - 1 # 51998 ns (Indices 1 to 51998 cover this range)

# DOM Vocab
DOM_VOCAB_SIZE = 5162 # 0=PAD, 1..5160=Data(IDs 1..5160), 5161=MASK
DOM_MASK_IDX = DOM_VOCAB_SIZE - 1 # 5161

# Charge Vocab (Example: 32 bins -> indices 0..31 for data)
CHARGE_NUM_BINS = 32 # Number of actual data bins
CHARGE_VOCAB_SIZE = CHARGE_NUM_BINS + 2 # 0=PAD, 1..32=Data(bins 0..31), 33=MASK
CHARGE_MASK_IDX = CHARGE_VOCAB_SIZE - 1 # 33

# Auxiliary Vocab (Example: 2 categories -> indices 0, 1 for data)
AUX_NUM_CATS = 2
AUX_VOCAB_SIZE = AUX_NUM_CATS + 2 # 0=PAD, 1=Data(False), 2=Data(True), 3=MASK
AUX_MASK_IDX = AUX_VOCAB_SIZE - 1 # 3


class EnhancedIceCubeEmbedding(nn.Module):
    """
    Embedding layer using index replacement for masking specific features.
    Indices: 0=[PAD], LAST_IDX=[MASK]. Assumes sub-embedding dims sum
    to the final model embedding dimension (no projection layer).
    """
    def __init__(self, config, masking=False):
        super().__init__()
        self.config = config # Store the main ExperimentConfig object
        self.masking = masking
        self.mask_prob = config.training.mask_prob if masking else 0.0

        embedding_cfg = config.embedding
        model_cfg = config.model

        # --- Embedding Layers (All use padding_idx=0) ---
        self.dom_embedding = nn.Embedding(
            DOM_VOCAB_SIZE, # Use constant calculated above
            embedding_cfg.dom_embedding_dim,
            padding_idx=PAD_IDX
        )
        self.time_embedding = nn.Embedding(
            TIME_VOCAB_SIZE,
            embedding_cfg.time_embedding_dim,
            padding_idx=PAD_IDX
        )
        # --- TODO: Define Charge Embedding ---
        self.charge_embedding = nn.Embedding(
             CHARGE_VOCAB_SIZE,
             embedding_cfg.charge_embedding_dim,
             padding_idx=PAD_IDX
        )
        # --- TODO: Define Auxiliary Embedding ---
        self.aux_embedding = nn.Embedding(
             AUX_VOCAB_SIZE,
             embedding_cfg.aux_embedding_dim,
             padding_idx=PAD_IDX
        )

        # --- Calculate Total Dimension & Validate ---
        total_sub_embed_dim = (
            embedding_cfg.dom_embedding_dim +
            embedding_cfg.time_embedding_dim +
            embedding_cfg.charge_embedding_dim +
            embedding_cfg.aux_embedding_dim
        )
        # Final dimension expected by transformer blocks
        self.embedding_dim = embedding_cfg.embedding_dim
        if total_sub_embed_dim != self.embedding_dim:
             raise ValueError(
                 f"Sum of sub-embedding dimensions ({total_sub_embed_dim}) does not match "
                 f"target embedding.embedding_dim ({self.embedding_dim}). "
                 f"Add a projection layer or adjust dimensions."
             )

        # --- CLS Token ---
        # CLS embedding needs to match the final embedding dimension
        self.cls_embedding = nn.Parameter(torch.randn(1, 1, self.embedding_dim))


    def forward(self, input_batch):
        x, l = input_batch # x shape: (batch, seq_len, features), l shape: (batch,)
        batch_size, seq_len, _ = x.shape
        device = x.device
        embedding_cfg = self.config.embedding # Use embedding config

        # --- 1. Get Original Padding Mask (based on DOM ID feature input) ---
        padding_mask = (x[:, :, 3] == PAD_IDX) # Shape: (batch, seq_len)

        # --- 2. Process Time Feature to Indices ---
        time_normalized = x[:, :, 0]
        time_float_approx = time_normalized * 3e4 + 1e4
        time_float_masked_for_min = torch.where(padding_mask, torch.full_like(time_float_approx, float('inf')), time_float_approx)
        t_min_per_event = torch.min(time_float_masked_for_min, dim=1, keepdim=True)[0]
        t_min_per_event = torch.where(torch.isinf(t_min_per_event), torch.zeros_like(t_min_per_event), t_min_per_event)
        #TODO first long, then subtract!
        time_relative_float = time_float_approx - t_min_per_event
        time_relative_int = torch.round(time_relative_float).long()
        # Clip relative time *before* shifting indices
        #TODO overlflow
        time_relative_int_clipped = torch.clamp(time_relative_int, min=0, max=MAX_TIME_DURATION)
        # Shift indices (data starts from index 1) and handle padding (map padding to PAD_IDX=0)
        time_indices = torch.where(padding_mask, PAD_IDX, time_relative_int_clipped + 1) # Data indices: 1 to MAX_TIME_DURATION+1

        # --- 3. Process DOM ID Feature to Indices ---
        dom_ids_orig = x[:, :, 3].long() # 0 for pad, ID+1 for data (1 to 5161 effective range)
        # No shift needed if data is already 1-based. Padding is already 0.
        # Just need to ensure max value doesn't exceed vocab size allows (it shouldn't)
        dom_indices = dom_ids_orig # Data indices: 1 to 5160

        # --- 4. Process Charge Feature to Indices (Placeholder) ---
        charge_normalized = x[:, :, 1]
        # TODO: Implement: charge_float = 10**(charge_normalized * 3.0)
        # TODO: Implement: charge_bin_idx = self.quantize_charge(charge_float) # Output indices 0 to CHARGE_NUM_BINS-1
        # TODO: Implement: charge_bin_idx_clipped = torch.clamp(charge_bin_idx, 0, CHARGE_NUM_BINS-1)
        # Placeholder: creating dummy indices
        charge_bin_idx_clipped = torch.zeros_like(dom_indices) # Replace with real calculation
        # Shift indices (data starts from 1) and handle padding
        charge_indices = torch.where(padding_mask, PAD_IDX, charge_bin_idx_clipped + 1) # Data indices: 1 to CHARGE_NUM_BINS

        # --- 5. Process Auxiliary Feature to Indices (Placeholder) ---
        aux_normalized = x[:, :, 2] # -0.5 for False, 0.5 for True
        # Map -0.5 -> 0, 0.5 -> 1 (example base indices for the 2 categories)
        aux_base_idx = torch.round(aux_normalized + 0.5).long() # Maps -0.5 to 0, 0.5 to 1
        aux_base_idx_clipped = torch.clamp(aux_base_idx, 0, AUX_NUM_CATS - 1)
        # Shift indices (data starts from 1) and handle padding
        aux_indices = torch.where(padding_mask, PAD_IDX, aux_base_idx_clipped + 1) # Data indices: 1 (False), 2 (True)

        # --- 6. Apply Masking (Index Replacement) ---
        output_mask = None # Boolean mask indicating which positions were chosen for masking
        if self.masking:
            is_non_auxiliary = (x[:, :, 2] == -0.5)
            random_mask = torch.rand(is_non_auxiliary.shape, device=device) < self.mask_prob
            output_mask = is_non_auxiliary & random_mask & ~padding_mask # Shape: (batch, seq_len)

            # Conditionally replace indices with feature-specific MASK_IDX
            if embedding_cfg.masking_doms:
                dom_indices = torch.where(output_mask, DOM_MASK_IDX, dom_indices)
            if embedding_cfg.masking_times:
                time_indices = torch.where(output_mask, TIME_MASK_IDX, time_indices)
            if embedding_cfg.masking_charges:
                charge_indices = torch.where(output_mask, CHARGE_MASK_IDX, charge_indices)
            # if embedding_cfg.masking_aux: # Add aux masking flag if needed
            #    aux_indices = torch.where(output_mask, AUX_MASK_IDX, aux_indices)


        # --- 7. Embedding Lookups ---
        dom_embeds = self.dom_embedding(dom_indices)
        time_embeds = self.time_embedding(time_indices)
        charge_embeds = self.charge_embedding(charge_indices) # Using placeholder indices for now
        aux_embeds = self.aux_embedding(aux_indices)       # Using placeholder indices for now

        # --- 8. Combine Embeddings (No Projection Layer) ---
        combined_embeds = torch.cat([dom_embeds, time_embeds, charge_embeds, aux_embeds], dim=2)
        # Shape: (batch, seq_len, embedding_dim) - Assumes sum matches target dim

        # --- 9. Prepend CLS Token ---
        cls_token_expanded = self.cls_embedding.expand(batch_size, -1, -1)
        full_embedding = torch.cat([cls_token_expanded, combined_embeds], dim=1) # Shape: (batch, seq_len+1, embedding_dim)

        # --- 10. Create Final Padding Mask (for Transformer Attention) ---
        final_padding_mask = torch.cat([
             torch.zeros(batch_size, 1, dtype=torch.bool, device=device), # CLS padding (False)
             padding_mask # Original sequence padding mask
        ], dim=1) # Shape: (batch, seq_len+1)

        # --- 11. Return ---
        return full_embedding, final_padding_mask, output_mask

# Enhanced Embedding Layer for PolarBERT

**Goal:** Improve model performance and physical relevance by replacing the simple linear projection of features with learned embeddings, particularly for time and charge, while correctly handling relative time and masking.

**Branch:** `eat/time-embedding`

**Key Changes & Implementation Steps (GPU-based):**

1.  **Input Processing (Inside `EnhancedIceCubeEmbedding.forward`):**
    * Receive batch `(x, l)` from the dataloader (where `x` contains normalized float features).
    * Perform all subsequent transformations on the GPU using PyTorch tensor operations.

2.  **Time Feature (`x[:, :, 0]`):**
    * **Relative Time Calculation:**
        * Invert normalization: `t_float = t_norm * 3e4 + 1e4`.
        * Calculate `t_min` per event, carefully ignoring padding values (mask with `inf` before `torch.min`).
        * Calculate relative time: `t_relative = t_float - t_min`.
    * **Integer Conversion & Clipping:**
        * Round `t_relative` to the nearest integer (`torch.round().long()`).
        * Clip the integer time to a predefined maximum (`MAX_TIME_VOCAB`, e.g., 51996) using `torch.clamp()`.
    * **Padding Index:** Use `torch.where()` to replace values in originally padded positions with a dedicated `TIME_PADDING_IDX` (e.g., 51997).
    * **Embedding:** Use an `nn.Embedding(time_vocab_size, time_embed_dim, padding_idx=TIME_PADDING_IDX)` layer with the resulting integer indices. `time_vocab_size` needs to accommodate `MAX_TIME_VOCAB`, `PAD`, `MASK`, etc. (e.g., 52000).

3.  **Charge Feature (`x[:, :, 1]`):**
    * **TODO:**
        * Define a quantization strategy (e.g., log binning, quantile binning based on data analysis).
        * Invert normalization: `log10_charge = charge_norm * 3.0`, then `charge = 10**log10_charge`.
        * Apply quantization function to get integer bin indices.
        * Handle padding.
        * Use an `nn.Embedding` layer for charge bins.

4.  **Auxiliary Feature (`x[:, :, 2]`):**
    * Can use a small `nn.Embedding` (e.g., 2-3 embeddings for True/False/Padding) or keep a simple linear projection. Needs careful handling of the input value (`aux - 0.5`).
    * Crucially, the *original* value of this feature is needed to determine which pulses *not* to mask during the masking step.

5.  **DOM ID Feature (`x[:, :, 3]`):**
    * Input is `sensor_id + 1`, padding is `0`.
    * Use `nn.Embedding(dom_vocab_size, dom_embed_dim, padding_idx=0)`. `dom_vocab_size` includes actual IDs + PAD + MASK token (e.g., 5160 + 2).

6.  **Combining Embeddings:**
    * Concatenate the embeddings for Time, DOM, Charge, and Auxiliary.
    * Use a final `nn.Linear` layer to project the concatenated vector to the model's target `embedding_dim`.

7.  **Special Tokens & Masking:**
    * **CLS Token:** Prepend a learned `cls_embedding` parameter (`nn.Parameter`) to the sequence *after* processing pulse embeddings.
    * **Masking:**
        * If `masking=True`, calculate a boolean `output_mask` based on `mask_prob`, ensuring auxiliary pulses (`x[:, :, 2] == 0.5`) and padded pulses are *not* masked.
        * Define a learned `mask_token_embed` parameter (`nn.Parameter` of size `embedding_dim`).
        * Use `torch.where(output_mask.unsqueeze(-1), mask_token_embed, projected_embeddings)` to replace the embeddings of masked positions with the learned mask embedding *after* the projection layer.
    * **Padding Mask:** Generate the final `final_padding_mask` (for the transformer attention) including the prepended CLS token (which is never padded).

8.  **Output:** The `forward` method should return `full_embedding`, `final_padding_mask`, and `output_mask` (if masking is enabled).

In [7]:
from torch import nn
from torch.nn import functional as F
import torch

In [10]:
emb = nn.Embedding(10, 4, padding_idx=2)

In [12]:
emb(torch.tensor([0]))

tensor([[-2.2103,  1.2373,  1.5075,  1.8231]], grad_fn=<EmbeddingBackward0>)

In [3]:
from polarbert.pretraining import get_dataloaders, load_and_process_config

In [10]:
config = load_and_process_config('/groups/pheno/inar/PolarBERT/configs/polarbert_IT.yaml')
config['data']['train_events'] = 100_000
config['training']['per_device_batch_size'] = 32

In [11]:
train_loader, val_loader = get_dataloaders(config, dataset_type='kaggle')

In [12]:
for i, batch in enumerate(train_loader):
    break

In [None]:
#TODO improve docs to specify what are the collumns in the data!

# T_evt[:len(selected_idx), 0] = (time[selected_idx] - 1e4) / 3e4
# T_evt[:len(selected_idx), 1] = np.log10(charge[selected_idx]) / 3.0
# T_evt[:len(selected_idx), 2] = auxiliary[selected_idx] - 0.5 # aux = True is BAD, so 0.5 is bad
# T_evt[:len(selected_idx), 3] = sensor_id[selected_idx] + 1  # +1 is needed since we use 0 for padding

In [15]:
(x, l), (y, c) = batch

Notice that we pad everything by zero!

In [35]:
import torch
times = x[0][:,0] * 3e4 + 1e4
print(times)
# Round times to integers using torch
times_rounded = torch.round(times).int()
times_rounded

tensor([ 6019.2871,  6103.5156,  6169.4336,  6367.1875,  6376.3428,  6425.7812,
         6520.9961,  6605.2246,  6989.7461,  7495.1172,  7745.9717,  8077.3926,
         8168.9453,  8253.1738,  8359.3750,  9252.0137,  9328.0029,  9365.9971,
         9516.1436,  9547.9580,  9658.0508,  9720.0771,  9878.0059, 10086.9746,
        10100.9941, 10102.0244, 10128.9746, 10144.9971, 10184.0215, 10319.9766,
        10567.1689, 10846.8633, 11046.4473, 11126.0986, 11320.1904, 12532.3486,
        12817.9932, 13537.5977, 13958.7402, 14094.2383, 14101.5625, 14156.4941,
        15240.4785, 15280.7617, 16134.0332, 10000.0000, 10000.0000, 10000.0000,
        10000.0000, 10000.0000, 10000.0000, 10000.0000, 10000.0000, 10000.0000,
        10000.0000, 10000.0000, 10000.0000, 10000.0000, 10000.0000, 10000.0000,
        10000.0000, 10000.0000, 10000.0000, 10000.0000, 10000.0000, 10000.0000,
        10000.0000, 10000.0000, 10000.0000, 10000.0000, 10000.0000, 10000.0000,
        10000.0000, 10000.0000, 10000.00

tensor([ 6019,  6104,  6169,  6367,  6376,  6426,  6521,  6605,  6990,  7495,
         7746,  8077,  8169,  8253,  8359,  9252,  9328,  9366,  9516,  9548,
         9658,  9720,  9878, 10087, 10101, 10102, 10129, 10145, 10184, 10320,
        10567, 10847, 11046, 11126, 11320, 12532, 12818, 13538, 13959, 14094,
        14102, 14156, 15240, 15281, 16134, 10000, 10000, 10000, 10000, 10000,
        10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000,
        10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000,
        10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000,
        10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000,
        10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000,
        10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000,
        10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000,
        10000, 10000, 10000, 10000, 10000, 10000, 10000], dtype=

In [36]:
import torch
import torch.nn as nn
import numpy as np # Only needed for np.inf if torch.inf is not available

# --- Define Constants ---
# Based on proposed vocab_size = 52000 for time
MAX_TIME_VOCAB = 51996   # Max relative time value allowed (0 to 51996 ns)
TIME_PADDING_IDX = 51997 # Index for padded time positions
TIME_MASK_IDX = 51998    # Reserved index for MASK token in time vocab
# TIME_UNUSED_IDX = 51999 # Reserved index

DOM_PADDING_IDX = 0      # Padding index used for DOM IDs in input data
DOM_MASK_IDX = 5161      # Reserved index for MASK token in DOM vocab (5160 real DOMs)
DOM_VOCAB_SIZE = 5160 + 2 # 5160 DOMs + PAD + MASK

class EnhancedIceCubeEmbedding(nn.Module):
    """
    Embedding layer for IceCube data using learned embeddings for time (relative),
    DOM ID, and potentially charge/auxiliary features. Performs transformations
    on-the-fly on the GPU within the forward pass. Handles masking correctly.
    """
    def __init__(self, config, masking=False):
        super().__init__()
        self.config = config
        self.masking = masking
        self.mask_prob = config['training']['mask_prob'] if masking else 0.0

        embedding_dim = config['model']['embedding_dim']
        # Get sub-embedding dimensions from config, with defaults
        dom_embed_dim = config['model'].get('dom_embed_dim', 64)
        time_embed_dim = config['model'].get('time_embed_dim', 128)
        charge_embed_dim = config['model'].get('charge_embed_dim', 32) # Placeholder size
        aux_embed_dim = config['model'].get('aux_embed_dim', 32)       # Placeholder size

        time_vocab_size = 52000 # Includes PAD, MASK etc.

        # --- Embedding Layers ---
        self.dom_embedding = nn.Embedding(DOM_VOCAB_SIZE, dom_embed_dim, padding_idx=DOM_PADDING_IDX)
        self.time_embedding = nn.Embedding(time_vocab_size, time_embed_dim, padding_idx=TIME_PADDING_IDX)

        # --- TODO: Define Charge Embedding ---
        # Example: Placeholder - simple projection for now
        charge_input_dim = 1 # log10(charge)/3.0 from input feature 1
        self.charge_proj = nn.Linear(charge_input_dim, charge_embed_dim)
        # Replace above with quantization + nn.Embedding when ready

        # --- TODO: Define Auxiliary Embedding ---
        # Example: Placeholder - simple projection for now
        aux_input_dim = 1 # aux - 0.5 from input feature 2
        self.aux_proj = nn.Linear(aux_input_dim, aux_embed_dim)
        # Could also use nn.Embedding(3, aux_embed_dim) if mapping -0.5, 0.5 to indices

        # --- Mask Token Parameter ---
        if self.masking:
             # Learned embedding vector for the [MASK] token
             self.mask_token_embed = nn.Parameter(torch.randn(1, 1, embedding_dim))

        # --- Projection Layer ---
        # Calculates total dimension from individual embeddings/projections
        total_sub_embed_dim = dom_embed_dim + time_embed_dim + charge_embed_dim + aux_embed_dim
        self.projection = nn.Linear(total_sub_embed_dim, embedding_dim)

        # --- CLS Token ---
        self.cls_embedding = nn.Parameter(torch.randn(1, 1, embedding_dim))


    def forward(self, input_batch):
        x, l = input_batch # x shape: (batch, seq_len, features), l shape: (batch,)
        batch_size, seq_len, _ = x.shape
        device = x.device

        # --- 1. Get Original Padding Mask (based on DOM ID) ---
        # Input x[:, :, 3] has 0 for padding, and sensor_id+1 otherwise.
        padding_mask = (x[:, :, 3] == DOM_PADDING_IDX) # Shape: (batch, seq_len)

        # --- 2. Process Time Feature ---
        time_normalized = x[:, :, 0]
        # Invert normalization
        time_float_approx = time_normalized * 3e4 + 1e4
        # Calculate t_min, ignoring padding
        time_float_masked_for_min = torch.where(padding_mask, torch.full_like(time_float_approx, float('inf')), time_float_approx)
        t_min_per_event = torch.min(time_float_masked_for_min, dim=1, keepdim=True)[0]
        t_min_per_event = torch.where(torch.isinf(t_min_per_event), torch.zeros_like(t_min_per_event), t_min_per_event) # Handle all-padding events
        # Calculate relative time
        time_relative_float = time_float_approx - t_min_per_event
        # Round, cast to long, clip
        time_relative_int = torch.round(time_relative_float).long()
        time_relative_int_clipped = torch.clamp(time_relative_int, min=0, max=MAX_TIME_VOCAB)
        # Apply padding index
        time_indices = torch.where(padding_mask, torch.full_like(time_relative_int_clipped, TIME_PADDING_IDX), time_relative_int_clipped)
        # Get time embeddings
        time_embeds = self.time_embedding(time_indices) # Shape: (batch, seq_len, time_embed_dim)

        # --- 3. Process DOM ID Feature ---
        dom_ids = x[:, :, 3].long() # Already prepared: 0 for pad, ID+1 otherwise
        dom_embeds = self.dom_embedding(dom_ids) # Shape: (batch, seq_len, dom_embed_dim)

        # --- 4. Process Charge Feature (Placeholder) ---
        charge_normalized = x[:, :, 1].unsqueeze(-1) # Keep feature dim
        # TODO: Invert normalization, quantize, use nn.Embedding
        charge_embeds = self.charge_proj(charge_normalized) # Using projection for now

        # --- 5. Process Auxiliary Feature (Placeholder) ---
        aux_normalized = x[:, :, 2].unsqueeze(-1) # Keep feature dim
        # TODO: Implement nn.Embedding lookup based on value (-0.5 or 0.5)
        aux_embeds = self.aux_proj(aux_normalized) # Using projection for now

        # --- 6. Combine Embeddings & Project ---
        combined_sub_embeds = torch.cat([dom_embeds, time_embeds, charge_embeds, aux_embeds], dim=2)
        projected_embeds = self.projection(combined_sub_embeds) # Shape: (batch, seq_len, embedding_dim)

        # --- 7. Apply Masking (if enabled) ---
        output_mask = None # Mask tensor to be returned for loss calculation
        if self.masking:
            # Identify non-auxiliary pulses (original aux feature value was -0.5)
            is_non_auxiliary = (x[:, :, 2] == -0.5)
            # Generate random mask based on probability
            random_mask = torch.rand(is_non_auxiliary.shape, device=device) < self.mask_prob

            # Combine conditions: Mask only non-auxiliary, non-padded positions that are randomly selected
            output_mask = is_non_auxiliary & random_mask & ~padding_mask # Shape: (batch, seq_len)

            # Replace the projected embeddings with the learned mask token embedding
            projected_embeds = torch.where(
                output_mask.unsqueeze(-1), # Expand mask shape for broadcasting
                self.mask_token_embed.expand(batch_size, seq_len, -1), # Expand mask token
                projected_embeds # Keep original embedding if not masked
            )

        # --- 8. Prepend CLS Token ---
        cls_token_expanded = self.cls_embedding.expand(batch_size, -1, -1)
        full_embedding = torch.cat([cls_token_expanded, projected_embeds], dim=1) # Shape: (batch, seq_len+1, embedding_dim)

        # --- 9. Create Final Padding Mask (for Transformer Attention) ---
        # Includes position for CLS token (never padded)
        final_padding_mask = torch.cat([
             torch.zeros(batch_size, 1, dtype=torch.bool, device=device), # CLS padding (False)
             padding_mask # Original sequence padding mask
        ], dim=1) # Shape: (batch, seq_len+1)

        # --- 10. Return necessary outputs ---
        # output_mask is the boolean mask indicating which input tokens were masked (needed for loss)
        # final_padding_mask indicates padding for the attention mechanism (includes CLS)
        return full_embedding, final_padding_mask, output_mask