**TODO** Write proper tests using pytest

In [1]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import numpy as np
import os
import yaml # Make sure PyYAML is installed

# --- Assume these are in your project structure ---
# Make sure paths are correct or install your package
try:
    from polarbert.config import PolarBertConfig
    # Assuming the new model is PolarBertModel and uses IceCubeTimeEmbedding
    from polarbert.time_embed_polarbert import PolarBertModel
    from polarbert.time_embedding import IceCubeTimeEmbedding # Or the actual name you used
except ImportError as e:
    print(f"Error importing necessary modules: {e}")
    print("Please ensure your classes PolarBertConfig, PolarBertModel, IceCubeTimeEmbedding are accessible.")
    # You might need to adjust sys.path or install your package (`pip install -e .`)
    # import sys
    # sys.path.append('/path/to/your/src') # Example if not installed
    # from config import PolarBertConfig
    # from model import PolarBertModel
    # from time_embedding import IceCubeTimeEmbedding

In [2]:

# --- Test Setup ---
print("--- Sanity Check Setup ---")
# 1. Load Config
# Make sure this path points to your YAML file
config_file = "/groups/pheno/inar/PolarBERT/configs/polarbert_new.yaml"
try:
    config = PolarBertConfig.from_yaml(config_file)
    print(f"Successfully loaded config from {config_file}")
except Exception as e:
    print(f"FATAL: Could not load config file: {e}")
    # Stop execution or handle error
    raise e

