In [20]:
import torch
import torch.nn as nn
import tqdm.notebook as tqdm
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import QM9
from warnings import warn
from networks import FractalNet, FractalNetSeparated, Net, GNN_no_rel, GNN
from subgraph import Graph_to_Subgraph

In [21]:
LABEL_INDEX = 7
def get_qm9(data_dir, device="cuda", transform=None):
    """Download the QM9 dataset from pytorch geometric. Put it onto the device. Split it up into train / validation / test.
    Args:
        data_dir: the directory to store the data.
        device: put the data onto this device.
    Returns:
        train dataset, validation dataset, test dataset.
    """
    dataset = QM9(data_dir, transform=transform)

    # Permute the dataset
    try:
        permu = torch.load("permute.pt")
        dataset = dataset[permu]
    except FileNotFoundError:
        warn("Using non-standard permutation since permute.pt does not exist.")
        dataset, _ = dataset.shuffle(return_perm=True)

    # z score / standard score targets to mean = 0 and std = 1.
    mean = dataset.data.y.mean(dim=0, keepdim=True)
    std = dataset.data.y.std(dim=0, keepdim=True)
    dataset.data.y = (dataset.data.y - mean) / std
    mean, std = mean[:, LABEL_INDEX].item(), std[:, LABEL_INDEX].item()

    # Move the data to the device (it should fit on lisa gpus)
    dataset.data = dataset.data.to(device)

    len_train = 100_000
    len_val = 10_000

    train = dataset[:len_train]
    valid = dataset[len_train : len_train + len_val]
    test = dataset[len_train + len_val :]

    assert len(dataset) == len(train) + len(valid) + len(test)

    return train, valid, test

In [22]:
node_features = 5
Z_ONE_HOT_DIM = 5
EDGE_ATTR_DIM = 4
edge_features = 0
hidden_features = 64
out_features = 1

# TRAINING SHARED PARAMETERS FRACTAL NET

In [23]:
#TODO: actually make this you lazyass

## TRAINING A UNROLLED FRACTAL NET

In [24]:
# create a fractal net and train it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = FractalNetSeparated(node_features, edge_features, hidden_features, out_features, depth=1, pool='add').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
# create a dataloader for qm9
train, valid, test = get_qm9("data/qm9", device=device, transform=Graph_to_Subgraph())
# take a subset of the dataset
loader = DataLoader(train, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid, batch_size=32, shuffle=False)

  warn("Using non-standard permutation since permute.pt does not exist.")


In [25]:
# get total number of parameters of the model
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params}")

Total number of parameters: 99777


In [26]:
# store loss per epoch
losses = []
val_losses = []
for epoch in range(10):
    model.train()
    train_loss = 0
    for data in tqdm.tqdm(loader):
        data = data.to(device)
        optimizer.zero_grad()
        target = data.y[:, LABEL_INDEX]
        # keep only ground nodes in the data.batch
        data.batch = data.batch[data.ground_node]
        out = model(data.x[:, :Z_ONE_HOT_DIM], data.edge_index, data.subgraph_edge_index, data.node_subnode_index, data.subnode_node_index,data.ground_node, data.subgraph_batch_index, data.batch)
        #print("data batch shape", data.batch.shape)
        #print('data x shape', data.x.shape)
        loss = criterion(out.squeeze(), target)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        # show loss on tqdm
        #tqdm.tqdm.write(f'Epoch: {epoch}, Loss: {loss.item()}')
    # store loss per epoch
    # get performance on the validation set
    model.eval()
    with torch.no_grad():
        valid_loss = 0
        for data in tqdm.tqdm(valid_loader):
            data = data.to(device)
            target = data.y[:, LABEL_INDEX]
            data.batch = data.batch[data.ground_node]
            out = model(data.x[:, :Z_ONE_HOT_DIM], data.edge_index, data.subgraph_edge_index, data.node_subnode_index, data.subnode_node_index,data.ground_node, data.subgraph_batch_index, data.batch)
            loss = criterion(out.squeeze(), target)
            valid_loss += loss.item()

    losses.append(train_loss/len(loader))
    val_losses.append(valid_loss/len(valid_loader))
    print(f'Epoch: {epoch}, Loss: {train_loss/len(loader)}, Valid Loss: {valid_loss/len(valid_loader)}')

subgraph_results = {'train_loss': losses, 'valid_loss': val_losses}

  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 0, Loss: 0.07443700703653973, Valid Loss: 0.03843221781900325


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 1, Loss: 0.007744406558033079, Valid Loss: 0.0012437402246651588


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 2, Loss: 0.007773048592555569, Valid Loss: 0.003685983083844661


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 3, Loss: 0.01643546956624603, Valid Loss: 0.005171041473926316


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 4, Loss: 0.006900626265608007, Valid Loss: 0.0016544421736043864


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 5, Loss: 0.006213742802614579, Valid Loss: 0.0024829512902203735


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 6, Loss: 0.0048536804817290976, Valid Loss: 0.0010665426027593604


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 7, Loss: 0.0060062585594842675, Valid Loss: 0.0012455974522246845


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 8, Loss: 0.004149867006061831, Valid Loss: 0.0013096979776465784


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 9, Loss: 0.00325009604133782, Valid Loss: 0.003286436346683877


# TRAINING A NORMAL GNN NET

In [28]:
# create a fractal net and train it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GNN_no_rel(5, edge_features, hidden_features, out_features, num_convolution_blocks=2, pooling='add').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
# create a dataloader for qm9
train, valid, test = get_qm9("data/qm9", device=device, transform=None)
# take a subset of the dataset
loader = DataLoader(train, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid, batch_size=32, shuffle=False)

  warn("Using non-standard permutation since permute.pt does not exist.")


