In [1]:
import torch
import os
import warnings


In [2]:

# --- Assuming your project structure allows these imports ---
try:
    # Import necessary classes from your project
    from polarbert.config import PolarBertConfig
    from polarbert.time_embed_polarbert import PolarBertModel # Pre-training model class
    from polarbert.te_finetuning import PolarBertFinetuner # Fine-tuning model class
except ImportError as e:
    print(f"Error importing project modules: {e}")
    print("Please ensure your notebook's environment can access the polarbert package.")
    # You might need to add the source directory to sys.path
    # import sys
    # sys.path.append('/path/to/your/PolarBERT/src') # Adjust path if needed
    # from polarbert.config import PolarBertConfig
    # ... etc.
    raise e


In [3]:

# --- Define Checkpoint Paths ---
pretrained_ckpt_path = "/groups/pheno/inar/PolarBERT/checkpoints/te_polarbert_100M_2ep_250414-011046/last.ckpt"
finetuned_ckpt_path = "/groups/pheno/inar/PolarBERT/checkpoints/finetune_direction_kaggle_250428-114235/last.ckpt"


In [8]:

# --- Helper Function to Load State Dict ---
# (Handles potential 'state_dict' nesting in checkpoint file)
def load_state_dict_from_checkpoint(ckpt_path):
    if not os.path.isfile(ckpt_path):
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
    try:
        # Use weights_only=True for security
        checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=True)
        # Check if state_dict is nested (common in PL checkpoints)
        if 'state_dict' in checkpoint:
            return checkpoint['state_dict']
        else:
            # Assume the checkpoint directly contains the state_dict
            return checkpoint
    except Exception as e:
        print(f"Error loading state dictionary from {ckpt_path}: {e}")
        raise

# --- 1. Load and Inspect Pre-trained Model ---
print(f"--- Loading Pre-trained Model ---")
print(f"Checkpoint: {pretrained_ckpt_path}")

try:
    # Load config associated with the pre-trained checkpoint
    # Assumes config.yaml is in the same directory
    pretrained_config = PolarBertConfig.from_checkpoint(pretrained_ckpt_path)
    print("Pre-trained config loaded.")

    # Instantiate the pre-training model architecture
    pretrained_model = PolarBertModel(pretrained_config)
    print("Pre-trained model (PolarBertModel) instantiated.")

    # Load the state dict
    pretrained_state_dict = load_state_dict_from_checkpoint(pretrained_ckpt_path)

    # Load weights into the model instance
    # Use strict=False initially to see missing/unexpected keys easily
    missing_keys, unexpected_keys = pretrained_model.load_state_dict(
        pretrained_state_dict, strict=False
    )
    print("Pre-trained state dict loaded into model instance.")
    if missing_keys:
        print(f"  Missing keys: {missing_keys}")
    if unexpected_keys:
        print(f"  Unexpected keys: {unexpected_keys}") # Should ideally be empty

    print("\nPre-trained Model Named Modules:")
    # List top-level modules and prediction heads specifically
    found_modules = []
    for name, module in pretrained_model.named_modules():
        # Print top-level modules or specific heads of interest
        if '.' not in name or name in ['dom_head', 'charge_head', 'embedding', 'final_norm']:
             module_info = f"  {name:<30} | {module.__class__.__name__}"
             # Add parameter count for Linear layers
             if isinstance(module, torch.nn.Linear):
                 num_params = sum(p.numel() for p in module.parameters())
                 module_info += f" | Params: {num_params:,}"
             print(module_info)
             found_modules.append(name)
    if not found_modules:
        print("  (Could not list modules - check model definition)")


except Exception as e:
    print(f"!!! Error loading pre-trained model: {e}")
    import traceback
    traceback.print_exc()


--- Loading Pre-trained Model ---
Checkpoint: /groups/pheno/inar/PolarBERT/checkpoints/te_polarbert_100M_2ep_250414-011046/last.ckpt
Loading config associated with checkpoint last.ckpt
Loading configuration from: /groups/pheno/inar/PolarBERT/checkpoints/te_polarbert_100M_2ep_250414-011046/config.yaml
Validating configuration...
Pre-trained config loaded.
INFO: Concatenated embeddings directly match model embedding dim. No projection layer used.
Pre-trained model (PolarBertModel) instantiated.
Pre-trained state dict loaded into model instance.

