In [None]:
# Batch Size (32)
# Number of subjets (20)
# Number of Feature (8)
# subject length (30)

In [1]:
# Standard library imports
import yaml
import json
from tqdm import tqdm

# Third-party imports
import numpy as np
import matplotlib.pyplot as plt

# PyTorch imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# HDF5 handling
import h5py

print("All necessary modules imported successfully.")

All necessary modules imported successfully.


In [2]:
class JetDataset(Dataset):
    def __init__(self, file_path, subset_size=None, transform=None):
        """
        Initializes the JetDataset with data from an HDF5 file.
        
        Parameters:
        - file_path (str): Path to the HDF5 file containing the dataset.
        - subset_size (int, optional): Number of samples to use from the dataset.
        - transform (callable, optional): Transform to apply to the features.
        """
        print(f"Initializing JetDataset with file: {file_path}")
        
        # Load data from HDF5 file
        with h5py.File(file_path, 'r') as hdf:
            self.features = torch.tensor(hdf["particles/features"][:], dtype=torch.float32)
            self.subjets = [json.loads(subjet) for subjet in hdf["subjets"][:]]
        
        self.transform = transform
        
        print(f"Raw dataset size: {len(self.subjets)} jets")
        print(f"Feature shape: {self.features.shape}")
        
        # Filter jets with at least 10 real subjets
        self.filter_good_jets()
        
        if subset_size is not None:
            print(f"Applying subset size: {subset_size}")
            self.features = self.features[:subset_size]
            self.subjets = self.subjets[:subset_size]
        
        print(f"Final dataset size: {len(self.subjets)} jets")

    def filter_good_jets(self):
        """
        Filters the jets to keep only those with at least 10 real subjets.
        """
        print("Filtering good jets...")
        good_jets = []
        good_features = []
        
        for i in range(len(self.subjets)):
            num_real_subjets = self.get_num_real_subjets(self.subjets[i])
            if num_real_subjets >= 10:
                good_jets.append(self.subjets[i])
                good_features.append(self.features[i])

            # Uncomment the lines below to print progress every 1000 jets
            # if i % 1000 == 0:
            #     print(f"Processed {i} jets, found {len(good_jets)} good jets")
        
        self.subjets = good_jets
        self.features = torch.stack(good_features)
        print(f"Filtered to {len(self.subjets)} good jets")
    
    @staticmethod
    def get_num_real_subjets(jet):
        """
        Counts and returns the number of real subjets in a given jet.
        
        Parameters:
        - jet (list): A list of subjets where each subjet is a dictionary.
        
        Returns:
        - int: Number of real subjets (those with num_ptcls > 0).
        """
        return sum(1 for subjet in jet if subjet['features']['num_ptcls'] > 0)

    def __len__(self):
        """
        Returns the total number of jets in the dataset.
        
        Returns:
        - int: Number of jets.
        """
        return len(self.subjets)

    def __getitem__(self, idx):
        """
        Retrieves the features and subjets for a given index and processes them.
        
        Parameters:
        - idx (int): Index of the jet to retrieve.
        
        Returns:
        - tuple: (features, subjets, subjet_mask, particle_mask)
          where `features` is the normalized feature tensor,
          `subjets` is the processed subjets tensor,
          `subjet_mask` is the mask for subjets,
          and `particle_mask` is the mask for particles.
        """
        features = self.features[idx]
        subjets = self.subjets[idx]
        subjets, subjet_mask, particle_mask = self.process_subjets(subjets)
        
        feature_names = ['pT', 'eta', 'phi']
        features = normalize_features(features, feature_names, config, jet_type='Jets')
        
        if self.transform:
            features = self.transform(features)
        
        if idx % 1000 == 0:
            print(f"Getting item {idx}")
            print(f"Features shape: {features.shape}")
            print(f"Subjets shape: {subjets.shape}")
            print(f"Subjet mask shape: {subjet_mask.shape}")
            print(f"Particle mask shape: {particle_mask.shape}")
            print(f"Number of non-empty subjets: {subjet_mask.sum().item()}")
            print(f"Number of non-empty particles: {particle_mask.sum().item()}")

        return features, subjets, subjet_mask, particle_mask

    def process_subjets(self, subjets):
        """
        Processes subjets to create tensor representations and masks.
        
        Parameters:
        - subjets (list): List of subjets where each subjet is a dictionary.
        
        Returns:
        - tuple: (subjets, subjet_mask, particle_mask)
          where `subjets` is the tensor representation of subjets,
          `subjet_mask` is the mask for subjets,
          and `particle_mask` is the mask for particles.
        """
        if isinstance(subjets, torch.Tensor):
            subjets = subjets.tolist()

        if isinstance(subjets[0], list):
            subjets = [{'features': {'pT': s[0], 'eta': s[1], 'phi': s[2], 'num_ptcls': s[3]}, 'indices': s[4:]} for s in subjets]

        max_len = max(len(subjet['indices']) for subjet in subjets)
        subjet_tensors = []
        subjet_mask = []
        particle_mask = []
        
        for i, subjet in enumerate(subjets):
            feature_tensors = [torch.tensor([subjet['features'][k]], dtype=torch.float32).expand(max_len) for k in ['pT', 'eta', 'phi', 'num_ptcls']]
            features = torch.stack(feature_tensors, dim=0)
            
            print(f"Subjet feature tensors shape after adjustment: {features.shape}")
            indices = torch.tensor(subjet['indices'], dtype=torch.float32).unsqueeze(0).expand(features.size(0), -1)

            if indices.shape[1] < max_len:
                pad_len = max_len - indices.shape[1]
                indices = torch.nn.functional.pad(indices, (0, pad_len), 'constant', -1)

            combined = torch.cat([features, indices], dim=0)
            subjet_tensors.append(combined)
            
            is_empty = subjet['features']['num_ptcls'] == 0
            subjet_mask.append(0 if is_empty else 1)
            particle_mask.append([1 if i < len(subjet['indices']) else 0 for i in range(max_len)])
            
            if i % 100 == 0:
                print(f"Processed subjet {i}")
                print(f"Subjet features shape: {features.shape}")
                print(f"Subjet indices shape: {indices.shape}")
                print(f"Is empty: {is_empty}")

        subjets = torch.stack(subjet_tensors)
        subjet_mask = torch.tensor(subjet_mask, dtype=torch.float32)
        particle_mask = torch.tensor(particle_mask, dtype=torch.float32)
        
        print(f"Processed subjets shape: {subjets.shape}")
        print(f"Subjet mask shape: {subjet_mask.shape}")
        print(f"Particle mask shape: {particle_mask.shape}")
        
        return subjets, subjet_mask, particle_mask


