In [2]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric as tg
import tqdm.notebook as tqdm
from torch_geometric.loader import DataLoader
# import QM9 dataset
from torch_geometric.datasets import QM9

In [3]:
# load some data from QM9
data = QM9('~/data/QM9') # this will download the dataset if it is not already present
sample_data = data[0]

In [4]:
class Subgraph:
    def __init__(self, graph, depth=1):
        self.num_nodes = graph.x.shape[0]
        self.subgraph = graph.clone()

    def convert_to_subgraph(self):
        self.add_subnode_features()
        self.add_node_flags()
        self.add_subnode_edges()
        self.add_node_subnode_edges()
        self.add_subnode_node_edges()
        self.add_subgraph_batch_index()
        return self.subgraph

    def add_subnode_features(self):
        self.subgraph.x = self.subgraph.x.repeat(self.num_nodes+1,1)
        self.total_num_nodes = self.subgraph.x.shape[0]

    def add_node_flags(self):
        if hasattr(self.subgraph, 'x'):
            self.subgraph.ground_node = torch.arange(self.subgraph.x.shape[0]) < self.num_nodes
        else:
            print('No node features found. Please add node features first.')

    def add_subnode_edges(self):
        self.subgraph.subgraph_edge_index = self.subgraph.edge_index + self.num_nodes
        for subg in range(self.num_nodes):
            self.subgraph.subgraph_edge_index = torch.cat([self.subgraph.subgraph_edge_index, self.subgraph.edge_index + (subg+1)*self.num_nodes], dim=1)

    def add_node_subnode_edges(self):
        self.subgraph.node_subnode_index = torch.stack([torch.arange(self.num_nodes).repeat_interleave(self.num_nodes), torch.arange(self.num_nodes, self.total_num_nodes)], dim=0)

    def add_subnode_node_edges(self):
        self.subgraph.subnode_node_index = torch.stack([torch.arange(self.num_nodes, self.total_num_nodes), torch.arange(self.num_nodes).repeat_interleave(self.num_nodes)], dim=0)

    def add_subgraph_batch_index(self):
        self.subgraph.subgraph_batch_index = torch.arange(self.num_nodes).repeat_interleave(self.num_nodes)

In [5]:
class FractalMP(tg.nn.MessagePassing):
    """Message Passing Neural Network Layer"""
    def __init__(
        self,
        node_features,
        edge_features,
        hidden_features,
        out_features,
        aggr="add",
        act=nn.ReLU,
        edge_inference=False,
    ):
        super().__init__(aggr=aggr)
        self.edge_inference = edge_inference
        self.message_net = nn.Sequential(
            nn.Linear(2 * node_features + edge_features, hidden_features),
            act(),
            nn.Linear(hidden_features, hidden_features),
            act(),
        )

        self.update_net = nn.Sequential(
            nn.Linear(node_features + hidden_features, hidden_features),
            act(),
            nn.Linear(hidden_features, out_features),
        )

        if edge_inference:
            self.edge_inferrer = nn.Sequential(
                nn.Linear(hidden_features, 1), nn.Sigmoid()
            )

    def forward(self, x, edge_index, subgraph_edge_index, node_subnode_index, subnode_node_index ,ground_node, subgraph_batch_index, edge_attr=None):
        """Propagate"""
        x = self.propagate(edge_index, x=x, edge_attr=edge_attr)
        x = self.propagate(node_subnode_index, x=x, edge_attr=edge_attr)
        x = self.propagate(subgraph_edge_index, x=x, edge_attr=edge_attr)
        x = self.propagate(subnode_node_index, x=x, edge_attr=edge_attr)
        # global pool over nodes whose ground node is false
        #x[ground_node] = tg.nn.global_mean_pool(x[~ground_node], subgraph_batch_index)
        return x

    def message(self, x_i, x_j, edge_attr):
        """Send message with edge attributes"""
        input = [x_i, x_j, edge_attr]
        input = [val for val in input if val is not None]
        input = torch.cat(input, dim=-1)
        message = self.message_net(input)

        if self.edge_inference:
            message = message * self.edge_inferrer(message)
        return message

    def update(self, message, x):
        """Update node"""
        input = torch.cat((x, message), dim=-1)
        update = self.update_net(input)
        return update