Pre-trained Model Named Modules:
                                 | PolarBertModel
  embedding                      | IceCubeTimeEmbedding
  transformer_blocks             | ModuleList
  final_norm                     | RMSNorm
  dom_head                       | Linear | Params: 1,326,120
  charge_head                    | Linear | Params: 257


In [9]:

# --- 2. Load and Inspect Fine-tuned Model ---
print(f"\n--- Loading Fine-tuned Model ---")
print(f"Checkpoint: {finetuned_ckpt_path}")

try:
    # Load config associated with the fine-tuned checkpoint
    # Assumes config.yaml is in the same directory
    finetuned_config = PolarBertConfig.from_checkpoint(finetuned_ckpt_path)
    print("Fine-tuned config loaded.")

    # Instantiate the fine-tuning model architecture
    # Pass None for pretrained_checkpoint_path to prevent auto-loading backbone here
    # We will load the *entire* fine-tuned state dict afterwards
    finetuned_model = PolarBertFinetuner(finetuned_config, pretrained_checkpoint_path=None)
    print("Fine-tuned model (PolarBertFinetuner) instantiated.")

    # Load the state dict from the fine-tuned checkpoint
    finetuned_state_dict = load_state_dict_from_checkpoint(finetuned_ckpt_path)

    # Load weights into the model instance
    # Use strict=False for inspection
    missing_keys_ft, unexpected_keys_ft = finetuned_model.load_state_dict(
        finetuned_state_dict, strict=False
    )
    print("Fine-tuned state dict loaded into model instance.")
    if missing_keys_ft:
        print(f"  Missing keys: {missing_keys_ft}") # Should ideally be empty
    if unexpected_keys_ft:
        print(f"  Unexpected keys: {unexpected_keys_ft}") # Should ideally be empty

    print("\nFine-tuned Model Named Modules:")
    # List top-level modules and heads specifically
    found_modules_ft = []
    for name, module in finetuned_model.named_modules():
         # Print top-level modules or specific parts like 'backbone' or 'prediction_head'
         if '.' not in name or name in ['backbone', 'prediction_head'] or 'backbone.' in name and name.count('.')==1 or 'prediction_head.' in name :
              module_info = f"  {name:<30} | {module.__class__.__name__}"
              if isinstance(module, torch.nn.Linear):
                   num_params = sum(p.numel() for p in module.parameters())
                   module_info += f" | Params: {num_params:,}"
              # Avoid printing every single layer inside backbone unless desired
              if name.startswith('backbone') and name.count('.') > 1:
                   # Example: skip printing individual transformer blocks etc.
                   # Add more specific checks if needed
                   if 'transformer_blocks.' in name or 'embedding.' in name:
                       continue
              print(module_info)
              found_modules_ft.append(name)

    if not found_modules_ft:
        print("  (Could not list modules - check model definition)")


except Exception as e:
    print(f"!!! Error loading fine-tuned model: {e}")
    import traceback
    traceback.print_exc()

print("\n--- Inspection Complete ---")
print("Compare the named modules above.")
print("Key things to look for:")
print(" - Pre-trained model: Should have `embedding`, `transformer_blocks`, `final_norm`, `dom_head`, `charge_head`.")
print(" - Fine-tuned model: Should have `backbone` (containing embedding, blocks, norm) and `prediction_head` (likely an nn.Sequential). It should *not* have top-level `dom_head` or `charge_head`.")


--- Loading Fine-tuned Model ---
Checkpoint: /groups/pheno/inar/PolarBERT/checkpoints/finetune_direction_kaggle_250428-114235/last.ckpt
Loading config associated with checkpoint last.ckpt
Loading configuration from: /groups/pheno/inar/PolarBERT/checkpoints/finetune_direction_kaggle_250428-114235/config.yaml
Validating configuration...
Fine-tuned config loaded.
INFO: Concatenated embeddings directly match model embedding dim. No projection layer used.
No pretrained checkpoint provided or 'new' specified. Training backbone from scratch.
Initialized head for task: direction with hidden size: 1024
Fine-tuned model (PolarBertFinetuner) instantiated.
Fine-tuned state dict loaded into model instance.