In [43]:
def custom_collate_fn(batch):
    """
    Custom collate function for DataLoader to handle variable-sized subjets and particles.
    
    Parameters:
    - batch (list): List of tuples where each tuple contains (features, subjets, subjet_mask, particle_mask).
    
    Returns:
    - tuple: (features, subjets, subjet_masks, particle_masks) where each element is a padded tensor.
    """
    print("\n--- Starting custom_collate_fn ---")
    
    # Unzip the batch to separate features, subjets, subjet_masks, and particle_masks
    features, subjets, subjet_masks, particle_masks = zip(*batch)
    
    # Stack features along a new dimension
    features = torch.stack(features)
    print(f"Features shape after stacking: {features.shape}")
    
    # Determine the maximum dimensions for padding
    max_subjets = max(s.size(0) for s in subjets)
    max_subjet_features = max(s.size(1) for s in subjets)
    max_subjet_length = max(s.size(2) for s in subjets)
    
    print(f"Max subjets: {max_subjets}")
    print(f"Max subjet features: {max_subjet_features}")
    print(f"Max subjet length: {max_subjet_length}")
    
    padded_subjets = []
    padded_subjet_masks = []
    padded_particle_masks = []
    
    # Pad each element in the batch to match the maximum dimensions
    for i, (s, sm, pm) in enumerate(zip(subjets, subjet_masks, particle_masks)):
        print(f"\nProcessing batch item {i+1}/{len(batch)}:")
        print(f"Original subjets shape: {s.shape}")
        print(f"Original subjet mask shape: {sm.shape}")
        print(f"Original particle mask shape: {pm.shape}")

        print(f"size of 0 : {s.size(0)}")
        print(f"size of 1 : {s.size(1)}")
        print(f"size of 2 : {s.size(2)}")
        
        pad_subjets = max_subjets - s.size(0)
        pad_features = max_subjet_features - s.size(1)
        pad_length = max_subjet_length - s.size(2)
        
        print(f"Padding required - Subjets: {pad_subjets}, Features: {pad_features}, Length: {pad_length}")

        if(pad_subjets > 0 or pad_features > 0 or pad_length > 0):
            # Pad the subjets tensor
            padded_s = F.pad(s, (0, pad_length, 0, pad_features, 0, pad_subjets), "constant", 0)
            padded_subjets.append(padded_s)
            
            # Pad the subjet masks tensor
            padded_sm = F.pad(sm, (0, pad_subjets), "constant", 0)
            padded_subjet_masks.append(padded_sm)
            
            # Pad the particle masks tensor
            padded_pm = F.pad(pm, (0, pad_length, 0, pad_subjets), "constant", 0)
            padded_particle_masks.append(padded_pm)
        
        
            print(f"Padded subjets shape: {padded_s.shape}")
            print(f"Padded subjet mask shape: {padded_sm.shape}")
            print(f"Padded particle mask shape: {padded_pm.shape}")
    
    # Stack the padded tensors
    subjets = torch.stack(padded_subjets)
    subjet_masks = torch.stack(padded_subjet_masks)
    particle_masks = torch.stack(padded_particle_masks)
    
    print(f"\nFinal stacked subjets shape: {subjets.shape}")
    print(f"Final stacked subjet masks shape: {subjet_masks.shape}")
    print(f"Final stacked particle masks shape: {particle_masks.shape}")
    
    print("--- End of custom_collate_fn ---\n")
    
    return features, subjets, subjet_masks, particle_masks


