In [None]:
# Optimizing Dataset Processing

Let's discuss strategies to optimize and parallelize the dataset processing:

1. **Multiprocessing** to parallelize file processing
2. **Memory-mapped files** to reduce memory usage
3. **Efficient serialization** to speed up I/O operations
4. **Chunking** to process data in parallel batches
5. **Alternative storage formats** for PyTorch Geometric Data objects
import concurrent.futures
import os
from tqdm.auto import tqdm
import time

```python
def process_single_graph(args):
    """Process a single graph from an HDF5 file
    
    Args:
        args: Tuple containing (file_path, idx, output_dir, validate_data, pre_filter, pre_transform)
        
    Returns:
        int: The index of the processed graph
    """
    file_path, idx, output_dir, validate_data, pre_filter, pre_transform = args
    
    with h5py.File(file_path, "r") as f:
        data = load_graph(
            f, 
            idx,
            float_dtype=torch.float32,
            int_dtype=torch.int64,
            validate=validate_data
        )
    
    if pre_filter is not None and not pre_filter(data):
        return idx, False  # Skip this data point
        
    if pre_transform is not None:
        data = pre_transform(data)
        
    # Save to disk
    torch.save(data, os.path.join(output_dir, f"data_{idx}.pt"))
    return idx, True

class OptimizedCsDataset(CsDataset):
    def process_parallel(self, max_workers=None, chunk_size=100):
        """Process the dataset in parallel using multiprocessing.
        
        Args:
            max_workers: Maximum number of worker processes
            chunk_size: Number of graphs to process in each worker
            
        Returns:
            None
        """
        # Convert NumPy arrays to Python lists for JSON serialization
        d = {
            "manifold_codes": [v.item() for v in self.manifold_codes],
            "manifolds": [str(m) for m in self.manifold_names],
            "boundaries": [str(m) for m in self.boundaries],
            "boundary_codes": [v.item() for v in self.boundary_codes],
        }

        with open(os.path.join(self.processed_dir, "metadata.json"), "w") as f:
            json.dump(d, f)
            
        # Calculate optimal number of workers if not specified
        if max_workers is None:
            max_workers = min(os.cpu_count(), 16)  # Use at most 16 workers or CPU count
            
        print(f"Processing dataset using {max_workers} workers")
        
        start_time = time.time()
        total_processed = 0
        
        # Process each file
        for file_idx, file in enumerate(self.raw_paths):
            if not os.path.exists(file):
                raise FileNotFoundError(f"Input file {file} does not exist.")
                
            with h5py.File(file, "r") as f:
                num_graphs = f["num_causal_sets"][()]
                print(f"Processing file {file_idx+1}/{len(self.raw_paths)}: {file} with {num_graphs} graphs")
                
                # Create argument list for all graphs in this file
                args_list = [
                    (file, idx, self.processed_dir, self.validate_data, self.pre_filter, self.pre_transform)
                    for idx in range(num_graphs)
                ]
                
                # Process in parallel using chunks
                with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
                    # Submit tasks in chunks to reduce overhead
                    future_to_idx = {}
                    for i in range(0, len(args_list), chunk_size):
                        chunk_args = args_list[i:i + chunk_size]
                        future_batch = executor.map(process_single_graph, chunk_args)
                        future_to_idx[i] = future_batch
                    
                    # Process results with progress bar
                    with tqdm(total=num_graphs) as pbar:
                        for start_idx, future_batch in future_to_idx.items():
                            for idx, success in future_batch:
                                if success:
                                    total_processed += 1
                                pbar.update(1)
        
        end_time = time.time()
        print(f"Processed {total_processed} graphs in {end_time - start_time:.2f} seconds")
        print(f"Average processing time: {(end_time - start_time) / total_processed:.4f} seconds per graph")
    ```