Fine-tuned Model Named Modules:
                                 | PolarBertFinetuner
  backbone                       | PolarBertModel
  backbone.embedding             | IceCubeTimeEmbedding
  backbone.transformer_blocks    | ModuleList
  backbone.final_norm            | RMSNorm
  backbone.dom

Using device: cuda

Setting up validation dataloader...
Directory: /groups/pheno/inar/icecube_kaggle/memmaped_eval_1M_127
Event limit: 100000
Batch size: 1024
Dataset sliced to 100000 events.
Validation dataloader created.
Estimated number of validation batches: 96

Calculating DOM loss for: Pre-trained Model (DOM Head)


Eval DOM Loss (Pre-trained Model (DOM Head)):   0%|          | 0/96 [00:00<?, ?it/s]

  return torch.as_tensor(data)


Evaluation complete for Pre-trained Model (DOM Head). Batches with valid DOM loss: 96/96

Calculating DOM loss for: Fine-tuned Model (Backbone DOM Head)


Eval DOM Loss (Fine-tuned Model (Backbone DOM Head)):   0%|          | 0/96 [00:00<?, ?it/s]

  return torch.as_tensor(data)


Evaluation complete for Fine-tuned Model (Backbone DOM Head). Batches with valid DOM loss: 97/97

--- DOM Loss Comparison ---
Validation Dataset: /groups/pheno/inar/icecube_kaggle/memmaped_eval_1M_127 (100000 events)
Pre-trained Model DOM Loss : 2.084680
Fine-tuned Model DOM Loss  : 9.516705 (Using DOM Head within its backbone)
---------------------------

Note: DOM loss increased after fine-tuning, which might be expected if the fine-tuning task diverges significantly.


In [12]:
import torch

# --- Ensure models are loaded from previous cells ---
if 'pretrained_model' not in locals() or 'finetuned_model' not in locals():
    print("Models not found in local scope. Please re-run the model loading cells first.")
else:
    print("--- Comparing DOM Head Parameters ---")

    # Access the dom_head layers
    # Note: For the fine-tuned model, the dom_head is part of its 'backbone' attribute
    dom_head_pretrained = pretrained_model.dom_head
    dom_head_finetuned = finetuned_model.backbone.dom_head # Access within the backbone

    # Check if layers exist (should do based on inspection)
    if not isinstance(dom_head_pretrained, torch.nn.Linear):
         print("ERROR: Pre-trained model does not have an nn.Linear dom_head.")
    elif not isinstance(dom_head_finetuned, torch.nn.Linear):
         print("ERROR: Fine-tuned model's backbone does not have an nn.Linear dom_head.")
    else:
        # Compare Weights
        weights_equal = torch.equal(dom_head_pretrained.weight.data, dom_head_finetuned.weight.data)
        print(f"DOM Head Weights are the same: {weights_equal}")

        # Compare Biases (check if bias exists for both first)
        if dom_head_pretrained.bias is not None and dom_head_finetuned.bias is not None:
            biases_equal = torch.equal(dom_head_pretrained.bias.data, dom_head_finetuned.bias.data)
            print(f"DOM Head Biases are the same:  {biases_equal}")
        elif dom_head_pretrained.bias is None and dom_head_finetuned.bias is None:
            print(f"DOM Head Biases are the same:  True (both are None)")
        else:
            print(f"DOM Head Biases are the same:  False (one has bias, the other doesn't)")

        # Overall conclusion
        if weights_equal and \
           (dom_head_pretrained.bias is None and dom_head_finetuned.bias is None or
            (dom_head_pretrained.bias is not None and dom_head_finetuned.bias is not None and biases_equal)):
            print("\nConclusion: The DOM head parameters are identical in both models.")
            print("The difference in DOM loss is likely due to differences in the backbone representations feeding into the head.")
        else:
            print("\nConclusion: The DOM head parameters are DIFFERENT between the two models.")
            print("This difference in head weights likely contributes significantly to the difference in DOM loss.")
            if not weights_equal:
                 # Optional: Check the magnitude of difference
                 weight_diff = torch.abs(dom_head_pretrained.weight.data - dom_head_finetuned.weight.data).mean()
                 print(f"  Mean absolute difference in weights: {weight_diff.item():.4e}")

--- Comparing DOM Head Parameters ---
DOM Head Weights are the same: False
DOM Head Biases are the same:  False

