In [28]:
%load_ext autoreload
%autoreload 2

from typing import List, Dict, Any
import itertools
import random
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import QM9
from torch_geometric.data import Data
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import to_dense_adj, dense_to_sparse, remove_self_loops
from data_utils import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [29]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

max_num_nodes = 29

# TODO: pre-transform and store matrices to disk
transform = T.Compose([
    SelectQM9TargetProperties(properties=["homo", "lumo"]),
    SelectQM9NodeFeatures(features=["atom_type"]),
    AddAdjacencyMatrix(max_num_nodes=max_num_nodes),
    AddNodeAttributeMatrix(max_num_nodes=max_num_nodes),
    T.ToDevice(device=device)
])

dataset = QM9(root="./data", transform=transform)

train_dataset, val_dataset, test_dataset = create_qm9_data_split(dataset=dataset)

num_node_features = dataset.num_node_features
num_edge_features = dataset.num_edge_features

In [30]:
class Encoder(nn.Module):

    def __init__(self, hparams: Dict[str, Any]) -> None:
        super().__init__()

        # TODO: two graph convolutional layers (32 and 64 channeles) with identity connection (edge conditioned graph convolution)
        self.conv1 = GCNConv(in_channels=hparams["num_node_features"], out_channels=32)
        self.conv2 = GCNConv(in_channels=32, out_channels=64)
        self.fc = nn.Linear(in_features=64, out_features=128)

    def forward(self, data: Data):
        x, edge_index, batch, edge_attr = data.x, data.edge_index, data.batch, data.edge_attr

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.fc(x)
        return x
    
class Decoder(nn.Module):

    def __init__(self, hparams: Dict[str, Any]) -> None:
        super().__init__()

        self.fcls = nn.Sequential(
            nn.Linear(in_features=128, out_features=128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=256),
            nn.BatchNorm1d(num_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=512),
            nn.BatchNorm1d(num_features=512),
            nn.ReLU(),
        )

        self.max_num_nodes = hparams["max_num_nodes"]
        self.num_node_features = hparams["num_node_features"]

        # the atom graph is symmetric so we only predict the upper triangular part
        upper_triangular_size = int(self.max_num_nodes * (self.max_num_nodes + 1) / 2)
        self.fc_adjacency = nn.Linear(in_features=512, out_features=upper_triangular_size)

        self.fc_node_features = nn.Linear(in_features=512, out_features=self.max_num_nodes * hparams["num_node_features"])
        self.fc_edge_features = nn.Linear(in_features=512, out_features=self.max_num_nodes * hparams["num_edge_features"])
        

    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x = self.fcls(z)
        # predict upper triangular matrix including the diagonal
        adj_triu_matrix = self.fc_adjacency(x)
        node_features = self.fc_node_features(x)
        edge_features = self.fc_edge_features(x)

        # reshape matrices
        node_features = node_features.view(-1, self.max_num_nodes, self.num_node_features)

        return adj_triu_matrix, node_features, edge_features


class GraphVAE(nn.Module):

    def __init__(self, hparams: Dict[str, Any]) -> None:
        super().__init__()

        self.encoder = Encoder(hparams=hparams)
        self.decoder = Decoder(hparams=hparams)

    def forward(self, data) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        z = self.encoder(data)
        x = self.decoder(z)
        return x
    
    def reconstruction_loss(
        self, 
        input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], 
        target: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    ):
        input_adj_triu_mat, input_node_features, _ = input
        target_adj_triu_mat, target_node_features, _ = target

        # TODO: average over separately over diagonal and off-diagonal elements
        adjacency_loss = F.binary_cross_entropy_with_logits(input=input_adj_triu_mat, target=target_adj_triu_mat)

        node_logits = input_node_features.view(-1, input_node_features.size(-1))
        node_targets = target_node_features.argmax(dim=2).view(-1)
        node_feature_loss = F.cross_entropy(input=node_logits, target=node_targets)

        return adjacency_loss + node_feature_loss