In [6]:
class MP(tg.nn.MessagePassing):
    """Message Passing Neural Network Layer"""
    def __init__(
        self,
        node_features,
        edge_features,
        hidden_features,
        out_features,
        aggr="add",
        act=nn.ReLU,
        edge_inference=False,
    ):
        super().__init__(aggr=aggr)
        self.edge_inference = edge_inference
        self.message_net = nn.Sequential(
            nn.Linear(2 * node_features + edge_features, hidden_features),
            act(),
            nn.Linear(hidden_features, hidden_features),
            act(),
        )

        self.update_net = nn.Sequential(
            nn.Linear(node_features + hidden_features, hidden_features),
            act(),
            nn.Linear(hidden_features, out_features),
        )

        if edge_inference:
            self.edge_inferrer = nn.Sequential(
                nn.Linear(hidden_features, 1), nn.Sigmoid()
            )

    def forward(self, x, edge_index, edge_attr=None):
        """Propagate"""
        x = self.propagate(edge_index, x=x, edge_attr=edge_attr)
        # global pool over nodes whose ground node is false
        #x[ground_node] = tg.nn.global_mean_pool(x[~ground_node], subgraph_batch_index)
        return x

    def message(self, x_i, x_j, edge_attr):
        """Send message with edge attributes"""
        input = [x_i, x_j, edge_attr]
        input = [val for val in input if val is not None]
        input = torch.cat(input, dim=-1)
        message = self.message_net(input)

        if self.edge_inference:
            message = message * self.edge_inferrer(message)
        return message

    def update(self, message, x):
        """Update node"""
        input = torch.cat((x, message), dim=-1)
        update = self.update_net(input)
        return update

In [7]:
class FractalNet(nn.Module):
    def __init__(self, node_features, edge_features, hidden_features, out_features, depth=1):
        super().__init__()
        self.depth = depth
        self.embedding = nn.Linear(node_features, hidden_features)
        self.fractal_mps = nn.ModuleList()
        for i in range(depth):
            self.fractal_mps.append(FractalMP(hidden_features, edge_features, hidden_features, hidden_features))
        self.output = nn.Linear(hidden_features, out_features)

    def forward(self, x, edge_index, subgraph_edge_index, node_subnode_index, subnode_node_index ,ground_node, subgraph_batch_index, batch_idx, edge_attr=None):
        x = self.embedding(x)
        for i in range(self.depth):
            x = self.fractal_mps[i](x, edge_index, subgraph_edge_index, node_subnode_index, subnode_node_index, ground_node, subgraph_batch_index, edge_attr)
        x = self.output(x)
        # global pooling over nodes whose ground node is true
        x = tg.nn.global_mean_pool(x[ground_node], batch_idx)
        return x

In [17]:
class FractalNetSeparated(nn.Module):
    def __init__(self, node_features, edge_features, hidden_features, out_features, depth=1):
        super().__init__()
        self.depth = depth
        self.embedding = nn.Linear(node_features, hidden_features)
        self.ground_mps = nn.ModuleList()
        self.ground_to_sub_mps = nn.ModuleList()
        self.sub_mps = nn.ModuleList()
        self.sub_to_ground_mps = nn.ModuleList()
        for i in range(depth):
            self.ground_mps.append(MP(hidden_features, edge_features, hidden_features, hidden_features))
            self.ground_to_sub_mps.append(MP(hidden_features, edge_features, hidden_features, hidden_features))
            self.sub_mps.append(MP(hidden_features, edge_features, hidden_features, hidden_features))
            self.sub_to_ground_mps.append(MP(hidden_features, edge_features, hidden_features, hidden_features))
        self.output = nn.Linear(hidden_features, out_features)

    def forward(self, x, edge_index, subgraph_edge_index, node_subnode_index, subnode_node_index ,ground_node, subgraph_batch_index, batch_idx, edge_attr=None):
        x = self.embedding(x)
        for i in range(self.depth):
            x = self.ground_mps[i](x, edge_index, edge_attr)
            x = self.ground_to_sub_mps[i](x, node_subnode_index, edge_attr)
            x = self.sub_mps[i](x, subgraph_edge_index, edge_attr)
            x = self.sub_to_ground_mps[i](x, subnode_node_index, edge_attr)
        x = self.output(x)
        print('before global mean pool', x.shape)
        # global pooling over nodes whose ground node is true
        x = tg.nn.global_mean_pool(x[ground_node], batch_idx)
        print('after global mean pool', x.shape)
        return x