In [44]:
import torch
import torch.nn as nn

class IJEPA(nn.Module):
    """
    IJEPA model class that uses a transformer encoder for context and target representations
    and predicts representations for masked parts of the input.
    """
    
    def __init__(self, input_dim, embed_dim, depth, num_heads, mlp_ratio, dropout=0.1, use_predictor=True):
        """
        Initializes the IJEPA model.
        
        Parameters:
        - input_dim (int): The dimension of the input features.
        - embed_dim (int): The dimension of the embedding space.
        - depth (int): The number of layers in the transformer encoder.
        - num_heads (int): The number of attention heads in the transformer encoder.
        - mlp_ratio (float): The ratio of the hidden dimension in the MLP layer relative to the embedding dimension.
        - dropout (float): The dropout rate.
        - max_seq_len (int): The maximum sequence length for positional embeddings.
        - use_predictor (bool): Whether to use the predictor module.
        """
        super(J_JepaBase, self).__init__()
        self.use_predictor = use_predictor
        
        print(f"Initializing IJEPA with input_dim={input_dim}, embed_dim={embed_dim}, use_predictor={use_predictor}")

        # TODO: Positional embedding
        
        # Transformer encoders for context and target
        self.context_encoder = self.create_encoder(embed_dim, depth, num_heads, mlp_ratio, dropout)
        self.target_encoder = self.create_encoder(embed_dim, depth, num_heads, mlp_ratio, dropout)

        # Embedding layer
        self.embedding = nn.Linear(input_dim, embed_dim)
        
        if use_predictor:
            # Predictor MLP to transform the context representations to match the input dimension
            self.predictor = nn.Sequential(
                nn.Linear(embed_dim, embed_dim),
                nn.GELU(),
                nn.Linear(embed_dim, input_dim)  # Ensure the output dimension matches the input dimension
            )

    def create_encoder(self, embed_dim, depth, num_heads, mlp_ratio, dropout):
        """
        Creates a transformer encoder.
        
        Parameters:
        - embed_dim (int): The dimension of the embedding space.
        - depth (int): The number of layers in the transformer encoder.
        - num_heads (int): The number of attention heads in the transformer encoder.
        - mlp_ratio (float): The ratio of the hidden dimension in the MLP layer relative to the embedding dimension.
        - dropout (float): The dropout rate.
        
        Returns:
        - nn.TransformerEncoder: A transformer encoder.
        """
        return nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, 
                                       dim_feedforward=int(embed_dim * mlp_ratio), 
                                       dropout=dropout),
            num_layers=depth
        )

    def forward(self, context, target):
        """
        Forward pass of the IJEPA model.
        
        Parameters:
        - context (torch.Tensor): The context input tensor.
        - target (torch.Tensor): The target input tensor.
        
        Returns:
        - pred_repr (torch.Tensor or None): The predicted representations for the masked parts (if use_predictor is True).
        - target_repr (torch.Tensor): The encoded target representations.
        - context_repr_shape (torch.Size): The shape of the context representations.
        - target_repr_shape (torch.Size): The shape of the target representations.
        """
        print("Starting IJEPA forward pass")
        batch_size, num_subjets, num_features, subjet_length = context.size()
        
        print(f"Input context shape: {context.shape}")
        print(f"Input target shape: {target.shape}")
        
        # Flatten the context and target tensors
        context = context.view(batch_size, num_subjets, -1)
        target = target.view(batch_size, num_subjets, -1)
        
        print(f"Reshaped context shape: {context.shape}")
        print(f"Reshaped target shape: {target.shape}")
        
        # Apply the embedding layer
        context_emb = self.embedding(context)
        target_emb = self.embedding(target)
        
        print(f"Embedded context shape: {context.shape}")
        print(f"Embedded target shape: {target.shape}")
        
        # TODO: Add positional embeddings
        
        print("Applying context encoder")
        # Encode the context representations
        context_repr = self.context_encoder(context_emb.transpose(0, 1)).transpose(0,1)
        print("Applying target encoder")
        target_repr = self.target_encoder(target_emb.transpose(0, 1)).transpose(0, 1)
        
        print(f"Context encoder output shape: {context_repr.shape}")
        print(f"Target encoder output shape: {target_repr.shape}")
        
        if self.use_predictor:
            print("Applying predictor")
            # Predict the representations for the masked parts
            pred_repr = self.predictor(context_repr)
            return pred_repr, context_repr, target_repr
            
        
        print(f"Final predicted representation shape: {pred_repr.shape if pred_repr is not None else 'N/A'}")
        print(f"Final target representation shape: {target_repr.shape}")
        
        return pred_repr, target_repr, context_repr.shape, target_repr.shape