```python
import pickle
import torch.utils.data
import tempfile
import shutil
import numpy as np
from torch_sparse import SparseTensor

class MemoryEfficientStorage:
    """A class to handle memory-efficient storage of PyTorch Geometric Data objects"""
    
    @staticmethod
    def save_data_efficient(data, file_path):
        """Save a PyTorch Geometric Data object efficiently.
        
        This method:
        1. Converts sparse tensors to COO format
        2. Uses memory-mapped files for large tensors
        3. Compresses the file
        
        Args:
            data: PyTorch Geometric Data object
            file_path: Path to save the data
        """
        # Create a dictionary to store serialized attributes
        serialized = {}
        
        # Process each attribute of the data object
        for key, value in data:
            if torch.is_tensor(value):
                # For large tensors, use memory mapping
                if value.numel() > 1000000:  # Threshold for large tensors
                    # Use .npy format for large tensors
                    np_path = f"{file_path}.{key}.npy"
                    np_array = value.detach().cpu().numpy()
                    np.save(np_path, np_array)
                    serialized[key] = {
                        'type': 'numpy_memmap',
                        'path': np_path,
                        'shape': value.shape,
                        'dtype': str(value.dtype)
                    }
                else:
                    # Use standard serialization for smaller tensors
                    serialized[key] = {
                        'type': 'tensor',
                        'data': value.detach().cpu()
                    }
            elif isinstance(value, SparseTensor):
                # Handle sparse tensors
                row, col, values = value.coo()
                serialized[key] = {
                    'type': 'sparse_tensor',
                    'row': row.detach().cpu(),
                    'col': col.detach().cpu(),
                    'values': values.detach().cpu() if values is not None else None,
                    'size': value.size()
                }
            else:
                # For other data types, use pickle directly
                serialized[key] = {
                    'type': 'pickle',
                    'data': value
                }
                
        # Save the serialized data
        with open(file_path, 'wb') as f:
            pickle.dump(serialized, f, protocol=pickle.HIGHEST_PROTOCOL)
    
    @staticmethod
    def load_data_efficient(file_path):
        """Load a PyTorch Geometric Data object efficiently.
        
        Args:
            file_path: Path to the saved data
            
        Returns:
            PyTorch Geometric Data object
        """
        with open(file_path, 'rb') as f:
            serialized = pickle.load(f)
            
        # Create a new Data object
        data = Data()
        
        # Process each serialized attribute
        for key, value in serialized.items():
            if value['type'] == 'numpy_memmap':
                # Load memory-mapped tensor
                np_array = np.load(value['path'], mmap_mode='r')
                tensor = torch.from_numpy(np_array).clone()  # Clone to avoid issues with mmap
                setattr(data, key, tensor)
            elif value['type'] == 'tensor':
                # Load standard tensor
                setattr(data, key, value['data'])
            elif value['type'] == 'sparse_tensor':
                # Recreate sparse tensor
                sparse = SparseTensor(
                    row=value['row'],
                    col=value['col'],
                    value=value['values'],
                    sparse_sizes=value['size']
                )
                setattr(data, key, sparse)
            elif value['type'] == 'pickle':
                # Load pickled data
                setattr(data, key, value['data'])
                
        return data

class EfficientCsDataset(OptimizedCsDataset):
    """A memory-efficient version of the CsDataset class"""
    
    def process(self):
        """Process the dataset using memory-efficient storage"""
        # Convert NumPy arrays to Python lists for JSON serialization
        d = {
            "manifold_codes": [v.item() for v in self.manifold_codes],
            "manifolds": [str(m) for m in self.manifold_names],
            "boundaries": [str(m) for m in self.boundaries],
            "boundary_codes": [v.item() for v in self.boundary_codes],
        }

        with open(os.path.join(self.processed_dir, "metadata.json"), "w") as f:
            json.dump(d, f)

        for file in self.raw_paths:
            if not os.path.exists(file):
                raise FileNotFoundError(f"Input file {file} does not exist.")
            with h5py.File(file, "r") as f:
                print(f"Processing file: {file}")
                print(f"Number of causal sets: {f['num_causal_sets'][()]}")
                for idx in tqdm(range(f["num_causal_sets"][()])):
                    data = load_graph(
                        f,
                        idx,
                        float_dtype=torch.float32,
                        int_dtype=torch.int64,
                        validate=self.validate_data,
                    )
                    if self.pre_filter is not None:
                        if not self.pre_filter(data):
                            continue
                    if self.pre_transform is not None:
                        data = self.pre_transform(data)
                    
                    # Use efficient storage
                    output_path = os.path.join(self.processed_dir, f"data_{idx}.pt")
                    MemoryEfficientStorage.save_data_efficient(data, output_path)

    def get(self, idx):
        """Get a data object using memory-efficient loading"""
        data_path = os.path.join(self.processed_dir, f"data_{idx}.pt")
        data = MemoryEfficientStorage.load_data_efficient(data_path)
        if self.transform is not None:
            data = self.transform(data)
        return data
```
# Storage Formats for PyTorch Geometric Data Objects