In [29]:
# get total number of parameters of the model
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params}")

Total number of parameters: 186241


In [30]:
model.train()
# store loss per epoch
train_losses = []
valid_losses = []
for epoch in range(15):
    avg_loss = 0
    for data in tqdm.tqdm(loader):
        data = data.to(device)
        optimizer.zero_grad()
        target = data.y[:, LABEL_INDEX]
        out = model(data.x[:, :Z_ONE_HOT_DIM], data.edge_index, None, data.batch)
        loss = criterion(out.squeeze(), target)
        loss.backward()
        avg_loss += loss.item()
        optimizer.step()
        # show loss on tqdm
        #tqdm.tqdm.write(f'Epoch: {epoch}, Loss: {loss.item()}')
    # store loss per epoch
    # get performance on the validation set
    model.eval()
    with torch.no_grad():
        valid_loss = 0
        for data in tqdm.tqdm(valid_loader):
            data = data.to(device)
            target = data.y[:, LABEL_INDEX]
            out = model(data.x[:, :Z_ONE_HOT_DIM], data.edge_index, None, data.batch)
            loss = criterion(out.squeeze(), target)
            valid_loss += loss.item()

    train_losses.append(avg_loss/len(loader))
    valid_losses.append(valid_loss/len(valid_loader))
    print(f'Epoch: {epoch}, Loss: {avg_loss/len(loader)}, Valid Loss: {valid_loss/len(valid_loader)}')
normal_gnn_results = {'train_loss': train_losses, 'valid_loss': valid_losses}

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 0, Loss: 0.30232476485073567, Valid Loss: 0.19605280635075067


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 1, Loss: 0.04152809501268202, Valid Loss: 0.0006900344012076773


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 2, Loss: 0.002660130196438404, Valid Loss: 0.006016587888369688


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 3, Loss: 0.0022843028054706518, Valid Loss: 0.0004528283462160454


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 4, Loss: 0.002089855820953962, Valid Loss: 0.0006411828269167378


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 5, Loss: 0.0015825997749320231, Valid Loss: 0.0008497668258933582


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 6, Loss: 0.0014492124186325237, Valid Loss: 0.000671385707508121


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 7, Loss: 0.0015568956145597622, Valid Loss: 0.0009621920917866031


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 8, Loss: 0.0009419431429613905, Valid Loss: 0.0014922298416986275


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 9, Loss: 0.0009709137557791837, Valid Loss: 0.000607927473590742


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 10, Loss: 0.0010836833062155347, Valid Loss: 0.0043703175712222105


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 11, Loss: 0.0008475799004897635, Valid Loss: 8.535673726999466e-05


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 12, Loss: 0.0006939245965132067, Valid Loss: 0.0005486315193834504


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 13, Loss: 0.0006014579452617908, Valid Loss: 0.0035062241096483968


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 14, Loss: 0.0006556307674165146, Valid Loss: 7.747562439487756e-05


  0%|          | 0/3125 [00:00<?, ?it/s]

# TRAINING A GNN WITH EDGE FEATURES

In [None]:
# create a fractal net and train it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GNN(n_node_features = Z_ONE_HOT_DIM,
            n_edge_features=EDGE_ATTR_DIM,
            n_hidden=64,
            n_output=out_features,
            num_convolution_blocks=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
# create a dataloader for qm9
train, valid, test = get_qm9("data/qm9", device=device, transform=None)
# take a subset of the dataset
#dataset = dataset[:4000]
loader = DataLoader(train, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid, batch_size=32, shuffle=False)

In [None]:
# get total number of parameters of the model
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params}")

In [None]:
# store loss per epoch
train_loss_GNN = []
val_loss_GNN = []
for epoch in range(10):
    model.train()
    train_loss = 0
    for data in tqdm.tqdm(loader):
        data = data.to(device)
        optimizer.zero_grad()
        target = data.y[:, LABEL_INDEX]
        out = model(data.x[:, :Z_ONE_HOT_DIM],
                    data.edge_index,
                    data.edge_attr.argmax(dim=-1),
                    data.batch)
        loss = criterion(out.squeeze(), target)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        # show loss on tqdm
        #tqdm.tqdm.write(f'Epoch: {epoch}, Loss: {loss.item()}')
    # store loss per epoch
    # evaluate on the valid set
    model.eval()
    with torch.no_grad():
        valid_loss = 0
        for data in valid_loader:
            data = data.to(device)
            target = data.y[:, LABEL_INDEX]
            out = model(data.x[:, :Z_ONE_HOT_DIM],
                        data.edge_index,
                        data.edge_attr.argmax(dim=-1),
                        data.batch)
            loss = criterion(out.squeeze(), target)
            valid_loss += loss.item()
    train_loss_GNN.append(train_loss/len(loader))
    val_loss_GNN.append(valid_loss/len(valid_loader))
    print(f'Epoch: {epoch}, Loss: {train_loss/len(loader)}, Valid Loss: {valid_loss/len(valid_loader)}')
edge_gnn_results = {'train_loss': train_loss_GNN, 'valid_loss': val_loss_GNN}

# PLOTTING LOSS

In [None]:
# plot loss
# IGNORE FOR NOW #
import matplotlib.pyplot as plt
# plot results from different runs
plt.plot(normal_gnn_results['train_loss'], label='normal gnn')
plt.plot(normal_gnn_results['valid_loss'], label='normal gnn')
plt.plot(edge_gnn_results['train_loss'], label='edge gnn')
plt.plot(edge_gnn_results['valid_loss'], label='edge gnn')
plt.legend()
plt.show()
