In [2]:
import yaml
import json
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import h5py

class JetDataset(Dataset):
    """JetDataset class for loading and processing jet data from HDF5 files.

    Args:
        Dataset (_type_): _description_
    """
    def __init__(self, file_path, subset_size=None, transform=None, config=None):
        print(f"Initializing JetDataset with file: {file_path}")
        
        with h5py.File(file_path, 'r') as hdf:
            print("Loading features and subjets from HDF5 file")
            self.features = torch.tensor(hdf["particles/features"][:], dtype=torch.float32)
            self.subjets = [json.loads(subjet) for subjet in hdf["subjets"][:]]
        
        self.transform = transform
        self.config = config
        print(f"Raw dataset size: {len(self.subjets)} jets")
        print(f"Feature shape: {self.features.shape}")
        
        self.filter_good_jets()
        
        if subset_size is not None:
            print(f"Applying subset size: {subset_size}")
            self.features = self.features[:subset_size]
            self.subjets = self.subjets[:subset_size]
        
        print(f"Final dataset size: {len(self.subjets)} jets")

    def filter_good_jets(self):
        """
        Filters jets to retain only those with a sufficient number of real subjets.
        """
        print("Filtering good jets...")
        good_jets = []
        good_features = []
        
        for i in range(len(self.subjets)):
            num_real_subjets = self.get_num_real_subjets(self.subjets[i])
            if num_real_subjets >= 10:
                good_jets.append(self.subjets[i])
                good_features.append(self.features[i])
        
        self.subjets = good_jets
        self.features = torch.stack(good_features)
        print(f"Filtered to {len(self.subjets)} good jets")
    
    @staticmethod
    def get_num_real_subjets(jet):
        """
        Returns the number of real subjets in a given jet.
        """
        return sum(1 for subjet in jet if subjet['features']['num_ptcls'] > 0)

    def __len__(self):
        return len(self.subjets)

    def __getitem__(self, idx):
        """
        Retrieves the features and subjets for a given index and processes them.
        """
        check_raw_data(self.subjets, jet_index=idx)  # Debug statement

        print(f"\nFetching item {idx} from dataset")
        features = self.features[idx]
        subjets = self.subjets[idx]

        inspect_subjets(subjets)  # Debug statement

        subjets, subjet_mask, particle_mask = self.process_subjets(subjets)
        
        check_processed_data(subjets)  # Debug statement

        feature_names = ['pT', 'eta', 'phi']
        print("Normalizing features")
        features = normalize_features(features, feature_names, self.config, jet_type='Jets')
        
        if self.transform:
            print("Applying transform to features")
            features = self.transform(features)
        
        return features, subjets, subjet_mask, particle_mask

    def process_subjets(self, subjets):
        """
        Processes subjets to create tensor representations and masks.
        """
        print("Processing subjets")

        max_len = max(len(subjet['indices']) for subjet in subjets)
        print(f"Max length of indices in subjets: {max_len}")
        subjet_tensors = []
        subjet_mask = []
        particle_mask = []
        
        for i, subjet in enumerate(subjets):
            feature_tensors = [torch.tensor([subjet['features'][k]], dtype=torch.float32).expand(max_len) for k in ['pT', 'eta', 'phi', 'num_ptcls']]
            features = torch.stack(feature_tensors, dim=0)
            
            is_empty = subjet['features']['num_ptcls'] == 0
            subjet_mask.append(0 if is_empty else 1)
            particle_mask.append([1 if i < len(subjet['indices']) else 0 for i in range(max_len)])

            subjet_tensors.append(features)

        subjets = torch.stack(subjet_tensors)
        subjet_mask = torch.tensor(subjet_mask, dtype=torch.float32)
        particle_mask = torch.tensor(particle_mask, dtype=torch.float32)
        
        print(f"Final processed subjets shape: {subjets.shape}")
        print(f"Final subjet mask shape: {subjet_mask.shape}")
        print(f"Final particle mask shape: {particle_mask.shape}")
        
        return subjets, subjet_mask, particle_mask