In [45]:
def train_step(model, subjets, subjet_masks, particle_masks, optimizer, device, step):
    """
    Performs a single training step for the IJEPA model.
    
    Parameters:
    - model (nn.Module): The IJEPA model.
    - subjets (torch.Tensor): The input subjets tensor.
    - subjet_masks (torch.Tensor): The mask tensor for the subjets.
    - particle_masks (torch.Tensor): The mask tensor for the particles.
    - optimizer (torch.optim.Optimizer): The optimizer.
    - device (torch.device): The device to run the training on (CPU or GPU).
    - step (int): The current training step number.
    
    Returns:
    - loss.item() (float): The loss value for the current training step.
    """
    
    print(f"\nStarting training step {step}")
    batch_size, num_subjets, num_features, subjet_length = subjets.size()
    print(f"Input shapes - Subjets: {subjets.shape}, Subjet masks: {subjet_masks.shape}, Particle masks: {particle_masks.shape}")
    
    # Create random masks for context and target
    context_masks, target_masks = create_random_masks(batch_size, num_subjets, num_features, subjet_length)
    
    # Move masks and inputs to the appropriate device
    context_masks = context_masks.to(device)
    target_masks = target_masks.to(device)
    subjet_masks = subjet_masks.to(device)
    particle_masks = particle_masks.to(device)
    
    print(f"Mask shapes - Context: {context_masks.shape}, Target: {target_masks.shape}")
    
    # Apply masks to the input subjets
    context_subjets = subjets * context_masks
    target_subjets = subjets * target_masks
    
    print(f"Context subjets shape: {context_subjets.shape}")
    print(f"Target subjets shape: {target_subjets.shape}")
    
    # Zero the parameter gradients
    optimizer.zero_grad()
    
    print("Forwarding through model")
    # Forward pass through the model
    pred_repr, target_repr, context_repr_shape, target_repr_shape = model(context_subjets, target_subjets)  # Use masked subjets as target
    
    print(f"Predicted representation shape: {pred_repr.shape}")
    print(f"Target representation shape: {target_repr.shape}")
    
    # Apply masks to the loss calculation
    combined_mask = target_masks.to(device) * subjet_masks.unsqueeze(-1).unsqueeze(-1).expand_as(target_masks).to(device)
    
    # Ensure `pred_repr` and `target_repr` are on the same device
    pred_repr = pred_repr.to(device)
    target_repr = target_repr.to(device)
    
    # Compute the masked mean squared error loss
    loss = F.mse_loss(pred_repr * combined_mask, target_repr * combined_mask)
    print(f"Calculated loss: {loss.item()}")
    
    # Backward pass and optimization step
    loss.backward()
    optimizer.step()
    
    # Print and visualize predictions every 500 steps
    if step % 500 == 0:
        print_jet_details(pred_repr[0].cpu(), "Predicted")
        visualize_predictions_vs_ground_truth(subjets[0].cpu(), pred_repr[0].cpu(), title=f"Ground Truth vs Predictions (Step {step})")
        print(f"Context representation shape: {context_repr_shape}")
        print(f"Target representation shape: {target_repr_shape}")
        
    return loss.item()