When handling large graph datasets, choosing the right storage format is crucial. Here are some options:

## 1. PyTorch's Native Format (.pt)
- **Pros**: Direct compatibility with PyTorch, fast loading in PyTorch
- **Cons**: Files can be large, not human-readable, versioning issues between PyTorch versions

## 2. Memory-Mapped Files (with .npy for large tensors)
- **Pros**: Reduced memory usage, allows working with datasets that don't fit in RAM
- **Cons**: Slightly slower access, more complex implementation

## 3. Compressed Formats
- **Pros**: Smaller file size, good for storage and transfer
- **Cons**: Additional compression/decompression overhead

## 4. HDF5 Format (.h5)
- **Pros**: Good for hierarchical data, partial loading, good compression
- **Cons**: More complex API, potential compatibility issues

## 5. LMDB (Lightning Memory-Mapped Database)
- **Pros**: Great for large datasets, fast random access, transactional
- **Cons**: More setup required, less intuitive API

Our `MemoryEfficientStorage` class implements a hybrid approach:
- Small tensors: Stored directly in PyTorch format
- Large tensors: Stored as memory-mapped NumPy arrays
- Metadata: Handled via Python's pickle format

This balances memory efficiency and performance for large graph datasets.
```python
# Example of using the optimized dataset classes

# Set up paths
datapath = os.path.join("/mnt", "dataLinux", "machinelearning_data", "QuantumGrav/causal_sets")
files = [
    os.path.join(datapath, "cset_data_min=300_max=650_N=25000_d=2.h5"),
    os.path.join(datapath, "cset_data_min=300_max=650_N=25000_d=3.h5"),
    os.path.join(datapath, "cset_data_min=300_max=650_N=25000_d=4.h5"),
]

# Define transforms
onehot_transform = OneHotEncodeTargets(manifold_classes=6, boundary_classes=3, dimension_classes=3)

# Create dataset with parallel processing
print("Creating dataset with parallel processing...")
parallel_dset = OptimizedCsDataset(
    input=files,
    output=os.path.join(datapath, "processed_parallel"),
    transform=T.ToSparseTensor,
    pre_transform=onehot_transform,
    pre_filter=None,
    validate_data=True,
)

# This would process the dataset in parallel
# Uncomment to run:
# parallel_dset.process_parallel(max_workers=8, chunk_size=100)

# Create dataset with memory-efficient storage
print("Creating dataset with memory-efficient storage...")
efficient_dset = EfficientCsDataset(
    input=files,
    output=os.path.join(datapath, "processed_efficient"),
    transform=T.ToSparseTensor,
    pre_transform=onehot_transform,
    pre_filter=None,
    validate_data=True,
)

# This would process the dataset with memory-efficient storage
# Uncomment to run:
# efficient_dset.process()

# Benchmark loading times
def benchmark_loading(dataset, num_samples=100):
    start_time = time.time()
    for i in range(min(num_samples, len(dataset))):
        _ = dataset[i]
    end_time = time.time()
    print(f"Average loading time: {(end_time - start_time) / num_samples:.4f} seconds per graph")
    ```
# PyTorch Geometric-Specific Optimizations

Beyond general parallelization and efficient storage, there are some PyTorch Geometric-specific optimizations:

## 1. Pre-Computing Sparse Formats

Converting from dense to sparse formats is computationally expensive. Pre-computing and storing these representations can save time:

```python
# Pre-compute sparse formats during processing
data.edge_index, data.edge_attr = dense_to_sparse(adjacency_matrix)
```

## 2. Using Pre-Normalized Adjacency Matrices

For GCN-based models, pre-computing the normalized adjacency matrix can speed up training:

```python
# During preprocessing:
edge_index, edge_weight = dense_to_sparse(adj_matrix)
edge_index, edge_weight = GCNConv.norm(edge_index, edge_weight, num_nodes)
```

## 3. Batching Strategy

For large graphs, consider using smaller batch sizes or specialized batching strategies.

## 4. In-Memory Dataset for Training

After preprocessing, consider using an in-memory dataset variant for training if your data fits in RAM:

```python
class InMemoryGraphDataset(InMemoryDataset):
    def __init__(self, dataset):
        self.data_list = [dataset[i] for i in range(len(dataset))]
        self.data, self.slices = self.collate(self.data_list)
```

## 5. Using `torch_sparse` and `torch_scatter`

These libraries provide optimized operations for sparse data structures and are used by PyTorch Geometric internally.