def custom_collate_fn(batch):    
    """
    Custom collate function for DataLoader to handle variable-sized subjets and particles.
    """
    print("\n--- Starting custom_collate_fn ---")
    
    features, subjets, subjet_masks, particle_masks = zip(*batch)
    
    features = torch.stack(features)
    print(f"Features shape after stacking: {features.shape}")
    
    max_subjets = max(s.size(0) for s in subjets)
    max_subjet_features = max(s.size(1) for s in subjets)
    max_subjet_length = max(s.size(2) for s in subjets)
    
    print(f"Max subjets: {max_subjets}")
    print(f"Max subjet features: {max_subjet_features}")
    print(f"Max subjet length: {max_subjet_length}")
    
    padded_subjets = []
    padded_subjet_masks = []
    padded_particle_masks = []
    
    for i, (s, sm, pm) in enumerate(zip(subjets, subjet_masks, particle_masks)):
        pad_subjets = max_subjets - s.size(0)
        pad_features = max_subjet_features - s.size(1)
        pad_length = max_subjet_length - s.size(2)
        
        print(f"Padding required - Subjets: {pad_subjets}, Features: {pad_features}, Length: {pad_length}")

        padded_s = F.pad(s, (0, pad_length, 0, pad_features, 0, pad_subjets), "constant", 0)
        padded_subjets.append(padded_s)
        
        padded_sm = F.pad(sm, (0, pad_subjets), "constant", 0)
        padded_subjet_masks.append(padded_sm)
        
        padded_pm = F.pad(pm, (0, pad_length, 0, pad_subjets), "constant", 0)
        padded_particle_masks.append(padded_pm)
        
        print(f"Padded subjets shape: {padded_s.shape}")
        print(f"Padded subjet mask shape: {padded_sm.shape}")
        print(f"Padded particle mask shape: {padded_pm.shape}")
    
    subjets = torch.stack(padded_subjets)
    subjet_masks = torch.stack(padded_subjet_masks)
    particle_masks = torch.stack(padded_particle_masks)
    
    print(f"\nFinal stacked subjets shape: {subjets.shape}")
    print(f"Final stacked subjet masks shape: {subjet_masks.shape}")
    print(f"Final stacked particle masks shape: {particle_masks.shape}")
    
    print("--- End of custom_collate_fn ---\n")
    
    return features, subjets, subjet_masks, particle_masks

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        print("Initializing Attention module")
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        print(f"Attention forward pass with input shape: {x.shape}")
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        print(f"Attention output shape: {x.shape}")
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        print("Initializing MLP module")
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        print(f"MLP forward pass with input shape: {x.shape}")
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        print(f"MLP output shape: {x.shape}")
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        print("Initializing Block module")
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = nn.Identity() if drop_path <= 0 else nn.Dropout(drop_path)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        print(f"Block forward pass with input shape: {x.shape}")
        y = self.attn(self.norm1(x))
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        print(f"Block output shape: {x.shape}")
        return x

