In [17]:
%load_ext autoreload
%autoreload 2

from typing import List, Dict, Any, Tuple, Union, Optional
import os
import sys

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import pytorch_lightning as pl
from torch_geometric.data import Data
from config import get_config

from tng_halo.models import GNNBlock
from tng_halo import training_utils

%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
class BinaryNodeClassifier(pl.LightningModule):
    def __init__(
        self, 
        input_size: int, 
        hidden_sizes: List[int], 
        embed_size: int = None,
        graph_layer: str = "ChebConv", 
        graph_layer_args: Dict[str, Any] = None,
        activation_name: callable = nn.ReLU(), 
        layer_norm: bool = False, 
        norm_first: bool = False,
        optimizer_args: Dict[str, Any] = None,
        scheduler_args: Dict[str, Any] = None
    ) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.embed_size = embed_size
        self.graph_layer = graph_layer
        self.graph_layer_args = graph_layer_args or {}
        self.activation_name = activation_name
        self.layer_norm = layer_norm
        self.norm_first = norm_first
        self.optimizer_args = optimizer_args or {}
        self.scheduler_args = scheduler_args or {}
        self.save_hyperparameters()
        self._setup_model()

    def _setup_model(self) -> None:
        if self.embed_size:
            self.embed_layer = nn.Linear(self.input_size, self.embed_size)
            input_size = self.embed_size
        else:
            self.embed_layer = None
            input_size = self.input_size

        layer_sizes = [input_size] + self.hidden_sizes
        activation_fn = training_utils.get_activation(self.activation_name)

        self.layers = nn.ModuleList()
        for i in range(len(layer_sizes) - 1):
            layer = GNNBlock(
                layer_sizes[i-1], 
                layer_sizes[i], 
                layer_name=self.graph_layer,
                layer_args=self.graph_layer_args, 
                layer_norm=self.layer_norm, 
                norm_first=self.norm_first,
                activation_fn=activation_fn
            )
            self.layers.append(layer)
        self.output_layer = GNNBlock(
            layer_sizes[-1], 1,
            layer_name=self.graph_layer,
            layer_args=self.graph_layer_args,
            layer_norm=self.layer_norm,
            norm_first=self.norm_first,
            activation_fn=nn.Identity()
        )

    def forward(
        self, 
        x: torch.Tensor, 
        edge_index: torch.Tensor,
        edge_attr: torch.Tensor = None,
        edge_weight: torch.Tensor = None
    ) -> torch.Tensor:

        if self.embed_layer:
            x = self.embed_layer(x)
        for layer in self.layers:
            x = layer(x, edge_index, edge_attr, edge_weight)
        x = self.output_layer(x, edge_index, edge_attr, edge_weight)
        return x

    def training_step(self, batch, batch_idx):
        yhat = self.forward(
            batch.x, batch.edge_index, batch.edge_attr, batch.edge_weight)
        loss = F.binary_cross_entropy_with_logits(yhat, batch.y)
        self.log(
            "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation(self, batch, batch_idx):
        yhat = self.forward(
            batch.x, batch.edge_index, batch.edge_attr, batch.edge_weight)
        loss = F.binary_cross_entropy_with_logits(yhat, batch.y)
        self.log(
            "val_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return training_utils.configure_optimizers(
            parameters, self.optimizer_args, self.scheduler_args)

In [24]:
config = get_config()

In [25]:
# generate a random graph
x = torch.tensor([[2, 1], [5, 6], [3, 7]], dtype=torch.float)
y = torch.tensor([[0], [1], [0]], dtype=torch.float)
edge_index = torch.tensor(
    [[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
data = Data(x=x, y=y, edge_index=edge_index)

In [26]:
model = BinaryNodeClassifier(
    input_size=config.model.input_size,
    hidden_sizes=config.model.hidden_sizes,
    embed_size=config.model.embed_size,
    graph_layer=config.model.graph_layer,
    graph_layer_args=config.model.graph_layer_args,
    activation_name=config.model.activation_name,
    layer_norm=config.model.layer_norm,
    norm_first=config.model.norm_first,
    optimizer_args=config.optimizer,
    scheduler_args=config.scheduler
)

In [32]:
yhat = model(data.x, data.edge_index)
loss = F.binary_cross_entropy_with_logits(yhat, data.y)
print(yhat, loss)

tensor([[ 1.3063],
        [-1.1223],
        [-0.1839]], grad_fn=<AddBackward0>) tensor(1.1852, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