In [9]:
class Net(nn.Module):
    def __init__(self, node_features, edge_features, hidden_features, out_features, depth=1):
        super().__init__()
        self.depth = depth
        self.embedding = nn.Linear(node_features, hidden_features)
        self.mps= nn.ModuleList()
        for i in range(depth):
            self.mps.append(MP(hidden_features, edge_features, hidden_features, hidden_features))
        self.output = nn.Linear(hidden_features, out_features)

    def forward(self, x, edge_index, batch_idx, edge_attr=None):
        x = self.embedding(x)
        for i in range(self.depth):
            x = self.mps[i](x, edge_index, edge_attr)
        x = self.output(x)
        # global pooling over nodes whose ground node is true
        x = tg.nn.global_mean_pool(x, batch_idx)
        return x

In [10]:
from warnings import warn
LABEL_INDEX = 7
def get_qm9(data_dir, device="cuda"):
    """Download the QM9 dataset from pytorch geometric. Put it onto the device. Split it up into train / validation / test.
    Args:
        data_dir: the directory to store the data.
        device: put the data onto this device.
    Returns:
        train dataset, validation dataset, test dataset.
    """
    dataset = QM9(data_dir)

    # Permute the dataset
    try:
        permu = torch.load("permute.pt")
        dataset = dataset[permu]
    except FileNotFoundError:
        warn("Using non-standard permutation since permute.pt does not exist.")
        dataset, _ = dataset.shuffle(return_perm=True)

    # z score / standard score targets to mean = 0 and std = 1.
    mean = dataset.data.y.mean(dim=0, keepdim=True)
    std = dataset.data.y.std(dim=0, keepdim=True)
    dataset.data.y = (dataset.data.y - mean) / std
    mean, std = mean[:, LABEL_INDEX].item(), std[:, LABEL_INDEX].item()

    # Move the data to the device (it should fit on lisa gpus)
    dataset.data = dataset.data.to(device)

    len_train = 100_000
    len_val = 10_000

    train = dataset[:len_train]
    valid = dataset[len_train : len_train + len_val]
    test = dataset[len_train + len_val :]

    assert len(dataset) == len(train) + len(valid) + len(test)

    return train, valid, test

In [11]:
sample_graph = Subgraph(sample_data).convert_to_subgraph()
node_features = sample_graph.x.shape[1]
edge_features = 0
hidden_features = 64
out_features = 1

In [None]:
### TRAINING A REGULAR FRACTAL NET ###

In [12]:
# create a fractal net and train it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = FractalNet(node_features, edge_features, hidden_features, out_features, depth=3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
# create a dataloader for qm9
dataset , _ ,_  = get_qm9("data/qm9", device=device)
# take a subset of the dataset
dataset = dataset[:1000]
loader = DataLoader(dataset, batch_size=1, shuffle=True)

  warn("Using non-standard permutation since permute.pt does not exist.")


KeyboardInterrupt: 

In [None]:
model.train()
for epoch in range(10):
    avg_loss = 0
    for data in tqdm.tqdm(loader):
        data = data.to(device)
        data = Subgraph(data).convert_to_subgraph().to(device)
        optimizer.zero_grad()
        target = data.y[:, LABEL_INDEX]
        out = model(data.x, data.edge_index, data.subgraph_edge_index, data.node_subnode_index, data.subnode_node_index,data.ground_node, data.subgraph_batch_index, data.batch)
        loss = criterion(out.mean(), target.mean())
        loss.backward()
        avg_loss += loss.item()
        optimizer.step()
        # show loss on tqdm
        #tqdm.tqdm.write(f'Epoch: {epoch}, Loss: {loss.item()}')
    print(f'Epoch: {epoch}, Loss: {avg_loss/len(loader)}')