Conclusion: The DOM head parameters are DIFFERENT between the two models.
This difference in head weights likely contributes significantly to the difference in DOM loss.
  Mean absolute difference in weights: 6.6025e-02


In [13]:
import torch
import torch.nn as nn # Ensure nn is imported

# --- Ensure models are loaded from previous cells ---
if 'pretrained_model' not in locals() or 'finetuned_model' not in locals():
    print("Models not found in local scope. Please re-run the model loading cells first.")
else:
    print("--- Replacing Pre-training Heads in Fine-tuned Model's Backbone ---")

    # --- 1. Verify Original Heads (Optional but Recommended) ---
    print("Original heads BEFORE replacement:")
    # Check object IDs (will be different)
    print(f"  Pretrained dom_head ID: {id(pretrained_model.dom_head)}")
    print(f"  Finetuned backbone dom_head ID: {id(finetuned_model.backbone.dom_head)}")
    # Check if parameters are currently different (we know they are from previous test)
    weights_originally_equal = torch.equal(pretrained_model.dom_head.weight.data, finetuned_model.backbone.dom_head.weight.data)
    print(f"  DOM Head weights originally equal: {weights_originally_equal}")

    # --- 2. Perform the Replacement ---
    # Make the dom_head attribute in the fine-tuned backbone point to the
    # exact same nn.Linear layer object as in the pre-trained model.
    print("\nAssigning pre-trained heads to fine-tuned model's backbone...")
    try:
        # Check if the attributes exist before assignment
        if hasattr(finetuned_model.backbone, 'dom_head') and hasattr(pretrained_model, 'dom_head'):
             finetuned_model.backbone.dom_head = pretrained_model.dom_head
             print("  - Assigned pretrained dom_head.")
        else:
             print("  - Could not assign dom_head (missing in source or destination).")

        if hasattr(finetuned_model.backbone, 'charge_head') and hasattr(pretrained_model, 'charge_head'):
             finetuned_model.backbone.charge_head = pretrained_model.charge_head
             print("  - Assigned pretrained charge_head.")
        else:
             print("  - Could not assign charge_head (missing in source or destination).")
        print("Assignment complete.")

    except Exception as e:
        print(f"ERROR during head replacement: {e}")

    # --- 3. Verify Replacement ---
    print("\nVerifying heads AFTER replacement:")
    # Check object IDs (should now be the same for dom_head)
    print(f"  Pretrained dom_head ID: {id(pretrained_model.dom_head)}")
    print(f"  Finetuned backbone dom_head ID (should match above): {id(finetuned_model.backbone.dom_head)}")

    # Check if parameters are now the same (should be True)
    try:
        weights_now_equal = torch.equal(pretrained_model.dom_head.weight.data, finetuned_model.backbone.dom_head.weight.data)
        biases_now_equal = torch.equal(pretrained_model.dom_head.bias.data, finetuned_model.backbone.dom_head.bias.data)
        print(f"  DOM Head weights now equal: {weights_now_equal}")
        print(f"  DOM Head biases now equal:  {biases_now_equal}")
        if weights_now_equal and biases_now_equal:
            print("  Verification successful: Fine-tuned backbone now uses identical pre-trained DOM head object.")
        else:
            print("  Verification FAILED: Parameters are still different after assignment?!")
    except Exception as e:
        print(f"Error during verification: {e}")


    # --- 4. Re-evaluate DOM Loss (Optional) ---
    # You can now re-run the DOM loss calculation on the modified finetuned_model
    # to see if the loss is now closer to the pretrained model's loss.
    # Make sure the dataloader is ready (might need re-instantiation)
    print("\n(Optional) Re-evaluating DOM loss on modified fine-tuned model...")
    # Example: Assuming calculate_dom_loss function and val_loader exist
    # val_loader_re = DataLoader(val_dataset, ...) # Recreate dataloader if needed
    # avg_dom_loss_finetuned_modified = calculate_dom_loss(finetuned_model, val_loader_re, device)
    # print(f"Pre-trained Model DOM Loss        : {avg_dom_loss_pretrained:.6f}") # From previous run
    # print(f"Modified Fine-tuned Model DOM Loss: {avg_dom_loss_finetuned_modified:.6f}")

--- Replacing Pre-training Heads in Fine-tuned Model's Backbone ---
Original heads BEFORE replacement:
  Pretrained dom_head ID: 139796762497808
  Finetuned backbone dom_head ID: 139796762734800
  DOM Head weights originally equal: False

