<a href="https://colab.research.google.com/github/reiniscimurs/gnn_with_pytorch/blob/main/section_4/01_mini_batch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Implementation of Mini-Batch Method
Using PyTorch Geometric's DataLoader, mini-batches are extracted from the training data.

## Installation of PyTorch Geometric
Install the library "PyTorch Geometric" for Graph Neural Networks (GNN), as well as related libraries.

In [None]:
!pip install torch-geometric
!pip install scipy

## Loading the Dataset
We will read the "MUTAG" dataset from the commonly used benchmark dataset, TUDataset.
MUTAG contains 188 graphs.
The following code reads the MUTAG dataset.

In [None]:
from torch_geometric.datasets import TUDataset

dataset = TUDataset(root="/tmp/MUTAG", name="MUTAG")

print("Number of graphs:", len(dataset))
print("Number of classes:", dataset.num_classes)

data = dataset[0]  # first graph
print(data)


Set up a function to display information about the graph.

In [3]:
def graph_info(data):

    print("Number of nodes:", data.num_nodes)
    print("Number of edges:", data.num_edges)
    print("Number of features per node:", data.num_node_features)
    print("Is the graph undirected?", data.is_undirected())
    print("Does it have isolated nodes?", data.has_isolated_nodes())
    print("Does it have self-loops?", data.has_self_loops())

    print()

    print("Keys: ", data.keys)
    print("Features for each node:")
    print(data["x"])
    print("Labels for each node:")
    print(data["y"])
    print("Edges:")
    print(data["edge_index"])


Using the function, display information about the first graph.

In [None]:
graph_info(data)

## Extraction of Mini-Batches
Using PyTorch Geometric's DataLoader, mini-batches are randomly sampled from the training data.
Each mini-batch contains multiple graphs.

In [None]:
from torch_geometric.loader import DataLoader

batch_size = 64  # Number of graphs in each batch
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

for batch in loader:
    print("Batch:", batch)
    print("Number of graphs in the batch:", batch.num_graphs)
    print()