# 2. Define Device
if torch.cuda.is_available():
    # Explicitly use cuda:0
    device = torch.device("cuda:0")
    print(f"Using CUDA device 0: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("CUDA not available, using CPU.")


# 3. Create a Dummy Batch
B, L_orig, Ff = 4, 127, 4 # Batch size, Original Seq Len, Num Features
dummy_x = torch.rand(B, L_orig, Ff, dtype=torch.float32)
# Simulate some padding (last 10 elements for batch 0, last 20 for batch 1 etc.)
dummy_l = torch.tensor([L_orig - 10, L_orig - 20, L_orig - 5, L_orig], dtype=torch.long)
for i in range(B):
    if dummy_l[i] < L_orig:
        dummy_x[i, dummy_l[i]:, :] = 0.0 # Set padding features to 0
        dummy_x[i, dummy_l[i]:, 3] = 0.0 # Ensure DOM ID is 0 for padding

# Simulate some plausible feature values (replace rand if needed)
# Feature 3 (DOM ID): Should be > 0 for non-padded
non_pad_mask_init = (dummy_x[:, :, 3] != 0)
num_non_pad = non_pad_mask_init.sum()
dummy_x[:,:,3][non_pad_mask_init] = torch.randint(1, 5161, size=(num_non_pad,)).float() + 1.0 # DOM ID + 1

# Feature 2 (Aux): -0.5 or 0.5
aux_vals = torch.randint(0, 2, size=(num_non_pad,)).float() - 0.5
dummy_x[:,:,2][non_pad_mask_init] = aux_vals

# Add dummy labels (y, c) - required for _calculate_loss / training_step
# Shape depends on what your _calculate_loss expects from y_data tuple
# Example: y_data = (y_angles_ignored, true_total_charge)
dummy_y_angles = torch.rand(B, 2) # Placeholder, not used in pretrain loss
dummy_charge_target = torch.rand(B) * 1000 + 1 # Example total charges
dummy_y_data = (dummy_y_angles, dummy_charge_target)

dummy_batch = ((dummy_x, dummy_l), dummy_y_data)
print(f"Created dummy batch: x shape {dummy_x.shape}, l shape {dummy_l.shape}")


--- Sanity Check Setup ---
Loading configuration from: /groups/pheno/inar/PolarBERT/configs/polarbert_new.yaml
Validating configuration...
Successfully loaded config from /groups/pheno/inar/PolarBERT/configs/polarbert_new.yaml
Using CUDA device 0: NVIDIA H100 NVL
Created dummy batch: x shape torch.Size([4, 127, 4]), l shape torch.Size([4])


In [3]:
# --- Test 1: Initialization ---
print("\n--- Test 1: Initialization ---")
try:
    # Ensure mask_prob exists in config if masking=True
    # We'll force masking=True for thorough tests
    # config.model.embedding.masking_prob = 0.15 # Make sure it's set
    model = PolarBertModel(config) # Assumes masking is handled by config
    print("PASS: Model initialized successfully.")
except Exception as e:
    print(f"FAIL: Model initialization failed: {e}")
    raise e



--- Test 1: Initialization ---
INFO: Concatenated embeddings directly match model embedding dim. No projection layer used.
PASS: Model initialized successfully.


In [4]:
# --- Test 2: Device Placement ---
print("\n--- Test 2: Device Placement ---")
try:
    model.to(device)
    # Move all parts of the batch to the device
    dummy_x_dev = dummy_x.to(device)
    dummy_l_dev = dummy_l.to(device)
    dummy_y_angles_dev = dummy_y_angles.to(device)
    dummy_charge_target_dev = dummy_charge_target.to(device)
    dummy_batch_dev = ((dummy_x_dev, dummy_l_dev), (dummy_y_angles_dev, dummy_charge_target_dev))

    # --- Debugging Print ---
    model_device = next(model.parameters()).device
    print(f"Target device: {device} (type: {type(device)})")
    print(f"Model parameter device: {model_device} (type: {type(model_device)})")
    # --- End Debugging Print ---

    assert model_device == device # Check model device
    assert dummy_batch_dev[0][0].device == device # Check input x device
    assert dummy_batch_dev[0][1].device == device # Check input l device
    assert dummy_batch_dev[1][1].device == device # Check target c device
    print(f"PASS: Model and batch moved to {device}.")
except Exception as e:
    print(f"FAIL: Moving model/batch to device failed: {e}")
    raise e


--- Test 2: Device Placement ---
Target device: cuda:0 (type: <class 'torch.device'>)
Model parameter device: cuda:0 (type: <class 'torch.device'>)
PASS: Model and batch moved to cuda:0.


In [5]:

# --- Test 3: Forward Pass & Shapes (Eval Mode) ---
print("\n--- Test 3: Forward Pass & Shapes (Eval Mode) ---")
model.eval() # Set to evaluation mode (disables dropout)
try:
    with torch.no_grad(): # No need to track gradients here
        dom_logits, charge_pred, output_mask, seq_padding_mask = model.forward(dummy_batch_dev)

    # Check shapes
    E = config.model.embedding_dim
    num_dom_classes = config.model.embedding.dom_vocab_size - 2
    expected_dom_shape = (B, L_orig, num_dom_classes)
    expected_charge_shape = (B, 1)
    expected_mask_shape = (B, L_orig) # Output mask has original length
    expected_pad_mask_shape = (B, L_orig) # Seq padding mask has original length

    assert dom_logits.shape == expected_dom_shape, f"DOM Logits shape mismatch: Expected {expected_dom_shape}, Got {dom_logits.shape}"
    assert charge_pred.shape == expected_charge_shape, f"Charge Pred shape mismatch: Expected {expected_charge_shape}, Got {charge_pred.shape}"
    # Output mask might be None if masking is off in config/init, handle this
    if hasattr(model.embedding, 'masking') and model.embedding.masking:
         assert output_mask is not None, "Output mask should not be None when masking is enabled"
         assert output_mask.shape == expected_mask_shape, f"Output Mask shape mismatch: Expected {expected_mask_shape}, Got {output_mask.shape}"
    else:
         assert output_mask is None, "Output mask should be None when masking is disabled"
    assert seq_padding_mask.shape == expected_pad_mask_shape, f"Seq Padding Mask shape mismatch: Expected {expected_pad_mask_shape}, Got {seq_padding_mask.shape}"

    # Check device and dtype
    assert dom_logits.device == device, "Output device mismatch"
    # Dtype might depend on precision setting ('16-mixed' might output float16)
    print(f"Output dtypes: dom={dom_logits.dtype}, charge={charge_pred.dtype}")

    print("PASS: Forward pass successful, shapes and device are correct.")
    # Store outputs for next test
    outputs1 = (dom_logits, charge_pred, output_mask, seq_padding_mask)
except Exception as e:
    print(f"FAIL: Forward pass failed: {e}")
    raise e


--- Test 3: Forward Pass & Shapes (Eval Mode) ---
Output dtypes: dom=torch.float32, charge=torch.float32
PASS: Forward pass successful, shapes and device are correct.


In [6]:
# --- Test 4: Batch Consistency (Eval Mode) ---
print("\n--- Test 4: Batch Consistency (Eval Mode) ---")
try:
    print(f"Running first forward pass in eval mode (model.training={model.training})...")
    assert not model.training # Double check eval mode is set
    with torch.no_grad():
        outputs1 = model.forward(dummy_batch_dev)

    print(f"Running second forward pass in eval mode (model.training={model.training})...")
    assert not model.training # Double check eval mode is set
    with torch.no_grad():
        outputs2 = model.forward(dummy_batch_dev)

    # --- Detailed Comparison ---
    dom_logits1, charge_pred1, output_mask1, seq_padding_mask1 = outputs1
    dom_logits2, charge_pred2, output_mask2, seq_padding_mask2 = outputs2

    # 1. Compare Output Masks (Boolean)
    mask_identical = False
    if output_mask1 is not None and output_mask2 is not None:
        mask_identical = torch.equal(output_mask1, output_mask2)
        print(f"Output masks are identical: {mask_identical}")
        assert mask_identical, "Output masks differ between runs!"
    elif output_mask1 is None and output_mask2 is None:
        mask_identical = True # Both None is consistent
        print("Output masks are identical (both None)")
    else:
        print(f"FAIL: Output mask presence mismatch ({type(output_mask1)} vs {type(output_mask2)})")
        assert False, "Output mask presence mismatch"

    # 2. Compare Padding Masks (Boolean)
    pad_mask_identical = torch.equal(seq_padding_mask1, seq_padding_mask2)
    print(f"Seq padding masks are identical: {pad_mask_identical}")
    assert pad_mask_identical, "Seq padding masks differ between runs!"

    # 3. Compare Charge Predictions (Float)
    charge_diff = torch.abs(charge_pred1 - charge_pred2).max().item()
    charge_close = torch.allclose(charge_pred1, charge_pred2, atol=1e-6)
    print(f"Charge predictions are allclose (atol=1e-6): {charge_close} (Max Abs Diff: {charge_diff:.2e})")
    assert charge_close, "Charge predictions differ between runs"

    # 4. Compare DOM Logits (Float) - Fails Here
    dom_diff = torch.abs(dom_logits1 - dom_logits2).max().item()
    dom_close_strict = torch.allclose(dom_logits1, dom_logits2, atol=1e-6)
    dom_close_loose = torch.allclose(dom_logits1, dom_logits2, atol=1e-1) # The one that failed
    print(f"DOM logits are allclose (atol=1e-6): {dom_close_strict} (Max Abs Diff: {dom_diff:.2e})")
    print(f"DOM logits are allclose (atol=1e-1): {dom_close_loose}")

    # Raise the original error if the loose check failed
    assert dom_close_loose, "DOM logits differ between runs (even with atol=1e-1)"

    print("PASS: Model output is consistent for the same input in eval mode.") # Will only reach here if all asserts pass

except Exception as e:
    print(f"FAIL: Consistency check failed: {e}")
    # Don't re-raise immediately, let the prints show results
    # raise e # Optionally re-raise after printing debug info


--- Test 4: Batch Consistency (Eval Mode) ---
Running first forward pass in eval mode (model.training=False)...
Running second forward pass in eval mode (model.training=False)...
Output masks are identical: False
FAIL: Consistency check failed: Output masks differ between runs!


In [7]:

# --- Test 5: Loss Calculation (Train Mode) ---
print("\n--- Test 5: Loss Calculation (Train Mode) ---")
model.train() # Set to training mode
try:
    # Can call _calculate_loss directly or training_step
    # training_step requires a batch_idx
    combined_loss, dom_loss, charge_loss = model._calculate_loss(dummy_batch_dev)
    # loss_from_step = model.training_step(dummy_batch_dev, batch_idx=0)

    # Check if loss is a scalar tensor
    assert isinstance(combined_loss, torch.Tensor), "Loss is not a Tensor"
    assert combined_loss.ndim == 0, f"Loss should be scalar, but has shape {combined_loss.shape}"
    assert combined_loss.requires_grad, "Loss does not require gradients"
    # Check if loss is finite
    assert torch.isfinite(combined_loss), "Loss is not finite!"

    print(f"PASS: Loss calculated successfully: Combined={combined_loss.item():.4f}, DOM={dom_loss.item():.4f}, Charge={charge_loss.item():.4f}")
    loss_for_backward = combined_loss # Store for next test
except Exception as e:
    print(f"FAIL: Loss calculation failed: {e}")
    raise e


--- Test 5: Loss Calculation (Train Mode) ---
PASS: Loss calculated successfully: Combined=14.1082, DOM=8.6455, Charge=5.4627


In [8]:

# --- Test 6: Basic Gradient Flow (Train Mode) ---
print("\n--- Test 6: Basic Gradient Flow (Train Mode) ---")
model.train()
# Use a simple optimizer just for the test
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
try:
    optimizer.zero_grad()
    loss_for_backward.backward()
    optimizer.step() # Optional: check if step works

    # Check if a key parameter has a gradient
    # Example: first weight of the CLS embedding or a weight in the first block's attention
    checked_grad = False
    if model.embedding.cls_embedding.grad is not None:
         print(f"Gradient found for embedding.cls_embedding: mean abs = {model.embedding.cls_embedding.grad.abs().mean().item()}")
         checked_grad = True
    # Check grad of a weight in the first transformer block's attention input projection
    elif hasattr(model, 'transformer_blocks') and len(model.transformer_blocks) > 0 and hasattr(model.transformer_blocks[0], 'attn'):
         attn_layer = model.transformer_blocks[0].attn.mha # Get the underlying MHA
         if attn_layer.in_proj_weight.grad is not None:
              print(f"Gradient found for transformer_blocks[0].attn.mha.in_proj_weight: mean abs = {attn_layer.in_proj_weight.grad.abs().mean().item()}")
              checked_grad = True

    if not checked_grad:
         print("Warning: Could not easily check gradient for standard layers. Check manually.")
         # You might need to inspect other parameters like linear layer weights.
         # Check if *any* parameter has a non-None gradient
         found_any_grad = any(p.grad is not None for p in model.parameters() if p.requires_grad)
         assert found_any_grad, "No gradients found after backward pass!"
         print("OK: At least one parameter received gradients.")


    print("PASS: Backward pass executed and gradients seem to be present.")
    optimizer.zero_grad() # Clean up gradients
except Exception as e:
    print(f"FAIL: Backward pass or gradient check failed: {e}")
    raise e

print("\n--- Sanity Checks Complete ---")


--- Test 6: Basic Gradient Flow (Train Mode) ---
Gradient found for embedding.cls_embedding: mean abs = 0.14659950137138367
PASS: Backward pass executed and gradients seem to be present.

--- Sanity Checks Complete ---