class JetsTransformer(nn.Module):
    def __init__(self, num_features, embed_dim, depth, num_heads, mlp_ratio, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=nn.LayerNorm):
        super().__init__()
        print("Initializing JetsTransformer module")
        self.num_features = num_features
        self.embed_dim = embed_dim
        
        # Adjust the input dimensions based on the new input shape
        self.patch_embed = nn.Linear(num_features * 30, embed_dim)  # num_features * subjet_length
        
        self.pos_embed = nn.Parameter(torch.zeros(1, 512, embed_dim))
        self.blocks = nn.ModuleList([Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate, norm_layer=norm_layer) for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        print(f"JetsTransformer forward pass with input shape: {x.shape}")
        B, N, C, L = x.shape
        x = x.view(B, N, -1)  # Flatten last two dimensions to [B, N, C*L]
        print(f"Flattened input shape: {x.shape}")
        x = self.patch_embed(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        print(f"JetsTransformer output shape: {x.shape}")
        return x.view(B, N, -1)  # Reshape back if necessary

class JetsTransformerPredictor(nn.Module):
    def __init__(self, num_features, embed_dim, predictor_embed_dim, depth, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm):
        super().__init__()
        print("Initializing JetsTransformerPredictor module")
        self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
        self.predictor_blocks = nn.ModuleList([Block(dim=predictor_embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate, norm_layer=norm_layer) for i in range(depth)])
        self.predictor_norm = norm_layer(predictor_embed_dim)
        self.predictor_proj = nn.Linear(predictor_embed_dim, 8 * 30, bias=True)  # Match target dimensions
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, masks_x, masks):
        print(f"JetsTransformerPredictor forward pass with input shape: {x.shape}")
        x = self.predictor_embed(x)
        for blk in self.predictor_blocks:
            x = blk(x)
        x = self.predictor_proj(x)
        print(f"JetsTransformerPredictor output shape: {x.shape}")
        return x.view(x.size(0), x.size(1), 8, 30)  # Reshape to match target_repr shape

class JJEPA(nn.Module):
    def __init__(self, input_dim, embed_dim, depth, num_heads, mlp_ratio, dropout=0.1, use_predictor=True):
        super(JJEPA, self).__init__()
        print("Initializing JJEPA module")
        self.use_predictor = use_predictor
        self.context_transformer = JetsTransformer(num_features=input_dim, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, drop_rate=dropout)
        if self.use_predictor:
            self.predictor_transformer = JetsTransformerPredictor(num_features=input_dim, embed_dim=embed_dim, predictor_embed_dim=embed_dim//2, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, drop_rate=dropout)


        # Debug Statement - Dimension check
        self.input_check = DimensionCheckLayer("Model Input", 3)
        self.context_check = DimensionCheckLayer("After Context Transformer", 3)
        self.predictor_check = DimensionCheckLayer("After Predictor", 3)

    def forward(self, context, target):
        print(f"JJEPA forward pass with context shape: {context.shape} and target shape: {target.shape}")
        context = context.to(next(self.parameters()).device)
        target = target.to(next(self.parameters()).device)
        
        context_repr = self.context_transformer(context)
        # Debug Statement
        context_repr = self.context_check(context_repr)
        if self.use_predictor:
            pred_repr = self.predictor_transformer(context_repr, None, None)
            pred_repr = self.predictor_check(pred_repr)
            print(f"JJEPA output - pred_repr shape: {pred_repr.shape}, context_repr shape: {context_repr.shape}, target shape: {target.shape}")
            return pred_repr, context_repr, target
        
        print(f"JJEPA output - context_repr shape: {context_repr.shape}, target shape: {target.shape}")
        return context_repr, target

def train_step(model, subjets, subjet_masks, particle_masks, optimizer, device, step):
    print(f"\nStarting training step {step}")
    
    # Debug Statement
    check_processed_data(subjets)
    
    batch_size, num_subjets, num_features, subjet_length = subjets.size()
    print(f"Input shapes - Subjets: {subjets.shape}, Subjet masks: {subjet_masks.shape}, Particle masks: {particle_masks.shape}")
    
    context_masks, target_masks = create_random_masks(batch_size, num_subjets, num_features, subjet_length)
    print(f"Context masks shape: {context_masks.shape}, Target masks shape: {target_masks.shape}")
    
    context_masks = context_masks.to(device)
    target_masks = target_masks.to(device)
    subjet_masks = subjet_masks.to(device)
    particle_masks = particle_masks.to(device)
    
    context_subjets = subjets * context_masks
    target_subjets = subjets * target_masks
    
    optimizer.zero_grad()
    
    print("Forwarding through model")
    pred_repr, context_repr, target_repr = model(context_subjets, target_subjets)
    
    print(f"Predicted representation shape: {pred_repr.shape}")
    print(f"Target representation shape: {target_repr.shape}")
    
    combined_mask = target_masks.to(device) * subjet_masks.unsqueeze(-1).unsqueeze(-1).expand_as(target_masks).to(device)
    
    pred_repr = pred_repr.to(device)
    target_repr = target_repr.to(device)
    
    print("Calculating loss")
    loss = F.mse_loss(pred_repr * combined_mask, target_repr * combined_mask)
    print(f"Calculated loss: {loss.item()}")
    
    loss.backward()
    optimizer.step()
    
    if step % 500 == 0:
        print_jet_details(pred_repr[0].cpu(), "Predicted")
        visualize_predictions_vs_ground_truth(subjets[0].cpu(), pred_repr[0].cpu(), title=f"Ground Truth vs Predictions (Step {step})")
        print(f"Context representation shape: {context_repr.shape}")
        print(f"Target representation shape: {target_repr.shape}")
        
    return loss.item()



def create_random_masks(batch_size, num_subjets, num_features, subjet_length, context_scale=0.7):
    print(f"Creating random masks with batch_size={batch_size}, num_subjets={num_subjets}")
    context_masks = []
    target_masks = []

    for i in range(batch_size):
        indices = torch.randperm(num_subjets)
        context_size = int(num_subjets * context_scale)
        context_indices = indices[:context_size]
        target_indices = indices[context_size:]

        context_mask = torch.zeros(num_subjets, num_features, subjet_length)
        target_mask = torch.zeros(num_subjets, num_features, subjet_length)

        context_mask[context_indices] = 1
        target_mask[target_indices] = 1

        context_masks.append(context_mask)
        target_masks.append(target_mask)

    return torch.stack(context_masks), torch.stack(target_masks)

def normalize_features(features, feature_names, config, jet_type="Jets"):
    print(f"Normalizing features with shape: {features.shape}")
    normalized_features = features.clone()
    for i, feature_name in enumerate(feature_names):
        method = config["INPUTS"]["SEQUENTIAL"][jet_type].get(feature_name, "none")
        print(f"Normalizing feature '{feature_name}' with method '{method}'")

        if method == "normalize":
            mean = features[:, i].mean()
            std = features[:, i].std()
            std = std if std > 1e-8 else 1.0
            normalized_features[:, i] = (features[:, i] - mean) / std
        elif method == "log_normalize":
            normalized_features[:, i] = torch.log1p(features[:, i])

    print(f"Normalized features shape: {normalized_features.shape}")
    return normalized_features

def print_jet_details(jet, name):
    print(f"\n{name} Jet Details:")
    print(f"Shape: {jet.shape}")
    print(f"Non-zero elements: {torch.count_nonzero(jet)}")
    print("\nFirst few elements:")
    print(jet[0, :5, :5])

def check_raw_data(subjets_data, jet_index=0):
    print(f"\n--- Checking Raw Data for Jet {jet_index} ---")
    jet_subjets = subjets_data[jet_index]
    print(f"Number of subjets: {len(jet_subjets)}")
    print(f"Subjet features: {list(jet_subjets[0]['features'].keys())}")
    print(f"Sample subjet feature values: {jet_subjets[0]['features']}")

def check_processed_data(processed_subjets, batch_index=0):
    print(f"\n--- Checking Processed Data for Batch Item {batch_index} ---")
    print(f"Processed shape: {processed_subjets.shape}")
    if len(processed_subjets.shape) == 3:
        num_subjets, num_features, subjet_length = processed_subjets.shape
        print(f"Number of subjets: {num_subjets}")
        print(f"Number of features: {num_features}")
        print(f"Subjet length: {subjet_length}")
        print("\nFirst few values of each feature:")
        for i in range(num_features):
            print(f"Feature {i}: {processed_subjets[0, i, :5]}")
    else:
        print("Unexpected shape for processed subjets")

def inspect_subjets(subjets, num_samples=5):
    print("\n--- Inspecting Subjets ---")
    for i in range(min(num_samples, len(subjets))):
        subjet = subjets[i]
        print(f"\nSubjet {i}:")
        print(f"  pT: {subjet['features']['pT']:.2f}")
        print(f"  eta: {subjet['features']['eta']:.2f}")
        print(f"  phi: {subjet['features']['phi']:.2f}")
        print(f"  num_ptcls: {subjet['features']['num_ptcls']}")

class DimensionCheckLayer(torch.nn.Module):
    def __init__(self, name, expected_dims):
        super().__init__()
        self.name = name
        self.expected_dims = expected_dims

    def forward(self, x):
        if len(x.shape) != self.expected_dims:
            print(f"WARNING: {self.name} has {len(x.shape)} dimensions, expected {self.expected_dims}")
        return x


def visualize_training_loss(train_losses):
    print("Visualizing training loss")
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.title('Training Loss over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

if __name__ == "__main__":
    print("Starting main program")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    with open('config.yaml', 'r') as file:
        config = yaml.safe_load(file)
        
    try:
        print("Loading dataset")
        train_dataset = JetDataset("../data/val/val_20_30.h5", subset_size=1000, config=config)
    except Exception as e:
        print(f"Error loading dataset: {e}")

    print("Creating DataLoader")
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)

    print("Initializing model")
    model = JJEPA(input_dim=240, embed_dim=512, depth=12, num_heads=8, mlp_ratio=4.0, dropout=0.1).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.04)

    num_epochs = 10
    train_losses = []
    
    print(f"Starting training for {num_epochs} epochs")
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}", leave=True, position=0)
    
        for step, (features, subjets, subjet_masks, particle_masks) in enumerate(train_loader):
            features = features.to(device)
            subjets = subjets.to(device)
            subjet_masks = subjet_masks.to(device)
            particle_masks = particle_masks.to(device)
            
            loss = train_step(model, subjets, subjet_masks, particle_masks, optimizer, device, step)
            total_loss += loss
            
            progress_bar.set_postfix(loss=loss)
            progress_bar.update(1)
            
            if step % 100 == 0:
                print(f"\nEpoch {epoch+1}, Step {step}, Loss: {loss:.4f}")
        
        progress_bar.close()

        avg_train_loss = total_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {avg_train_loss:.4f}")

    print("Training completed")
    print("Visualizing training loss")
    visualize_training_loss(train_losses)

    print("Saving model")
    torch.save(model.state_dict(), 'ijepa_model.pth')

    print("Model saved.")


Starting main program
Using device: cuda
Loading dataset
Initializing JetDataset with file: ../data/val/val_20_30.h5
Error loading dataset: [Errno 2] Unable to synchronously open file (unable to open file: name = '../data/val/val_20_30.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)
Creating DataLoader


NameError: name 'train_dataset' is not defined