In [None]:
import torch 
import torch_geometric
import os
from torch.nn import Linear 
import torch.nn.functional as F 
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GraphConv, GATv2Conv
from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool, SAGPooling, Set2Set

In [None]:
class GCNBlock(torch.nn.Module): 

    def __init__(self, input_dim, output_dim, hidden_dim, dropout=0.5, use_batch_norm=False, gcn_type=GCNConv, activation = F.relu, gcn_kwargs=None):
        super(GCNBlock, self).__init__()
        self.use_batch_norm = use_batch_norm
        self.dropout = dropout
        self.gcn_type = gcn_type
        self.conv = gcn_type(input_dim, hidden_dim, **(gcn_kwargs if gcn_kwargs else {}))
        self.activation = activation

        if use_batch_norm: 
            self.batch_norm = torch.nn.BatchNorm1d(hidden_dim) # select the right normalization layer based on the input dimension 
        else:
            self.batch_norm = torch.nn.Identity


    def forward(self, x, edge_index):
        x_res = x 
        x = self.conv(x, edge_index)
        x = self.batch_norm(x, )
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = x + x_res
        return x


In [None]:
class RegressionBlock(torch.nn.Module): 

    def __init__(self, input_dim, output_dim, hidden_dims, activation = F.relu): 
        super(RegressionBlock, self).__init__()
        self.activation = activation
        self.hidden_dims = hidden_dims

        if len(hidden_dims) == 0: 
            self.layers = Linear(input_dim, output_dim)
        else: 
            layers = []
            in_dim = input_dim
            for hidden_dim in hidden_dims: 
                layers.append(Linear(in_dim, hidden_dim)) # check again if we need the bias there, I don't think so actually... 
                layers.append(activation)
                in_dim = hidden_dim
            layers.append(Linear(in_dim, output_dim))
            self.layers = torch.nn.Sequential(*layers)

    def forward(self, x): 
        x = self.layers(x) # <-- need softmax here somewhere? -> yes, for the categorical outputs, but not for the regression outputs
        return x

In [None]:
class GraphFeaturesBlock(torch.nn.Module): 

    def __init__(self, input_dim, output_dim, hidden_dims, dropout=0.5, activation = F.relu): 
        super(GraphFeaturesBlock, self).__init__()
        self.dropout = dropout
        self.activation = activation
        self.hidden_dims = hidden_dims

        if len(hidden_dims) == 0: 
            self.linear = Linear(input_dim, output_dim)
        else: 
            layers = []
            in_dim = input_dim
            for hidden_dim in hidden_dims: 
                layers.append(Linear(in_dim, hidden_dim)) # check bias
                layers.append(activation)
                in_dim = hidden_dim
            layers.append(Linear(in_dim, output_dim))
            self.linear = torch.nn.Sequential(*layers)

    def forward(self, x): 
        x = self.linear(x)
        return x

basic model class to organize the other things. does the following: 
- passes input through gcn network. This is a succession of GCN blocks
- applies graph feature network to graph level features and concatenates them with pooled node features **if** `use_graph_features = true`. 

- passes the result through the regression net to get out (dimension, boundary_id, manifold_id): 
```bash
x -> gcn -> pool -> concat(_, g) -> regression -> output
g ----------------> MLP_g(g) _|
```

In [None]:
class GCNModel(torch.nn.Module): 

    def __init__(self, gcn_net, regression_net, graph_features_net, pooling_layer, use_graph_features=False):
        super(GCNModel, self).__init__()
        self.gcn_net = gcn_net
        self.regression_net = regression_net
        self.graph_features_net = graph_features_net
        self.use_graph_features = use_graph_features
        self.pooling_layer = pooling_layer

    def forward(self, x, edge_index, batch, graph_features=None): 
        x = self.gcn_net(x, edge_index)
        x = self.pooling_layer(x, batch)
        if self.use_graph_features: 
            graph_features = self.graph_features_net(graph_features)
            x = torch.cat((x, graph_features), dim=-1) # last dim
        x = self.regression_net(x)
        return x

        


**TODO**
- [ ] how do GCNConv, GraphConv, SageConv work? 
- [ ] how does GlobalAttention, Set2Set work?
- [ ] train the damn thing