# Optimizing Dataset Processing

Let's discuss strategies to optimize and parallelize the dataset processing:

1. **Multiprocessing** to parallelize file processing
2. **Memory-mapped files** to reduce memory usage
3. **Efficient serialization** to speed up I/O operations
4. **Chunking** to process data in parallel batches
5. **Alternative storage formats** for PyTorch Geometric Data objects
import concurrent.futures
import os
from tqdm.auto import tqdm
import time

```python
def process_single_graph(args):
    """Process a single graph from an HDF5 file
    
    Args:
        args: Tuple containing (file_path, idx, output_dir, validate_data, pre_filter, pre_transform)
        
    Returns:
        int: The index of the processed graph
    """
    file_path, idx, output_dir, validate_data, pre_filter, pre_transform = args
    
    with h5py.File(file_path, "r") as f:
        data = load_graph(
            f, 
            idx,
            float_dtype=torch.float32,
            int_dtype=torch.int64,
            validate=validate_data
        )
    
    if pre_filter is not None and not pre_filter(data):
        return idx, False  # Skip this data point
        
    if pre_transform is not None:
        data = pre_transform(data)
        
    # Save to disk
    torch.save(data, os.path.join(output_dir, f"data_{idx}.pt"))
    return idx, True

class OptimizedCsDataset(CsDataset):
    def process_parallel(self, max_workers=None, chunk_size=100):
        """Process the dataset in parallel using multiprocessing.
        
        Args:
            max_workers: Maximum number of worker processes
            chunk_size: Number of graphs to process in each worker
            
        Returns:
            None
        """
        # Convert NumPy arrays to Python lists for JSON serialization
        d = {
            "manifold_codes": [v.item() for v in self.manifold_codes],
            "manifolds": [str(m) for m in self.manifold_names],
            "boundaries": [str(m) for m in self.boundaries],
            "boundary_codes": [v.item() for v in self.boundary_codes],
        }

        with open(os.path.join(self.processed_dir, "metadata.json"), "w") as f:
            json.dump(d, f)
            
        # Calculate optimal number of workers if not specified
        if max_workers is None:
            max_workers = min(os.cpu_count(), 16)  # Use at most 16 workers or CPU count
            
        print(f"Processing dataset using {max_workers} workers")
        
        start_time = time.time()
        total_processed = 0
        
        # Process each file
        for file_idx, file in enumerate(self.raw_paths):
            if not os.path.exists(file):
                raise FileNotFoundError(f"Input file {file} does not exist.")
                
            with h5py.File(file, "r") as f:
                num_graphs = f["num_causal_sets"][()]
                print(f"Processing file {file_idx+1}/{len(self.raw_paths)}: {file} with {num_graphs} graphs")
                
                # Create argument list for all graphs in this file
                args_list = [
                    (file, idx, self.processed_dir, self.validate_data, self.pre_filter, self.pre_transform)
                    for idx in range(num_graphs)
                ]
                
                # Process in parallel using chunks
                with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
                    # Submit tasks in chunks to reduce overhead
                    future_to_idx = {}
                    for i in range(0, len(args_list), chunk_size):
                        chunk_args = args_list[i:i + chunk_size]
                        future_batch = executor.map(process_single_graph, chunk_args)
                        future_to_idx[i] = future_batch
                    
                    # Process results with progress bar
                    with tqdm(total=num_graphs) as pbar:
                        for start_idx, future_batch in future_to_idx.items():
                            for idx, success in future_batch:
                                if success:
                                    total_processed += 1
                                pbar.update(1)
        
        end_time = time.time()
        print(f"Processed {total_processed} graphs in {end_time - start_time:.2f} seconds")
        print(f"Average processing time: {(end_time - start_time) / total_processed:.4f} seconds per graph")
    ```