In [46]:
def create_random_masks(batch_size, num_subjets, num_features, subjet_length, context_scale=0.7):
    """
    Creates random masks for context and target subjets.
    
    Parameters:
    - batch_size (int): Number of samples in a batch.
    - num_subjets (int): Number of subjets in each sample.
    - num_features (int): Number of features in each subjet.
    - subjet_length (int): Length of each subjet.
    - context_scale (float): Proportion of subjets to be used as context.
    
    Returns:
    - context_masks (torch.Tensor): Tensor of context masks.
    - target_masks (torch.Tensor): Tensor of target masks.
    """
    print(f"Creating random masks with batch_size={batch_size}, num_subjets={num_subjets}")
    context_masks = []
    target_masks = []

    for i in range(batch_size):
        # Randomly permute indices of subjets
        indices = torch.randperm(num_subjets)
        context_size = int(num_subjets * context_scale)
        context_indices = indices[:context_size]
        target_indices = indices[context_size:]

        # Initialize masks with zeros
        context_mask = torch.zeros(num_subjets, num_features, subjet_length)
        target_mask = torch.zeros(num_subjets, num_features, subjet_length)

        # Set mask values for context and target
        context_mask[context_indices] = 1
        target_mask[target_indices] = 1

        context_masks.append(context_mask)
        target_masks.append(target_mask)

        if i == 0:
            print(f"Sample context mask shape: {context_mask.shape}")
            print(f"Sample target mask shape: {target_mask.shape}")

    return torch.stack(context_masks), torch.stack(target_masks)

# def normalize_features(features):
#     """
#     Normalizes the features tensor.
    
#     Parameters:
#     - features (torch.Tensor): Input features tensor.
    
#     Returns:
#     - normalized (torch.Tensor): Normalized features tensor.
#     """
#     print("Normalizing features")
#     mean = features.mean(dim=0, keepdim=True)
#     std = features.std(dim=0, keepdim=True)
#     std[std == 0] = 1e-8  # Prevent division by zero
#     normalized = (features - mean) / std
#     print(f"Normalized features shape: {normalized.shape}")
#     return normalized

# load normalize config
with open("config.yaml", "r") as file:
    config = yaml.safe_load(file)

def normalize_features(features, feature_names, config, jet_type="Jets"):
    """
    Normalize the input features using methods specified in the YAML configuration file.
    
    Parameters:
    - features (Tensor): Input features of shape (num_particles, num_features).
    - feature_names (list of str): List of feature names corresponding to the columns in `features`.
    - config (dict): Configuration dictionary loaded from the YAML file.
    - jet_type (str): The type of jet ('Jets' or 'Subjets') for which normalization is applied.
    
    Returns:
    - Tensor: Normalized features.
    """

    normalized_features = features.clone()
    for i, feature_name in enumerate(feature_names):
        method = config["INPUTS"]["SEQUENTIAL"][jet_type].get(feature_name, "none")
        print(f"Normalizing feature'{feature_name}' with method '{method}'")

        if method == "normalize":
            mean = features[:, i].mean()
            std = features[:, i].std()
            std = std if std > 1e-8 else 1.0  #to prevent division by zero
            normalized_features[:, 1] = (features[:, i] - mean )/std
        elif method == "log_normalize":
            normalized_features[:,i] =  torch.log1p(features[:, i])


    return normalized_features

