In [3]:
from torch.utils.data import Dataset, DataLoader
import h5py

class JetDataset(Dataset):
    """JetDataset class for loading and processing jet data from HDF5 files.

    Args:
        Dataset (_type_): _description_
    """
    def __init__(self, file_path, subset_size=None, transform=None, config=None):
        print(f"Initializing JetDataset with file: {file_path}")
        
        with h5py.File(file_path, 'r') as hdf:
            print("Loading features and subjets from HDF5 file")
            self.features = torch.tensor(hdf["particles/features"][:], dtype=torch.float32)
            self.subjets = [json.loads(subjet) for subjet in hdf["subjets"][:]]
        
        self.transform = transform
        self.config = config
        print(f"Raw dataset size: {len(self.subjets)} jets")
        print(f"Feature shape: {self.features.shape}")
        
        self.filter_good_jets()
        
        if subset_size is not None:
            print(f"Applying subset size: {subset_size}")
            self.features = self.features[:subset_size]
            self.subjets = self.subjets[:subset_size]
        
        print(f"Final dataset size: {len(self.subjets)} jets")

    def filter_good_jets(self):
        """
        Filters jets to retain only those with a sufficient number of real subjets.
        """
        print("Filtering good jets...")
        good_jets = []
        good_features = []
        
        for i in range(len(self.subjets)):
            num_real_subjets = self.get_num_real_subjets(self.subjets[i])
            if num_real_subjets >= 10:
                good_jets.append(self.subjets[i])
                good_features.append(self.features[i])
        
        self.subjets = good_jets
        self.features = torch.stack(good_features)
        print(f"Filtered to {len(self.subjets)} good jets")
    
    @staticmethod
    def get_num_real_subjets(jet):
        """
        Returns the number of real subjets in a given jet.
        """
        return sum(1 for subjet in jet if subjet['features']['num_ptcls'] > 0)

    def __len__(self):
        return len(self.subjets)

    def __getitem__(self, idx):
        """
        Retrieves the features and subjets for a given index and processes them.
        """
        check_raw_data(self.subjets, jet_index=idx)  # Debug statement

        print(f"\nFetching item {idx} from dataset")
        features = self.features[idx]
        subjets = self.subjets[idx]

        inspect_subjets(subjets)  # Debug statement

        subjets, subjet_mask, particle_mask = self.process_subjets(subjets)
        
        check_processed_data(subjets)  # Debug statement

        feature_names = ['pT', 'eta', 'phi']
        print("Normalizing features")
        features = normalize_features(features, feature_names, self.config, jet_type='Jets')
        
        if self.transform:
            print("Applying transform to features")
            features = self.transform(features)
        
        return features, subjets, subjet_mask, particle_mask

    def process_subjets(self, subjets):
        """
        Processes subjets to create tensor representations and masks.
        """
        print("Processing subjets")

        max_len = max(len(subjet['indices']) for subjet in subjets)
        print(f"Max length of indices in subjets: {max_len}")
        subjet_tensors = []
        subjet_mask = []
        particle_mask = []
        
        for i, subjet in enumerate(subjets):
            feature_tensors = [torch.tensor([subjet['features'][k]], dtype=torch.float32).expand(max_len) for k in ['pT', 'eta', 'phi', 'num_ptcls']]
            features = torch.stack(feature_tensors, dim=0)
            
            is_empty = subjet['features']['num_ptcls'] == 0
            subjet_mask.append(0 if is_empty else 1)
            particle_mask.append([1 if i < len(subjet['indices']) else 0 for i in range(max_len)])

            subjet_tensors.append(features)

        subjets = torch.stack(subjet_tensors)
        subjet_mask = torch.tensor(subjet_mask, dtype=torch.float32)
        particle_mask = torch.tensor(particle_mask, dtype=torch.float32)
        
        print(f"Final processed subjets shape: {subjets.shape}")
        print(f"Final subjet mask shape: {subjet_mask.shape}")
        print(f"Final particle mask shape: {particle_mask.shape}")
        
        return subjets, subjet_mask, particle_mask