In [31]:
hparams = {
    "batch_size": 32,
    "max_num_nodes": max_num_nodes,
    "learning_rate": 1e-3,
    "beta_1": 0.5,
    "epochs": 1000,
    "num_node_features": num_node_features,
    "num_edge_features": num_edge_features
}

# TODO: plot reconstruction
# TODO: implement reconstruction loss adjacency
# TODO: implement reconstruction loss node features
# TODO: implement reconstruction loss edge features

# TODO: filter chemically invalid mols from QM9 dataset
# TODO: add kl-loss
# TODO: implement encoder from the paper
# TODO: graph matching

batch_size = hparams["batch_size"]
dataloaders = {
    "train_single": DataLoader(train_dataset[:1], batch_size=batch_size, shuffle=True),
    "train_tiny": DataLoader(train_dataset[:batch_size], batch_size=batch_size, shuffle=True),
    "train_small": DataLoader(train_dataset[:4096], batch_size=batch_size, shuffle=True),
    "train": DataLoader(train_dataset, batch_size=batch_size, shuffle=True),

    "val": DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
}

val_subset_count = 32
dataloaders["val_subsets"] = create_validation_subset_loaders(validation_dataset=val_dataset, subset_count=32, batch_size=batch_size)

In [34]:
graph_vae_model = GraphVAE(hparams=hparams).to(device=device)
optimizer = torch.optim.Adam(
    graph_vae_model.parameters(),
    lr=hparams["learning_rate"],
    betas=(hparams["beta_1"], 0.999)
)
epochs = hparams["epochs"]

train_loader = dataloaders["train_tiny"]
val_subset_loader_iterator = itertools.cycle(dataloaders["val_subsets"])

validation_interval = 100

writer = create_tensorboard_writer(experiment_name="graph-vae")

for epoch in range(epochs):
    graph_vae_model.train()
    for batch_index, train_batch in enumerate(tqdm(train_loader,  desc=f"Epoch {epoch + 1} Training")):
        optimizer.zero_grad()
        train_prediction = graph_vae_model(train_batch)
        train_target = (train_batch.adj_triu_mat, train_batch.node_mat, ...)

        loss = graph_vae_model.reconstruction_loss(input=train_prediction, target=train_target)

        loss.backward()
        optimizer.step()

        iteration = len(train_loader) * epoch + batch_index
        writer.add_scalars("Loss", {"Training": loss.item()}, iteration)

        if iteration % validation_interval == 0:
            graph_vae_model.eval()
            val_loss_sum = 0

            # Get the next subset of the validation set
            val_loader = next(val_subset_loader_iterator)
            with torch.no_grad():
                for val_batch in val_loader:
                    val_prediction = graph_vae_model(val_batch)
                    val_target = (val_batch.adj_triu_mat, val_batch.node_mat, ...)
                    val_loss_sum += graph_vae_model.reconstruction_loss(
                        input=val_prediction,
                        target=val_target
                    )
            
            val_loss = val_loss_sum / len(val_loader)
            writer.add_scalars("Loss", {"Validation": val_loss.item()}, iteration)
            
            graph_vae_model.train()

    # Visualization
    # get random mol from train dataset and visualize
    sample_index = random.randint(0, len(train_loader) * batch_size)
    sample_index = 32
    sample = train_dataset[sample_index]
    # print(sample)
    writer.add_image('Input', molecule_graph_data_to_image(sample), global_step=epoch, dataformats="NCHW")

    graph_vae_model.eval()
    prediction = graph_vae_model(sample)

    # convert predicted matrices into Graph
    adj_triu_mat = torch.where(F.sigmoid(prediction[0]) > 0.5, 1.0, 0.0)[0]

    n = hparams["max_num_nodes"]
    mask = torch.ones(n, n).triu() == 1
    adj_mat = torch.zeros(n, n, device=device)
    adj_mat[mask] = adj_triu_mat
    diagonal = adj_mat.diagonal()
    # create symmetric matrix
    adj_mat = adj_mat + adj_mat.t() - torch.diag(diagonal)

    invalid_node_indices = (diagonal == 0).nonzero().flatten().cpu().tolist()

    edge_index, _ = dense_to_sparse(adj=adj_mat)
    edge_index, _ = remove_self_loops(edge_index=edge_index)

    # assume hydrogen for all atoms for now
    x = torch.zeros(29, 5)
    x[:, 0] = 1

    reconstructed_sample = Data(x=x, edge_index=edge_index)
    reconstructed_sample = remove_nodes(data=reconstructed_sample, nodes_to_remove=invalid_node_indices)

    def create_edge_attr(edge_index, edge_attr_matrix):
        # Assuming the edge_attr_matrix is ordered in the same way as the adjacency matrix
        edge_attr = edge_attr_matrix[edge_index[0], edge_index[1]]
        return edge_attr
    
    # add edge attributes
    edge_attr = torch.zeros(reconstructed_sample.num_edges, 4)
    # single bond for now
    edge_attr[:, 0] = 1
    reconstructed_sample.edge_attr = edge_attr
    
    writer.add_image('Reconstruction', molecule_graph_data_to_image(reconstructed_sample), global_step=epoch, dataformats="NCHW")  


