In [None]:
!ls /eos/user/r/rdfexp/ecal/cluster/output_deepcluster_dumper/windows_data/electrons/run3_126X_2023/ndjson_126X_mcRun3_2023_forPU65_byevent_v1

In [None]:
import glob 
import gzip 
import json
import tarfile
import numpy as np
import matplotlib.pyplot as plt
import math
import os
from glob import glob
import numpy as np
import uproot
import torch
import torch_geometric
from torch_geometric.data import InMemoryDataset,  Dataset
from torch_geometric.transforms import NormalizeFeatures
from torch.utils.data import IterableDataset
from torch_geometric.data import Data
from multiprocessing import Pool

#from torch.utils.data import Dataset
import torch.nn.functional as F
import awkward as ak
from tqdm import tqdm

#from torch_geometric.data import InMemoryDataset
import os
import glob
import torch

In [None]:
import torch
from torch_geometric.data import Data

def convert_to_tensor(features_dict):
    # Convert dictionary of lists to a single tensor
    features = torch.tensor([features_dict[key] for key in features_dict], dtype=torch.float).T
    return features

def create_data_object(nodes_features, nodes_sim_features, edges_idx, edges_labels):
    # Convert nodes features and labels to tensors
    x = convert_to_tensor(nodes_features)
    y = convert_to_tensor(nodes_sim_features)
    
    # Convert edge indices and labels to tensors
    edge_index = torch.tensor(edges_idx, dtype=torch.long).T
    edge_attr = torch.tensor([edges_labels[key] for key in edges_labels], dtype=torch.float).T
    
    # Create a Data object
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
    return data