Assigning pre-trained heads to fine-tuned model's backbone...
  - Assigned pretrained dom_head.
  - Assigned pretrained charge_head.
Assignment complete.

Verifying heads AFTER replacement:
  Pretrained dom_head ID: 139796762497808
  Finetuned backbone dom_head ID (should match above): 139796762497808
  DOM Head weights now equal: True
  DOM Head biases now equal:  True
  Verification successful: Fine-tuned backbone now uses identical pre-trained DOM head object.

(Optional) Re-evaluating DOM loss on modified fine-tuned model...


In [15]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm # Use notebook version of tqdm
import numpy as np
import warnings
import os

# --- Make sure project imports are available ---
try:
    # Need config, base model, finetuner model structure (to check instance type)
    from polarbert.config import PolarBertConfig
    from polarbert.time_embed_polarbert import PolarBertModel
    from polarbert.te_finetuning import PolarBertFinetuner # Only needed for isinstance check

    # Need dataset and default transforms used for pre-training validation
    from polarbert.icecube_dataset import IceCubeDataset # Assuming Kaggle type uses this
    from polarbert.te_pretraining import default_transform, default_target_transform

except ImportError as e:
    print(f"Error importing dataset/transform modules: {e}")
    print("Please ensure they are accessible.")
    raise e

# --- Ensure Models Exist from Previous Steps ---
if 'pretrained_model' not in locals():
    print("ERROR: 'pretrained_model' not found. Please load the pre-trained model first.")
    raise NameError("pretrained_model not defined")
if 'finetuned_model' not in locals():
    print("ERROR: 'finetuned_model' not found. Please load the fine-tuned model first.")
    raise NameError("finetuned_model not defined")
if not hasattr(finetuned_model.backbone, 'dom_head') or id(finetuned_model.backbone.dom_head) != id(pretrained_model.dom_head):
     print("WARNING: It looks like the dom_head in the fine-tuned model hasn't been replaced")
     print("         with the pre-trained one yet. Run the replacement cell first for a fair comparison.")
     # Optionally, perform the replacement here again if desired:
     # print("Performing head replacement now...")
     # finetuned_model.backbone.dom_head = pretrained_model.dom_head
     # finetuned_model.backbone.charge_head = pretrained_model.charge_head

# --- 1. Evaluation Setup ---
# Use pre-training validation data
val_data_dir = '/groups/pheno/inar/icecube_kaggle/memmaped_eval_1M_127'
dataset_type = 'kaggle'
eval_batch_size = 1024
val_events_limit = 100000 # Limit validation events

# Other settings
num_workers = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Retrieve lambda_charge (should be same in both configs ideally)
# Load pre-trained config again if not available
if 'pretrained_config' not in locals():
     print("Loading pre-trained config to get lambda_charge...")
     pretrained_config = PolarBertConfig.from_checkpoint(pretrained_ckpt_path) # Need path variable
lambda_charge = pretrained_config.model.lambda_charge
print(f"Using lambda_charge: {lambda_charge}")


# --- 2. Create Pre-training Validation Dataloader ---
print(f"\nSetting up PRE-TRAINING validation dataloader...")
print(f"Directory: {val_data_dir}")
print(f"Event limit: {val_events_limit}")
print(f"Batch size: {eval_batch_size}")

try:
    # Use the default target transform for pre-training data
    pretrain_val_dataset_full = IceCubeDataset(
        data_dir=val_data_dir,
        batch_size=eval_batch_size,
        transform=default_transform,
        target_transform=default_target_transform # Returns numpy y, c
    )
    pretrain_val_dataset = pretrain_val_dataset_full.slice(0, val_events_limit)
    pretrain_val_loader = DataLoader(
        pretrain_val_dataset, batch_size=None, num_workers=num_workers,
        pin_memory=False, persistent_workers=(num_workers > 0)
    )
    num_batches = len(pretrain_val_dataset)
    print(f"Pre-training validation dataloader created ({num_batches} batches).")
except Exception as e:
    print(f"Error creating pre-training validation dataloader: {e}")
    raise e

# --- 3. Define Pre-training Loss Evaluation Function ---