```python
import pickle
import torch.utils.data
import tempfile
import shutil
import numpy as np
from torch_sparse import SparseTensor

class MemoryEfficientStorage:
    """A class to handle memory-efficient storage of PyTorch Geometric Data objects"""
    
    @staticmethod
    def save_data_efficient(data, file_path):
        """Save a PyTorch Geometric Data object efficiently.
        
        This method:
        1. Converts sparse tensors to COO format
        2. Uses memory-mapped files for large tensors
        3. Compresses the file
        
        Args:
            data: PyTorch Geometric Data object
            file_path: Path to save the data
        """
        # Create a dictionary to store serialized attributes
        serialized = {}
        
        # Process each attribute of the data object
        for key, value in data:
            if torch.is_tensor(value):
                # For large tensors, use memory mapping
                if value.numel() > 1000000:  # Threshold for large tensors
                    # Use .npy format for large tensors
                    np_path = f"{file_path}.{key}.npy"
                    np_array = value.detach().cpu().numpy()
                    np.save(np_path, np_array)
                    serialized[key] = {
                        'type': 'numpy_memmap',
                        'path': np_path,
                        'shape': value.shape,
                        'dtype': str(value.dtype)
                    }
                else:
                    # Use standard serialization for smaller tensors
                    serialized[key] = {
                        'type': 'tensor',
                        'data': value.detach().cpu()
                    }
            elif isinstance(value, SparseTensor):
                # Handle sparse tensors
                row, col, values = value.coo()
                serialized[key] = {
                    'type': 'sparse_tensor',
                    'row': row.detach().cpu(),
                    'col': col.detach().cpu(),
                    'values': values.detach().cpu() if values is not None else None,
                    'size': value.size()
                }
            else:
                # For other data types, use pickle directly
                serialized[key] = {
                    'type': 'pickle',
                    'data': value
                }
                
        # Save the serialized data
        with open(file_path, 'wb') as f:
            pickle.dump(serialized, f, protocol=pickle.HIGHEST_PROTOCOL)
    
    @staticmethod
    def load_data_efficient(file_path):
        """Load a PyTorch Geometric Data object efficiently.
        
        Args:
            file_path: Path to the saved data
            
        Returns:
            PyTorch Geometric Data object
        """
        with open(file_path, 'rb') as f:
            serialized = pickle.load(f)
            
        # Create a new Data object
        data = Data()
        
        # Process each serialized attribute
        for key, value in serialized.items():
            if value['type'] == 'numpy_memmap':
                # Load memory-mapped tensor
                np_array = np.load(value['path'], mmap_mode='r')
                tensor = torch.from_numpy(np_array).clone()  # Clone to avoid issues with mmap
                setattr(data, key, tensor)
            elif value['type'] == 'tensor':
                # Load standard tensor
                setattr(data, key, value['data'])
            elif value['type'] == 'sparse_tensor':
                # Recreate sparse tensor
                sparse = SparseTensor(
                    row=value['row'],
                    col=value['col'],
                    value=value['values'],
                    sparse_sizes=value['size']
                )
                setattr(data, key, sparse)
            elif value['type'] == 'pickle':
                # Load pickled data
                setattr(data, key, value['data'])
                
        return data

class EfficientCsDataset(OptimizedCsDataset):
    """A memory-efficient version of the CsDataset class"""
    
    def process(self):
        """Process the dataset using memory-efficient storage"""
        # Convert NumPy arrays to Python lists for JSON serialization
        d = {
            "manifold_codes": [v.item() for v in self.manifold_codes],
            "manifolds": [str(m) for m in self.manifold_names],
            "boundaries": [str(m) for m in self.boundaries],
            "boundary_codes": [v.item() for v in self.boundary_codes],
        }

        with open(os.path.join(self.processed_dir, "metadata.json"), "w") as f:
            json.dump(d, f)

        for file in self.raw_paths:
            if not os.path.exists(file):
                raise FileNotFoundError(f"Input file {file} does not exist.")
            with h5py.File(file, "r") as f:
                print(f"Processing file: {file}")
                print(f"Number of causal sets: {f['num_causal_sets'][()]}")
                for idx in tqdm(range(f["num_causal_sets"][()])):
                    data = load_graph(
                        f,
                        idx,
                        float_dtype=torch.float32,
                        int_dtype=torch.int64,
                        validate=self.validate_data,
                    )
                    if self.pre_filter is not None:
                        if not self.pre_filter(data):
                            continue
                    if self.pre_transform is not None:
                        data = self.pre_transform(data)
                    
                    # Use efficient storage
                    output_path = os.path.join(self.processed_dir, f"data_{idx}.pt")
                    MemoryEfficientStorage.save_data_efficient(data, output_path)

    def get(self, idx):
        """Get a data object using memory-efficient loading"""
        data_path = os.path.join(self.processed_dir, f"data_{idx}.pt")
        data = MemoryEfficientStorage.load_data_efficient(data_path)
        if self.transform is not None:
            data = self.transform(data)
        return data
```
# Storage Formats for PyTorch Geometric Data Objects

