In [7]:
import torch
import torch.nn as nn

In [11]:
def custom_collate_fn(batch):    
    """
    Custom collate function for DataLoader to handle variable-sized subjets and particles.
    """
    print("\n--- Starting custom_collate_fn ---")
    
    features, subjets, subjet_masks, particle_masks = zip(*batch)
    
    features = torch.stack(features)
    print(f"Features shape after stacking: {features.shape}")
    
    max_subjets = max(s.size(0) for s in subjets)
    max_subjet_features = max(s.size(1) for s in subjets)
    max_subjet_length = max(s.size(2) for s in subjets)
    
    print(f"Max subjets: {max_subjets}")
    print(f"Max subjet features: {max_subjet_features}")
    print(f"Max subjet length: {max_subjet_length}")
    
    padded_subjets = []
    padded_subjet_masks = []
    padded_particle_masks = []
    
    for i, (s, sm, pm) in enumerate(zip(subjets, subjet_masks, particle_masks)):
        pad_subjets = max_subjets - s.size(0)
        pad_features = max_subjet_features - s.size(1)
        pad_length = max_subjet_length - s.size(2)
        
        print(f"Padding required - Subjets: {pad_subjets}, Features: {pad_features}, Length: {pad_length}")

        padded_s = F.pad(s, (0, pad_length, 0, pad_features, 0, pad_subjets), "constant", 0)
        padded_subjets.append(padded_s)
        
        padded_sm = F.pad(sm, (0, pad_subjets), "constant", 0)
        padded_subjet_masks.append(padded_sm)
        
        padded_pm = F.pad(pm, (0, pad_length, 0, pad_subjets), "constant", 0)
        padded_particle_masks.append(padded_pm)
        
        print(f"Padded subjets shape: {padded_s.shape}")
        print(f"Padded subjet mask shape: {padded_sm.shape}")
        print(f"Padded particle mask shape: {padded_pm.shape}")
    
    subjets = torch.stack(padded_subjets)
    subjet_masks = torch.stack(padded_subjet_masks)
    particle_masks = torch.stack(padded_particle_masks)
    
    print(f"\nFinal stacked subjets shape: {subjets.shape}")
    print(f"Final stacked subjet masks shape: {subjet_masks.shape}")
    print(f"Final stacked particle masks shape: {particle_masks.shape}")
    
    print("--- End of custom_collate_fn ---\n")
    
    return features, subjets, subjet_masks, particle_masks

In [8]:
def create_random_masks(batch_size, num_subjets, num_features, subjet_length, context_scale=0.7):
    print(f"Creating random masks with batch_size={batch_size}, num_subjets={num_subjets}")
    context_masks = []
    target_masks = []

    for i in range(batch_size):
        indices = torch.randperm(num_subjets)
        context_size = int(num_subjets * context_scale)
        context_indices = indices[:context_size]
        target_indices = indices[context_size:]

        context_mask = torch.zeros(num_subjets, num_features, subjet_length)
        target_mask = torch.zeros(num_subjets, num_features, subjet_length)

        context_mask[context_indices] = 1
        target_mask[target_indices] = 1

        context_masks.append(context_mask)
        target_masks.append(target_mask)

    return torch.stack(context_masks), torch.stack(target_masks)

def normalize_features(features, feature_names, config, jet_type="Jets"):
    print(f"Normalizing features with shape: {features.shape}")
    normalized_features = features.clone()
    for i, feature_name in enumerate(feature_names):
        method = config["INPUTS"]["SEQUENTIAL"][jet_type].get(feature_name, "none")
        print(f"Normalizing feature '{feature_name}' with method '{method}'")

        if method == "normalize":
            mean = features[:, i].mean()
            std = features[:, i].std()
            std = std if std > 1e-8 else 1.0
            normalized_features[:, i] = (features[:, i] - mean) / std
        elif method == "log_normalize":
            normalized_features[:, i] = torch.log1p(features[:, i])

    print(f"Normalized features shape: {normalized_features.shape}")
    return normalized_features



In [9]:
def print_jet_details(jet, name):
    print(f"\n{name} Jet Details:")
    print(f"Shape: {jet.shape}")
    print(f"Non-zero elements: {torch.count_nonzero(jet)}")
    print("\nFirst few elements:")
    print(jet[0, :5, :5])

def check_raw_data(subjets_data, jet_index=0):
    print(f"\n--- Checking Raw Data for Jet {jet_index} ---")
    jet_subjets = subjets_data[jet_index]
    print(f"Number of subjets: {len(jet_subjets)}")
    print(f"Subjet features: {list(jet_subjets[0]['features'].keys())}")
    print(f"Sample subjet feature values: {jet_subjets[0]['features']}")

def check_processed_data(processed_subjets, batch_index=0):
    print(f"\n--- Checking Processed Data for Batch Item {batch_index} ---")
    print(f"Processed shape: {processed_subjets.shape}")
    if len(processed_subjets.shape) == 3:
        num_subjets, num_features, subjet_length = processed_subjets.shape
        print(f"Number of subjets: {num_subjets}")
        print(f"Number of features: {num_features}")
        print(f"Subjet length: {subjet_length}")
        print("\nFirst few values of each feature:")
        for i in range(num_features):
            print(f"Feature {i}: {processed_subjets[0, i, :5]}")
    else:
        print("Unexpected shape for processed subjets")

def inspect_subjets(subjets, num_samples=5):
    print("\n--- Inspecting Subjets ---")
    for i in range(min(num_samples, len(subjets))):
        subjet = subjets[i]
        print(f"\nSubjet {i}:")
        print(f"  pT: {subjet['features']['pT']:.2f}")
        print(f"  eta: {subjet['features']['eta']:.2f}")
        print(f"  phi: {subjet['features']['phi']:.2f}")
        print(f"  num_ptcls: {subjet['features']['num_ptcls']}")

class DimensionCheckLayer(torch.nn.Module):
    def __init__(self, name, expected_dims):
        super().__init__()
        self.name = name
        self.expected_dims = expected_dims

    def forward(self, x):
        if len(x.shape) != self.expected_dims:
            print(f"WARNING: {self.name} has {len(x.shape)} dimensions, expected {self.expected_dims}")
        return x

In [10]:
def visualize_training_loss(train_losses):
    print("Visualizing training loss")
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.title('Training Loss over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()