def print_jet_details(jet, name):
    """
    Prints details of the given jet tensor.
    
    Parameters:
    - jet (torch.Tensor): Input jet tensor.
    - name (str): Name of the jet for identification.
    """
    print(f"\n{name} Jet Details:")
    print(f"Shape: {jet.shape}")
    print(f"Non-zero elements: {torch.count_nonzero(jet)}")
    print("\nFirst few elements:")
    print(jet[0, :5, :5])

def visualize_jet(subjets, mask=None, title="Jet Visualization"):
    """
    Visualizes the subjets in a scatter plot.
    
    Parameters:
    - subjets (torch.Tensor): Tensor of subjets.
    - mask (torch.Tensor, optional): Mask tensor for filtering subjets.
    - title (str): Title of the plot.
    """
    fig, ax = plt.subplots(figsize=(10, 10))
    
    valid_points = False
    for i in range(subjets.size(0)):
        if mask is None or mask[i].any():
            eta = subjets[i, 1, 0].item()
            phi = subjets[i, 2, 0].item()
            pT = subjets[i, 0, 0].item()
            if pT > 0:
                ax.scatter(eta, phi, s=pT * 5, alpha=0.7)
                valid_points = True
    
    ax.set_xlabel('Eta')
    ax.set_ylabel('Phi')
    ax.set_title(title)
    
    if valid_points and ax.collections:
        plt.colorbar(ax.collections[0], label='pT')
    else:
        print(f"No valid points to display for {title}")
    
    plt.show()

def visualize_context_and_targets(subjets, context_mask, target_mask):
    """
    Visualizes context and target subjets in a scatter plot.
    
    Parameters:
    - subjets (torch.Tensor): Tensor of subjets.
    - context_mask (torch.Tensor): Mask tensor for context subjets.
    - target_mask (torch.Tensor): Mask tensor for target subjets.
    """
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(30, 10))

    def scatter_jet(ax, jet_data, mask, title):
        for i in range(jet_data.size(0)):
            if mask[i, 0, 0].item() > 0:  # Check only the first element of each subjet
                eta = jet_data[i, 1, 0].item()
                phi = jet_data[i, 2, 0].item()
                pT = jet_data[i, 0, 0].item()
                if pT > 0:
                    scatter = ax.scatter(eta, phi, s=pT * 5, alpha=0.7)
        ax.set_xlabel('Eta')
        ax.set_ylabel('Phi')
        ax.set_title(title)
        if ax.collections:
            plt.colorbar(ax.collections[0], ax=ax, label='pT')

    scatter_jet(ax1, subjets, torch.ones_like(context_mask), "Full Jet")
    scatter_jet(ax2, subjets, context_mask, "Context Subjets")
    scatter_jet(ax3, subjets, target_mask, "Target Subjets")

    plt.tight_layout()
    plt.show()

