In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric as tg

In [2]:
# import QM9 dataset
from torch_geometric.datasets import QM9

In [3]:
# load some data from QM9
data = QM9('~/data/QM9') # this will download the dataset if it is not already present
sample_data = data[0]

In [4]:
class Subgraph:
    def __init__(self, graph, depth=1):
        self.num_nodes = graph.x.shape[0]
        self.subgraph = graph.clone()

    def convert_to_subgraph(self):
        self.add_subnode_features()
        self.add_node_flags()
        self.add_subnode_edges()
        self.add_node_subnode_edges()
        self.add_subgraph_batch_index()
        return self.subgraph

    def add_subnode_features(self):
        self.subgraph.x = self.subgraph.x.repeat(self.num_nodes+1,1)
        self.total_num_nodes = self.subgraph.x.shape[0]

    def add_node_flags(self):
        if hasattr(self.subgraph, 'x'):
            self.subgraph.ground_node = torch.arange(self.subgraph.x.shape[0]) < self.num_nodes
        else:
            print('No node features found. Please add node features first.')

    def add_subnode_edges(self):
        self.subgraph.subgraph_edge_index = self.subgraph.edge_index + self.num_nodes
        for subg in range(self.num_nodes):
            self.subgraph.subgraph_edge_index = torch.cat([self.subgraph.subgraph_edge_index, self.subgraph.edge_index + (subg+1)*self.num_nodes], dim=1)

    def add_node_subnode_edges(self):
        self.subgraph.node_subnode_index = torch.stack([torch.arange(self.num_nodes).repeat_interleave(self.num_nodes), torch.arange(self.num_nodes, self.total_num_nodes)], dim=0)

    def add_subgraph_batch_index(self):
        self.subgraph.subgraph_batch_index = torch.arange(self.num_nodes).repeat_interleave(self.num_nodes)

In [5]:
class FractalMP(tg.nn.MessagePassing):
    """Message Passing Neural Network Layer"""
    def __init__(
        self,
        node_features,
        edge_features,
        hidden_features,
        out_features,
        aggr="add",
        act=nn.ReLU,
        edge_inference=False,
    ):
        super().__init__(aggr=aggr)
        self.edge_inference = edge_inference
        self.message_net = nn.Sequential(
            nn.Linear(2 * node_features + edge_features, hidden_features),
            act(),
            nn.Linear(hidden_features, hidden_features),
            act(),
        )

        self.update_net = nn.Sequential(
            nn.Linear(node_features + hidden_features, hidden_features),
            act(),
            nn.Linear(hidden_features, out_features),
        )

        if edge_inference:
            self.edge_inferrer = nn.Sequential(
                nn.Linear(hidden_features, 1), nn.Sigmoid()
            )

    def forward(self, x, edge_index, subgraph_edge_index, node_subnode_index, ground_node, subgraph_batch_index, edge_attr=None):
        """Propagate"""
        x = self.propagate(edge_index, x=x, edge_attr=edge_attr)
        x = self.propagate(node_subnode_index, x=x, edge_attr=edge_attr)
        x = self.propagate(subgraph_edge_index, x=x, edge_attr=edge_attr)
        # global pool over nodes whose ground node is false
        x[ground_node] = tg.nn.global_mean_pool(x[~ground_node], subgraph_batch_index)
        return x

    def message(self, x_i, x_j, edge_attr):
        """Send message with edge attributes"""
        input = [x_i, x_j, edge_attr]
        input = [val for val in input if val is not None]
        input = torch.cat(input, dim=-1)
        message = self.message_net(input)

        if self.edge_inference:
            message = message * self.edge_inferrer(message)
        return message

    def update(self, message, x):
        """Update node"""
        input = torch.cat((x, message), dim=-1)
        update = self.update_net(input)
        return update

In [6]:
class FractalNet(nn.Module):
    def __init__(self, node_features, edge_features, hidden_features, out_features, depth=1):
        super().__init__()
        self.depth = depth
        self.embedding = nn.Linear(node_features, hidden_features)
        self.fractal_mps = nn.ModuleList()
        for i in range(depth):
            self.fractal_mps.append(FractalMP(hidden_features, edge_features, hidden_features, hidden_features))
        self.output = nn.Linear(hidden_features, out_features)

    def forward(self, x, edge_index, subgraph_edge_index, node_subnode_index, ground_node, subgraph_batch_index, batch_idx, edge_attr=None):
        x = self.embedding(x)
        for i in range(self.depth):
            x = self.fractal_mps[i](x, edge_index, subgraph_edge_index, node_subnode_index, ground_node, subgraph_batch_index, edge_attr)
        x = self.output(x)
        # global pooling over nodes whose ground node is true
        x = tg.nn.global_mean_pool(x[ground_node], batch_idx)
        return x