When handling large graph datasets, choosing the right storage format is crucial. Here are some options:

## 1. PyTorch's Native Format (.pt)
- **Pros**: Direct compatibility with PyTorch, fast loading in PyTorch
- **Cons**: Files can be large, not human-readable, versioning issues between PyTorch versions

## 2. Memory-Mapped Files (with .npy for large tensors)
- **Pros**: Reduced memory usage, allows working with datasets that don't fit in RAM
- **Cons**: Slightly slower access, more complex implementation

## 3. Compressed Formats
- **Pros**: Smaller file size, good for storage and transfer
- **Cons**: Additional compression/decompression overhead

## 4. HDF5 Format (.h5)
- **Pros**: Good for hierarchical data, partial loading, good compression
- **Cons**: More complex API, potential compatibility issues

## 5. LMDB (Lightning Memory-Mapped Database)
- **Pros**: Great for large datasets, fast random access, transactional
- **Cons**: More setup required, less intuitive API

Our `MemoryEfficientStorage` class implements a hybrid approach:
- Small tensors: Stored directly in PyTorch format
- Large tensors: Stored as memory-mapped NumPy arrays
- Metadata: Handled via Python's pickle format

This balances memory efficiency and performance for large graph datasets.
```python
# Example of using the optimized dataset classes

# Set up paths
datapath = os.path.join("/mnt", "dataLinux", "machinelearning_data", "QuantumGrav/causal_sets")
files = [
    os.path.join(datapath, "cset_data_min=300_max=650_N=25000_d=2.h5"),
    os.path.join(datapath, "cset_data_min=300_max=650_N=25000_d=3.h5"),
    os.path.join(datapath, "cset_data_min=300_max=650_N=25000_d=4.h5"),
]

# Define transforms
onehot_transform = OneHotEncodeTargets(manifold_classes=6, boundary_classes=3, dimension_classes=3)

# Create dataset with parallel processing
print("Creating dataset with parallel processing...")
parallel_dset = OptimizedCsDataset(
    input=files,
    output=os.path.join(datapath, "processed_parallel"),
    transform=T.ToSparseTensor,
    pre_transform=onehot_transform,
    pre_filter=None,
    validate_data=True,
)

# This would process the dataset in parallel
# Uncomment to run:
# parallel_dset.process_parallel(max_workers=8, chunk_size=100)

# Create dataset with memory-efficient storage
print("Creating dataset with memory-efficient storage...")
efficient_dset = EfficientCsDataset(
    input=files,
    output=os.path.join(datapath, "processed_efficient"),
    transform=T.ToSparseTensor,
    pre_transform=onehot_transform,
    pre_filter=None,
    validate_data=True,
)

# This would process the dataset with memory-efficient storage
# Uncomment to run:
# efficient_dset.process()

# Benchmark loading times
def benchmark_loading(dataset, num_samples=100):
    start_time = time.time()
    for i in range(min(num_samples, len(dataset))):
        _ = dataset[i]
    end_time = time.time()
    print(f"Average loading time: {(end_time - start_time) / num_samples:.4f} seconds per graph")
    ```
# PyTorch Geometric-Specific Optimizations

Beyond general parallelization and efficient storage, there are some PyTorch Geometric-specific optimizations:

## 1. Pre-Computing Sparse Formats

Converting from dense to sparse formats is computationally expensive. Pre-computing and storing these representations can save time:

```python
# Pre-compute sparse formats during processing
data.edge_index, data.edge_attr = dense_to_sparse(adjacency_matrix)
```

## 2. Using Pre-Normalized Adjacency Matrices

For GCN-based models, pre-computing the normalized adjacency matrix can speed up training:

```python
# During preprocessing:
edge_index, edge_weight = dense_to_sparse(adj_matrix)
edge_index, edge_weight = GCNConv.norm(edge_index, edge_weight, num_nodes)
```

## 3. Batching Strategy

For large graphs, consider using smaller batch sizes or specialized batching strategies.

## 4. In-Memory Dataset for Training

After preprocessing, consider using an in-memory dataset variant for training if your data fits in RAM:

```python
class InMemoryGraphDataset(InMemoryDataset):
    def __init__(self, dataset):
        self.data_list = [dataset[i] for i in range(len(dataset))]
        self.data, self.slices = self.collate(self.data_list)
```

## 5. Using `torch_sparse` and `torch_scatter`

These libraries provide optimized operations for sparse data structures and are used by PyTorch Geometric internally.