In [None]:
### TRAINING A UNROLLED FRACTAL NET ###

In [18]:
# create a fractal net and train it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = FractalNetSeparated(node_features, edge_features, hidden_features, out_features, depth=3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
# create a dataloader for qm9
dataset , _ ,_  = get_qm9("data/qm9", device=device)
# take a subset of the dataset
dataset = dataset[:1000]
loader = DataLoader(dataset, batch_size=1, shuffle=True)

  warn("Using non-standard permutation since permute.pt does not exist.")


In [23]:
model.train()
for epoch in range(10):
    avg_loss = 0
    for data in tqdm.tqdm(loader):
        data = data.to(device)
        data = Subgraph(data).convert_to_subgraph().to(device)
        optimizer.zero_grad()
        target = data.y[:, LABEL_INDEX]
        out = model(data.x, data.edge_index, data.subgraph_edge_index, data.node_subnode_index, data.subnode_node_index,data.ground_node, data.subgraph_batch_index, data.batch)
        print('out is ', out)
        print('target is ', target)
        loss = criterion(out.mean(), target.mean())
        loss.backward()
        avg_loss += loss.item()
        optimizer.step()
        # show loss on tqdm
        #tqdm.tqdm.write(f'Epoch: {epoch}, Loss: {loss.item()}')
    print(f'Epoch: {epoch}, Loss: {avg_loss/len(loader)}')

  0%|          | 0/1000 [00:00<?, ?it/s]

before global mean pool torch.Size([380, 1])
after global mean pool torch.Size([1, 1])
out is  tensor([[0.1307]], device='cuda:0', grad_fn=<DivBackward0>)
target is  tensor([0.7190], device='cuda:0')
before global mean pool torch.Size([600, 1])
after global mean pool torch.Size([1, 1])
out is  tensor([[0.1359]], device='cuda:0', grad_fn=<DivBackward0>)
target is  tensor([0.1326], device='cuda:0')
before global mean pool torch.Size([552, 1])
after global mean pool torch.Size([1, 1])
out is  tensor([[0.1348]], device='cuda:0', grad_fn=<DivBackward0>)
target is  tensor([-0.3664], device='cuda:0')
before global mean pool torch.Size([240, 1])
after global mean pool torch.Size([1, 1])
out is  tensor([[0.1264]], device='cuda:0', grad_fn=<DivBackward0>)
target is  tensor([2.5660], device='cuda:0')
before global mean pool torch.Size([420, 1])
after global mean pool torch.Size([1, 1])
out is  tensor([[0.1337]], device='cuda:0', grad_fn=<DivBackward0>)
target is  tensor([0.6191], device='cuda:0')

KeyboardInterrupt: 

In [None]:
### TRAINING A NORMAL NET ###

In [None]:
# create a fractal net and train it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(node_features, edge_features, hidden_features, out_features, depth=3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
# create a dataloader for qm9
dataset , _ ,_  = get_qm9("data/qm9", device=device)
# take a subset of the dataset
dataset = dataset[:1000]
loader = DataLoader(dataset, batch_size=1, shuffle=True)

In [None]:
model.train()
for epoch in range(10):
    avg_loss = 0
    for data in tqdm.tqdm(loader):
        data = data.to(device)
        optimizer.zero_grad()
        target = data.y[:, LABEL_INDEX]
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out.mean(), target.mean())
        loss.backward()
        avg_loss += loss.item()
        optimizer.step()
        # show loss on tqdm
        #tqdm.tqdm.write(f'Epoch: {epoch}, Loss: {loss.item()}')
    print(f'Epoch: {epoch}, Loss: {avg_loss/len(loader)}')