Epoch 1 Training: 100%|██████████| 1/1 [00:01<00:00,  1.27s/it]


Node count:
20
Data(x=[20, 5], edge_index=[2, 138])


Epoch 2 Training: 100%|██████████| 1/1 [00:00<00:00,  7.54it/s]


Node count:
17
Data(x=[17, 5], edge_index=[2, 58])


Epoch 3 Training: 100%|██████████| 1/1 [00:00<00:00,  8.64it/s]


Node count:
20
Data(x=[20, 5], edge_index=[2, 48])


Epoch 4 Training: 100%|██████████| 1/1 [00:00<00:00,  6.84it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 34])


Epoch 5 Training: 100%|██████████| 1/1 [00:00<00:00,  8.60it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 22])


Epoch 6 Training: 100%|██████████| 1/1 [00:00<00:00,  8.70it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 22])


Epoch 7 Training: 100%|██████████| 1/1 [00:00<00:00,  8.29it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 22])


Epoch 8 Training: 100%|██████████| 1/1 [00:00<00:00,  8.92it/s]

Node count:
19





Data(x=[19, 5], edge_index=[2, 22])


Epoch 9 Training: 100%|██████████| 1/1 [00:00<00:00,  6.64it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 22])


Epoch 10 Training: 100%|██████████| 1/1 [00:00<00:00,  8.46it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 22])


Epoch 11 Training: 100%|██████████| 1/1 [00:00<00:00,  7.36it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 22])


Epoch 12 Training: 100%|██████████| 1/1 [00:00<00:00,  8.94it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 22])


Epoch 13 Training: 100%|██████████| 1/1 [00:00<00:00,  8.72it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 22])


Epoch 14 Training: 100%|██████████| 1/1 [00:00<00:00,  8.24it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 22])


Epoch 15 Training: 100%|██████████| 1/1 [00:00<00:00,  8.73it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 20])


Epoch 16 Training: 100%|██████████| 1/1 [00:00<00:00,  8.63it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 22])


Epoch 17 Training: 100%|██████████| 1/1 [00:00<00:00,  8.32it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 22])


Epoch 18 Training: 100%|██████████| 1/1 [00:00<00:00,  8.54it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 22])


Epoch 19 Training: 100%|██████████| 1/1 [00:00<00:00,  7.61it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 22])


Epoch 20 Training: 100%|██████████| 1/1 [00:00<00:00,  8.14it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 24])


Epoch 21 Training: 100%|██████████| 1/1 [00:00<00:00,  7.40it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 24])


Epoch 22 Training: 100%|██████████| 1/1 [00:00<00:00,  8.27it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 24])


Epoch 23 Training: 100%|██████████| 1/1 [00:00<00:00,  8.48it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 24])


Epoch 24 Training: 100%|██████████| 1/1 [00:00<00:00,  8.12it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 24])


Epoch 25 Training: 100%|██████████| 1/1 [00:00<00:00,  8.13it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 24])


Epoch 26 Training: 100%|██████████| 1/1 [00:00<00:00,  7.78it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 24])


Epoch 27 Training: 100%|██████████| 1/1 [00:00<00:00,  7.40it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 24])