def calculate_pretrain_losses(model: torch.nn.Module, dataloader: DataLoader, device: torch.device, lambda_charge: float) -> dict:
    """Calculates average pre-training losses (DOM, Charge, Combined)."""

    model.eval()
    model.to(device)

    total_dom_loss = 0.0
    total_charge_loss = 0.0
    total_combined_loss = 0.0
    batches_with_dom_loss = 0
    batches_with_charge_loss = 0
    processed_batches_combined = 0

    # Determine how to access components
    is_finetuner = isinstance(model, PolarBertFinetuner)
    if is_finetuner:
        embedding_layer = model.backbone.embedding
        transformer_blocks = model.backbone.transformer_blocks
        final_norm = model.backbone.final_norm
        # Access the potentially swapped heads within the backbone
        dom_head = model.backbone.dom_head
        charge_head = model.backbone.charge_head
        model_description = "Fine-tuned Model (Swapped Heads)"
    elif isinstance(model, PolarBertModel):
        embedding_layer = model.embedding
        transformer_blocks = model.transformer_blocks
        final_norm = model.final_norm
        dom_head = model.dom_head
        charge_head = model.charge_head
        model_description = "Pre-trained Model"
    else:
        raise TypeError("Model type not recognized (expected PolarBertModel or PolarBertFinetuner)")

    print(f"\nCalculating Pre-training losses for: {model_description}")

    with torch.no_grad():
        pbar = tqdm(dataloader, total=num_batches, desc=f"Eval Pretrain Loss ({model_description})")
        for batch in pbar:
            if batch is None: continue
            # Pre-training target transform returns numpy arrays y, c
            (x, l), y_data = batch
            if y_data is None: continue # Skip if no labels

            # Extract targets - y_data is tuple (y_numpy, c_numpy)
            # y_numpy is ignored here, charge c_numpy is needed
            true_total_charge_numpy = y_data[1]

            x = x.to(device)
            true_dom_ids = x[:, :, 3].long() # Assuming index 3 is sensor_id

            # Convert charge target to tensor *before* log10
            if true_total_charge_numpy is not None:
                 true_total_charge = torch.as_tensor(true_total_charge_numpy, device=device)
            else:
                 true_total_charge = None

            try:
                # --- Manual Forward Pass ---
                hidden_states, final_padding_mask, output_mask = embedding_layer((x, l))
                attn_key_padding_mask = final_padding_mask
                for block in transformer_blocks:
                    hidden_states = block(hidden_states, key_padding_mask=attn_key_padding_mask)
                hidden_states = final_norm(hidden_states)
                cls_embed = hidden_states[:, 0, :]
                sequence_embeds = hidden_states[:, 1:, :] # Exclude CLS

                # --- Predictions ---
                dom_logits = dom_head(sequence_embeds)
                charge_pred = charge_head(cls_embed)

                # --- Calculate Losses ---
                # DOM Loss
                current_dom_loss = torch.tensor(float('nan'), device=device)
                if output_mask is not None and output_mask.sum() > 0:
                    dom_targets = true_dom_ids - 1
                    masked_logits = dom_logits[output_mask]
                    masked_targets = dom_targets[output_mask]
                    loss_val_dom = F.cross_entropy(masked_logits, masked_targets.long(), ignore_index=-1)
                    if not torch.isnan(loss_val_dom):
                        current_dom_loss = loss_val_dom
                        total_dom_loss += current_dom_loss.item()
                        batches_with_dom_loss += 1
                elif processed_batches_combined < 1: # Warn only once
                    warnings.warn("output_mask is None or empty. Cannot calculate DOM loss.", RuntimeWarning)

                # Charge Loss
                current_charge_loss = torch.tensor(float('nan'), device=device)
                if true_total_charge is not None:
                    true_log_charge = torch.log10(torch.clamp(true_total_charge.float(), min=1e-6))
                    loss_val_charge = F.mse_loss(charge_pred.squeeze(-1), true_log_charge)
                    if not torch.isnan(loss_val_charge):
                        current_charge_loss = loss_val_charge
                        total_charge_loss += current_charge_loss.item()
                        batches_with_charge_loss += 1

                # Combined Loss
                dom_val = current_dom_loss.item() if not torch.isnan(current_dom_loss) else 0.0
                charge_val = current_charge_loss.item() if not torch.isnan(current_charge_loss) else 0.0
                combined_loss = dom_val + lambda_charge * charge_val

                if not np.isnan(combined_loss): # Check combined float value
                     total_combined_loss += combined_loss
                     processed_batches_combined += 1 # Increment only if combined is valid

            except Exception as e:
                 print(f"\nError during forward/loss calculation in batch {processed_batches_combined}: {e}")
                 warnings.warn(f"Skipping batch {processed_batches_combined} due to error.", RuntimeWarning)

            # Update progress bar
            if processed_batches_combined > 0:
                 pbar.set_postfix({
                     "Avg DOM": f"{total_dom_loss / batches_with_dom_loss:.4f}" if batches_with_dom_loss > 0 else "NaN",
                     "Avg Chrg": f"{total_charge_loss / batches_with_charge_loss:.4f}" if batches_with_charge_loss > 0 else "NaN"
                 })

    # --- Calculate Averages ---
    avg_dom = (total_dom_loss / batches_with_dom_loss) if batches_with_dom_loss > 0 else float('nan')
    avg_charge = (total_charge_loss / batches_with_charge_loss) if batches_with_charge_loss > 0 else float('nan')
    avg_combined = (total_combined_loss / processed_batches_combined) if processed_batches_combined > 0 else float('nan')

    print(f"Evaluation complete for {model_description}. Batches processed for combined loss: {processed_batches_combined}")

    return {"dom": avg_dom, "charge": avg_charge, "combined": avg_combined}