In [None]:
class ECALGraphDataset(IterableDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, graphs_in_file=512):
        self.root = root
        self.transform = transform
        self.pre_transform =  pre_transform
        self.pre_filter = pre_filter
        self.raw_dir = self.root + "/raw"
        self.processed_dir = self.root+ "/processed"
        self.graphs_in_file = graphs_in_file
        self._nfiles = len(self.processed_files)
        self._start = 0
        self._end = self._nfiles
        
    @property
    def raw_paths(self):
        """Returns expected raw files (not paths)."""
        return [f  for f in glob.glob(os.path.join(self.root, "raw", "*.tar.gz"))]

    @property
    def processed_files(self):
        """Retrieve the list of processed .pt files."""
        processed_dir = os.path.join(self.root, "processed")
        return sorted(f for f in os.listdir(processed_dir) if f.endswith('.pt'))

    def _preprocess(self, args):
        file, group_idx = args
        dataset = read_file(file)
        graphs_in_group = [ ]
        ifiles = 0
        for igraph, graph in enumerate(dataset):
            # Create PyG data object
            data = create_data_object(
                nodes_features=graph['nodes_features'],
                nodes_sim_features=graph['nodes_sim_features'],
                edges_idx=graph['edges_idx'],
                edges_labels=graph['edges_labels']
            )
            
            # Apply pre-filter if specified
            if self.pre_filter and not self.pre_filter(data):
                continue
                
            # Apply pre-transform if specified
            if self.pre_transform:
                data = self.pre_transform(data)
                
            graphs_in_group.append(data)

            if len(graphs_in_group) >= self.graphs_in_file or \
                   igraph == (len(dataset)-1):
                print(f"Saving file, {igraph}, {len(graphs_in_group)}")
                group_data_list = graphs_in_group
                group_data, slices = torch_geometric.data.InMemoryDataset.collate(group_data_list)
            
                save_path = os.path.join(
                    self.processed_dir,
                    f'graph_data_group_{group_idx}_{ifiles}.pt')
                torch.save((group_data, slices), save_path)
                graphs_in_group = []
                ifiles += 1
            
    
    def preprocess(self, num_workers=4):
        """Processes raw data into PyTorch Geometric format."""
        
        jobs = []
        for ifile, file in enumerate(self.raw_paths):
            jobs.append((file, ifile))

        #p = Pool(num_workers)
        #p.map(self._preprocess, jobs)
        for job in jobs:
            self._preprocess(job)

    def _load_group(self, file_path):
        """Loads a group of graphs from a file and yields individual graphs."""
        full_path = os.path.join(self.root, "processed", file_path)
        group_data, slices = torch.load(full_path)
        
        num_graphs = len(slices['x']) - 1
        for i in range(num_graphs):
            data = Data()
            for key in group_data.keys():
                if key in slices:
                    start, end = slices[key][i], slices[key][i+1]
                    if key == 'edge_index':
                        data[key] = group_data[key][:, start:end]
                    else:
                        data[key] = group_data[key][start:end]
                else:
                    data[key] = group_data[key]
            if self.transform:
                data = self.transform(data)
            yield data


    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            iter_start = self._start
            iter_end = self._end
        else:  # in a worker process
             # split workload
            per_worker = int(math.ceil((self._end - self._start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self._start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self._end)
        for ifile in range(iter_start, iter_end):
            yield from self._load_group(self.processed_files[ifile])


In [None]:
d = ECALGraphDataset("/eos/user/r/rdfexp/ecal/cluster/output_deepcluster_dumper/windows_data/gammas/run3_126X_2023_overlapTraining_double/ndjson_126X_mcRun3_2023_forPU65_byevent_v1/", graphs_in_file=51200)


In [None]:
dl =  DataLoader(
        d,
        batch_size=512,
        num_workers=1,
        pin_memory=True  # Faster data transfer to GPU
    )

In [None]:
for i,df  in enumerate(dl):
    print(i, df)
    if i == 100: 
        break

In [None]:
from torch_geometric.loader import DataLoader
import torch

def create_batch_loader(dataset, batch_size=32, num_workers=0):
    """
    Creates a DataLoader instance for batching graphs.
    
    Args:
        dataset: ECALLazyDataset instance
        batch_size: Number of graphs per batch
        shuffle: Whether to shuffle the dataset
        num_workers: Number of worker processes for loading data
        
    Returns:
        DataLoader instance configured for graph batching
    """
    

def create_dynamic_batch_loader(dataset, max_nodes=10000, mode='node', 
                              num_workers=0):
    """
    Creates a DataLoader instance with dynamic batching based on node count.
    
    Args:
        dataset: ECALLazyDataset instance
        max_nodes: Maximum number of nodes per batch
        mode: Either 'node' or 'edge' for size measurement
        shuffle: Whether to shuffle the dataset
        num_workers: Number of worker processes for loading data
        
    Returns:
        DataLoader instance configured for dynamic batching
    """
    from torch_geometric.loader import DynamicBatchSampler
    
    sampler = DynamicBatchSampler(
        dataset,
        max_num=max_nodes,
        mode=mode,
    )
    
    return DataLoader(
        dataset,
        batch_sampler=sampler,
        num_workers=num_workers,
        pin_memory=True
    )

In [None]:
df = create_batch_loader(d)

In [None]:
next(iter(df))

In [None]:
for i,df  in enumerate(df):
    print('.',end='')
    if i == 1000: 
        break

In [None]:
import networkx as nx
import matplotlib.pyplot as plt

def draw_graph(data):
    G = nx.Graph()
    
    # Add nodes with positions and sizes
    for i, (ieta, iphi, en) in enumerate(zip(data.x[:, 4], data.x[:, 5], data.x[:, 0])):
        G.add_node(i, pos=(ieta.item(), iphi.item()), size=en.item())
    
    # Add edges
    for edge in data.edge_index.T:
        G.add_edge(edge[0].item(), edge[1].item())
    
    # Get positions and sizes for drawing
    pos = nx.get_node_attributes(G, 'pos')
    sizes = [G.nodes[node]['size'] for node in G.nodes]  # Scale sizes for better visualization
    
    # Draw the graph
    plt.figure(figsize=(10, 10))
    nx.draw(G, pos, with_labels=True, node_size=sizes, node_color='skyblue', edge_color='gray', font_size=8) #node_size=sizes, 
    plt.show()


In [None]:
draw_graph(df)

In [None]:
for i in range(len(data_json)):
    graph_data = data_json[i]
    g = create_data_object(graph_data["nodes_features"], graph_data["nodes_sim_features"], 
                   graph_data["edges_idx"], graph_data["edges_labels"])
    draw_graph(g)