Epoch 28 Training: 100%|██████████| 1/1 [00:00<00:00,  8.50it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 24])


Epoch 29 Training: 100%|██████████| 1/1 [00:00<00:00,  8.44it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 24])


Epoch 30 Training: 100%|██████████| 1/1 [00:00<00:00,  7.82it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 24])


Epoch 31 Training: 100%|██████████| 1/1 [00:00<00:00,  7.00it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 24])


Epoch 32 Training: 100%|██████████| 1/1 [00:00<00:00,  7.89it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 24])


Epoch 33 Training: 100%|██████████| 1/1 [00:00<00:00,  5.91it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 26])


Epoch 34 Training: 100%|██████████| 1/1 [00:00<00:00,  7.97it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 26])


Epoch 35 Training: 100%|██████████| 1/1 [00:00<00:00,  8.82it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 26])


Epoch 36 Training: 100%|██████████| 1/1 [00:00<00:00,  7.78it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 26])


Epoch 37 Training: 100%|██████████| 1/1 [00:00<00:00,  7.72it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 26])


Epoch 38 Training: 100%|██████████| 1/1 [00:00<00:00,  7.05it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 26])


Epoch 39 Training: 100%|██████████| 1/1 [00:00<00:00,  7.75it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 28])


Epoch 40 Training: 100%|██████████| 1/1 [00:00<00:00,  9.01it/s]

Node count:
18





Data(x=[18, 5], edge_index=[2, 32])


Epoch 41 Training: 100%|██████████| 1/1 [00:00<00:00,  8.28it/s]

Node count:
18





Data(x=[18, 5], edge_index=[2, 32])


Epoch 42 Training: 100%|██████████| 1/1 [00:00<00:00,  7.96it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 32])


Epoch 43 Training: 100%|██████████| 1/1 [00:00<00:00,  7.00it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 32])


Epoch 44 Training: 100%|██████████| 1/1 [00:00<00:00,  8.01it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 32])


Epoch 45 Training: 100%|██████████| 1/1 [00:00<00:00,  6.63it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 32])


Epoch 46 Training: 100%|██████████| 1/1 [00:00<00:00,  8.43it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 32])


Epoch 47 Training: 100%|██████████| 1/1 [00:00<00:00,  8.12it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 32])


Epoch 48 Training: 100%|██████████| 1/1 [00:00<00:00,  7.65it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 32])


Epoch 49 Training: 100%|██████████| 1/1 [00:00<00:00,  6.86it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 32])


Epoch 50 Training: 100%|██████████| 1/1 [00:00<00:00,  7.85it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 32])


Epoch 51 Training: 100%|██████████| 1/1 [00:00<00:00,  7.67it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 32])


Epoch 52 Training: 100%|██████████| 1/1 [00:00<00:00,  7.53it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 32])


Epoch 53 Training: 100%|██████████| 1/1 [00:00<00:00,  8.56it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 32])


Epoch 54 Training: 100%|██████████| 1/1 [00:00<00:00,  6.91it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 32])


Epoch 55 Training: 100%|██████████| 1/1 [00:00<00:00,  8.44it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 34])


Epoch 56 Training: 100%|██████████| 1/1 [00:00<00:00,  9.30it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 34])


Epoch 57 Training: 100%|██████████| 1/1 [00:00<00:00,  6.31it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 36])


Epoch 58 Training: 100%|██████████| 1/1 [00:00<00:00,  8.42it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 36])


Epoch 59 Training: 100%|██████████| 1/1 [00:00<00:00,  7.30it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 36])


Epoch 60 Training: 100%|██████████| 1/1 [00:00<00:00,  7.41it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 38])


Epoch 61 Training: 100%|██████████| 1/1 [00:00<00:00,  8.52it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 38])


Epoch 62 Training: 100%|██████████| 1/1 [00:00<00:00,  8.79it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 38])


Epoch 63 Training: 100%|██████████| 1/1 [00:00<00:00,  8.70it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 38])


Epoch 64 Training: 100%|██████████| 1/1 [00:00<00:00,  7.51it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 38])