# --- 4. Run Evaluation ---
results_pretrained = calculate_pretrain_losses(pretrained_model, pretrain_val_loader, device, lambda_charge)

# Recreate dataloader for the second model
pretrain_val_loader_2 = DataLoader(pretrain_val_dataset, batch_size=None, num_workers=num_workers, pin_memory=False, persistent_workers=(num_workers > 0))
# Ensure finetuned_model is the one with swapped heads
results_finetuned_swapped = calculate_pretrain_losses(finetuned_model, pretrain_val_loader_2, device, lambda_charge)

# --- 5. Print Comparison ---
print("\n--- Pre-training Loss Benchmark ---")
print(f"Validation Dataset: {val_data_dir} ({val_events_limit} events)")
print(f"Metric          | Pre-trained Model | Fine-tuned (Swapped Heads)")
print(f"----------------|-------------------|---------------------------")
print(f"DOM Loss        | {results_pretrained['dom']:<17.6f} | {results_finetuned_swapped['dom']:<25.6f}")
print(f"Charge Loss     | {results_pretrained['charge']:<17.6f} | {results_finetuned_swapped['charge']:<25.6f}")
print(f"Combined Loss* | {results_pretrained['combined']:<17.6f} | {results_finetuned_swapped['combined']:<25.6f}")
print(f"*Combined = DOM + {lambda_charge:.2f} * Charge")
print("----------------------------------------------------------------")

Using device: cuda
Using lambda_charge: 1.0

Setting up PRE-TRAINING validation dataloader...
Directory: /groups/pheno/inar/icecube_kaggle/memmaped_eval_1M_127
Event limit: 100000
Batch size: 1024
Pre-training validation dataloader created (96 batches).

Calculating Pre-training losses for: Pre-trained Model


Eval Pretrain Loss (Pre-trained Model):   0%|          | 0/96 [00:00<?, ?it/s]

  return torch.as_tensor(data)


Evaluation complete for Pre-trained Model. Batches processed for combined loss: 96

Calculating Pre-training losses for: Fine-tuned Model (Swapped Heads)


Eval Pretrain Loss (Fine-tuned Model (Swapped Heads)):   0%|          | 0/96 [00:00<?, ?it/s]

  return torch.as_tensor(data)


Evaluation complete for Fine-tuned Model (Swapped Heads). Batches processed for combined loss: 97

--- Pre-training Loss Benchmark ---
Validation Dataset: /groups/pheno/inar/icecube_kaggle/memmaped_eval_1M_127 (100000 events)
Metric          | Pre-trained Model | Fine-tuned (Swapped Heads)
----------------|-------------------|---------------------------
DOM Loss        | 2.084382          | 4.000487                 
Charge Loss     | 0.000578          | 0.585109                 
Combined Loss* | 2.084960          | 4.585596                 
*Combined = DOM + 1.00 * Charge
----------------------------------------------------------------