def visualize_predictions_vs_ground_truth(subjets, pred_repr, title="Ground Truth vs Predictions"):
    """
    Visualizes ground truth and predicted subjets in a scatter plot.
    
    Parameters:
    - subjets (torch.Tensor): Ground truth subjets tensor.
    - pred_repr (torch.Tensor): Predicted subjets tensor.
    - title (str): Title of the plot.
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))

    def scatter_jet(ax, subjets, title):
        valid_points = False
        for i in range(subjets.size(0)):
            eta = subjets[i, 1, 0].item()
            phi = subjets[i, 2, 0].item()
            pT = subjets[i, 0, 0].item()
            if pT > 0:
                ax.scatter(eta, phi, s=pT * 5, alpha=0.7)
                valid_points = True
        ax.set_xlabel('Eta')
        ax.set_ylabel('Phi')
        ax.set_title(title)
        if valid_points and ax.collections:
            plt.colorbar(ax.collections[0], label='pT')

    scatter_jet(ax1, subjets, "Ground Truth")
    scatter_jet(ax2, pred_repr, "Predictions")

    plt.suptitle(title)
    plt.show()

def visualize_training_loss(train_losses):
    """
    Plots the training loss over epochs.
    
    Parameters:
    - train_losses (list of float): List of training loss values.
    """
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.title('Training Loss over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()


def visualize_sample_prediction(train_loader, model, device, step=0):
    """
    Visualizes a sample prediction after training.
    
    Parameters:
    train_loader (DataLoader): DataLoader for the training data.
    model (IJEPA): The trained model.
    device (torch.device): Device to run the model on (CPU or GPU).
    step (int): Step number for labeling the visualization (default: 0).
    """
    # Print a message indicating the start of sample prediction visualization
    print("Visualizing sample prediction")

    # Set the model to evaluation mode (disables dropout, batch normalization, etc.)
    model.eval()
    
    # Disable gradient computation for efficiency
    with torch.no_grad():
        # Get the next batch of data from the train_loader
        sample_batch = next(iter(train_loader))

        # Unpack the batch to get subjets, subjet masks, and particle masks
        _, sample_subjets, sample_subjet_masks, sample_particle_masks = sample_batch

        # Move the tensors to the specified device (CPU or GPU)
        sample_subjets = sample_subjets.to(device)
        sample_subjet_masks = sample_subjet_masks.to(device)
        sample_particle_masks = sample_particle_masks.to(device)

        # Get the dimensions of the sample_subjets tensor
        batch_size, num_subjets, num_features, subjet_length = sample_subjets.size()

        # Create random context and target masks
        context_masks, target_masks = create_random_masks(batch_size, num_subjets, num_features, subjet_length)

        # Move the masks to the specified device (CPU or GPU)
        context_masks = context_masks.to(device)
        target_masks = target_masks.to(device)

        # Apply the masks to the sample subjets to create context and target subjets
        context_subjets = sample_subjets * context_masks
        target_subjets = sample_subjets * target_masks

        # Feed the context and target subjets into the model to get the predicted and target representations
        pred_repr, target_repr, context_repr_shape, target_repr_shape = model(context_subjets, target_subjets)

        # Print the shapes of the final context and target representations
        print(f"Final context representation shape: {context_repr_shape}")
        print(f"Final target representation shape: {target_repr_shape}")

        # Visualize the ground truth vs predictions for the first item in the batch
        visualize_predictions_vs_ground_truth(sample_subjets[0].cpu(), pred_repr[0].cpu(), title=f"Final Model: Ground Truth vs Predictions (Step {step})")



In [4]:
if __name__ == "__main__":
    print("Starting main program")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
        
    # Load a subset of the dataset
    try:
        print("Loading dataset")
        train_dataset = JetDataset("../data/val/val_20_30.h5", subset_size = 1000)  # Use only 1000 samples    
    except Exception as e:
        print(f"Error loading dataset: {e}")

    print("Creating DataLoader")
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)

    # Initialize model
    print("Initializing model")
    model = IJEPA(input_dim=240, embed_dim=512, depth=12, num_heads=8, mlp_ratio=4.0, dropout=0.1)
    # model = IJEPA(input_dim=240, embed_dim=512, depth=16, num_heads=8, mlp_ratio=4.0, dropout=0.1).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.04)

    # Training
    num_epochs = 10
    train_losses = []
    
    print(f"Starting training for {num_epochs} epochs")
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}", leave=True, position=0)
    
        for step, (features, subjets, subjet_masks, particle_masks) in enumerate(train_loader):
            features = features.to(device)
            subjets = subjets.to(device)
            subjet_masks = subjet_masks.to(device)
            particle_masks = particle_masks.to(device)
            
            loss = train_step(model, subjets, subjet_masks, particle_masks, optimizer, device, step)
            total_loss += loss
            
            progress_bar.set_postfix(loss=loss)
            progress_bar.update(1)
            
            if step % 100 == 0:
                print(f"\nEpoch {epoch+1}, Step {step}, Loss: {loss:.4f}")
        
        progress_bar.close()

        avg_train_loss = total_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {avg_train_loss:.4f}")

    print("Training completed")

    # Visualize training loss
    print("Visualizing training loss")
    visualize_training_loss(train_losses)

    print("Saving model")
    torch.save(model.state_dict(), 'ijepa_model.pth')

    print("Model saved. Starting evaluation.")

    # Visualize a sample prediction after training
    visualize_sample_prediction(train_loader, model, device, step=num_epochs)

    print("Evaluation completed.")

Starting main program
Using device: cuda
Loading dataset


KeyboardInterrupt: 