Epoch 65 Training: 100%|██████████| 1/1 [00:00<00:00,  7.75it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 38])


Epoch 66 Training: 100%|██████████| 1/1 [00:00<00:00,  8.19it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 38])


Epoch 67 Training: 100%|██████████| 1/1 [00:00<00:00,  9.34it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 38])


Epoch 68 Training: 100%|██████████| 1/1 [00:00<00:00,  9.12it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 38])


Epoch 69 Training: 100%|██████████| 1/1 [00:00<00:00,  8.53it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 38])


Epoch 70 Training: 100%|██████████| 1/1 [00:00<00:00,  7.93it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 36])


Epoch 71 Training: 100%|██████████| 1/1 [00:00<00:00,  9.67it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 36])


Epoch 72 Training: 100%|██████████| 1/1 [00:00<00:00,  7.97it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 36])


Epoch 73 Training: 100%|██████████| 1/1 [00:00<00:00,  8.71it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 34])


Epoch 74 Training: 100%|██████████| 1/1 [00:00<00:00,  7.62it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 36])


Epoch 75 Training: 100%|██████████| 1/1 [00:00<00:00,  7.98it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 36])


Epoch 76 Training: 100%|██████████| 1/1 [00:00<00:00,  7.37it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 36])


Epoch 77 Training: 100%|██████████| 1/1 [00:00<00:00,  8.80it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 36])


Epoch 78 Training: 100%|██████████| 1/1 [00:00<00:00,  8.88it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 34])


Epoch 79 Training: 100%|██████████| 1/1 [00:00<00:00,  9.13it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 34])


Epoch 80 Training: 100%|██████████| 1/1 [00:00<00:00,  6.20it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 32])


Epoch 81 Training: 100%|██████████| 1/1 [00:00<00:00,  8.64it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 32])


Epoch 82 Training: 100%|██████████| 1/1 [00:00<00:00,  6.89it/s]


Node count:
18
Data(x=[18, 5], edge_index=[2, 32])


Epoch 83 Training: 100%|██████████| 1/1 [00:00<00:00,  8.55it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 30])


Epoch 84 Training: 100%|██████████| 1/1 [00:00<00:00,  7.24it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 28])


Epoch 85 Training: 100%|██████████| 1/1 [00:00<00:00,  7.58it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 26])


Epoch 86 Training: 100%|██████████| 1/1 [00:00<00:00,  7.53it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 24])


Epoch 87 Training: 100%|██████████| 1/1 [00:00<00:00,  7.51it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 22])


Epoch 88 Training: 100%|██████████| 1/1 [00:00<00:00,  7.86it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 22])


Epoch 89 Training: 100%|██████████| 1/1 [00:00<00:00,  7.26it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 22])


Epoch 90 Training: 100%|██████████| 1/1 [00:00<00:00,  7.15it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 22])


Epoch 91 Training: 100%|██████████| 1/1 [00:00<00:00,  8.77it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 22])


Epoch 92 Training: 100%|██████████| 1/1 [00:00<00:00,  7.02it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 22])


Epoch 93 Training: 100%|██████████| 1/1 [00:00<00:00,  8.61it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 22])


Epoch 94 Training: 100%|██████████| 1/1 [00:00<00:00,  8.17it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 22])


Epoch 95 Training: 100%|██████████| 1/1 [00:00<00:00,  7.80it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 22])


Epoch 96 Training: 100%|██████████| 1/1 [00:00<00:00,  7.69it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 22])


Epoch 97 Training: 100%|██████████| 1/1 [00:00<00:00,  8.27it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 22])


Epoch 98 Training: 100%|██████████| 1/1 [00:00<00:00,  7.61it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 22])


Epoch 99 Training: 100%|██████████| 1/1 [00:00<00:00,  7.86it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 22])


Epoch 100 Training: 100%|██████████| 1/1 [00:00<00:00,  8.13it/s]


Node count:
19
Data(x=[19, 5], edge_index=[2, 22])


Epoch 101 Training:   0%|          | 0/1 [00:00<?, ?it/s]


KeyboardInterrupt: 