In [7]:
sample_graph = Subgraph(sample_data).convert_to_subgraph()
node_features = sample_graph.x.shape[1]
edge_features = 0
hidden_features = 64
out_features = 1

In [9]:
# create a fractal net and train it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = FractalNet(node_features, edge_features, hidden_features, out_features, depth=3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

In [29]:
from warnings import warn
LABEL_INDEX = 7
def get_qm9(data_dir, device="cuda"):
    """Download the QM9 dataset from pytorch geometric. Put it onto the device. Split it up into train / validation / test.
    Args:
        data_dir: the directory to store the data.
        device: put the data onto this device.
    Returns:
        train dataset, validation dataset, test dataset.
    """
    dataset = QM9(data_dir)

    # Permute the dataset
    try:
        permu = torch.load("permute.pt")
        dataset = dataset[permu]
    except FileNotFoundError:
        warn("Using non-standard permutation since permute.pt does not exist.")
        dataset, _ = dataset.shuffle(return_perm=True)

    # z score / standard score targets to mean = 0 and std = 1.
    mean = dataset.data.y.mean(dim=0, keepdim=True)
    std = dataset.data.y.std(dim=0, keepdim=True)
    dataset.data.y = (dataset.data.y - mean) / std
    mean, std = mean[:, LABEL_INDEX].item(), std[:, LABEL_INDEX].item()

    # Move the data to the device (it should fit on lisa gpus)
    dataset.data = dataset.data.to(device)

    len_train = 100_000
    len_val = 10_000

    train = dataset[:len_train]
    valid = dataset[len_train : len_train + len_val]
    test = dataset[len_train + len_val :]

    assert len(dataset) == len(train) + len(valid) + len(test)

    return train, valid, test

In [30]:
# create a dataloader for qm9
from torch_geometric.loader import DataLoader
dataset , _ ,_  = get_qm9("data/qm9", device=device)
# take a subset of the dataset
dataset = dataset[:1000]
loader = DataLoader(dataset, batch_size=1, shuffle=True)

  warn("Using non-standard permutation since permute.pt does not exist.")


In [32]:
# start training the model
# import tqdm to track progress of loss
import tqdm.notebook as tqdm
model.train()
for epoch in range(10):
    avg_loss = 0
    for data in tqdm.tqdm(loader):
        data = data.to(device)
        data = Subgraph(data).convert_to_subgraph().to(device)
        optimizer.zero_grad()
        target = data.y[:, LABEL_INDEX]
        out = model(data.x, data.edge_index, data.subgraph_edge_index, data.node_subnode_index, data.ground_node, data.subgraph_batch_index, data.batch)
        loss = criterion(out.mean(), target.mean())
        loss.backward()
        avg_loss += loss.item()
        optimizer.step()
        # show loss on tqdm
        #tqdm.tqdm.write(f'Epoch: {epoch}, Loss: {loss.item()}')
    print(f'Epoch: {epoch}, Loss: {avg_loss/len(loader)}')


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch: 0, Loss: 1.0402429077640636


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch: 1, Loss: 0.9551820499933509


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch: 2, Loss: 0.8987012124667024


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch: 3, Loss: 0.9026593807928802


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch: 4, Loss: 0.8925886963500979


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch: 5, Loss: 0.8683935219835505


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch: 6, Loss: 0.7846923819308054


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch: 7, Loss: 0.825826865932264


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch: 8, Loss: 0.7890269650978861


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch: 9, Loss: 0.7157069895064606


In [None]:
# MISC
def graph_to_subgraph(graph, depth=1):
    num_nodes = graph.x.shape[0]
    subgraph = graph.clone()
    subgraph.x = subgraph.x.repeat(num_nodes+1,1) # Initialize subnodes by concatenation
    total_num_nodes = subgraph.x.shape[0]
    ground_node = torch.arange(subgraph.x.shape[0]) < num_nodes # Mask for whether the node is a subnode or not
    subgraph.ground_node = ground_node # Add to the Data object
    # Create subgraph edge index which is same as edge index but for every subgraph
    subgraph.subgraph_edge_index = subgraph.edge_index + num_nodes
    for subg in range(num_nodes):
        subgraph.subgraph_edge_index = torch.cat([subgraph.subgraph_edge_index, subgraph.edge_index + (subg+1)*num_nodes], dim=1)
    # Create edge index for directed edges between ground node and it's subnodes
    subgraph.interaction_index = torch.stack([torch.arange(num_nodes).repeat_interleave(num_nodes), torch.arange(num_nodes, total_num_nodes)], dim=0)