In [1]:
import math
from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.aggr import Aggregation
from torch_geometric.utils import softmax
from torch import Tensor
from torch.nn import (
    BatchNorm1d,
    Dropout,
    InstanceNorm1d,
    LayerNorm,
    ReLU,
    Sequential,
)
from torch_geometric.nn.dense.linear import Linear


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def fill_triangular_torch(x):
    m = x.shape[0] # should be n * (n+1) / 2
    # solve for n
    n = int(math.sqrt((0.25 + 2 * m)) - 0.5)
    idx = torch.tensor(m - (n**2 - m))
    
    x_tail = x[idx:]
        
    return torch.cat([x_tail, torch.flip(x, [0])], 0).reshape(n, n)

def fill_diagonal_torch(a, val):
    a[..., torch.arange(0, a.shape[0]), torch.arange(0, a.shape[0])] = val
    #a[..., torch.arange(0, a.shape[0]).to(device), torch.arange(0, a.shape[0]).to(device)] = val
    return a

def construct_fisher_matrix_multiple_torch(outputs):
    Q = torch.vmap(fill_triangular_torch)(outputs)
    # vmap the jnp.diag function for the batch
    _diag = torch.vmap(torch.diag)
    
    middle = _diag(torch.triu(Q) - torch.nn.Softplus()(torch.triu(Q))).to(device)
        
    padding = torch.zeros(Q.shape).to(device)
    
    # vmap the fill_diagonal code
    L = Q - torch.vmap(fill_diagonal_torch)(padding, middle)

    return torch.einsum('...ij,...jk->...ik', L, torch.permute(L, (0, 2, 1)))



## ADD IN AN MLP TO GET US TO THE RIGHT DIMENSIONALITY
class MLP(Sequential):
    def __init__(self, channels: List[int], norm: Optional[str] = None,
                 bias: bool = True, dropout: float = 0.):
        m = []
        for i in range(1, len(channels)):
            m.append(Linear(channels[i - 1], channels[i], bias=bias))

            if i < len(channels) - 1:
                if norm and norm == 'batch':
                    m.append(BatchNorm1d(channels[i], affine=True))
                elif norm and norm == 'layer':
                    m.append(LayerNorm(channels[i], elementwise_affine=True))
                elif norm and norm == 'instance':
                    m.append(InstanceNorm1d(channels[i], affine=False))
                elif norm:
                    raise NotImplementedError(
                        f'Normalization layer "{norm}" not supported.')
                m.append(ReLU())
                m.append(Dropout(dropout))

        super().__init__(*m)



class FishnetsAggregation(Aggregation):
    r"""Fishnets aggregation for GNNs

    .. math::
        \mathrm{var}(\mathcal{X}) = \mathrm{mean}(\{ \mathbf{x}_i^2 : x \in
        \mathcal{X} \}) - \mathrm{mean}(\mathcal{X})^2.

    Args:
        n_p (int): latent space size
        semi_grad (bool, optional): If set to :obj:`True`, will turn off
            gradient calculation during :math:`E[X^2]` computation. Therefore,
            only semi-gradients are used during backpropagation. Useful for
            saving memory and accelerating backward computation.
            (default: :obj:`False`)
    """
    def __init__(self, n_p: int, in_size: int = None, semi_grad: bool = False):
        super().__init__()
        
        self.n_p = n_p
        
        if in_size is None:
            in_size = n_p
        
        self.in_size = in_size
        self.semi_grad = semi_grad
        fdim = n_p + ((n_p * (n_p + 1)) // 2)
        self.fishnets_dims = fdim
        from torch_geometric.nn import Linear
        self.lin_1 = Linear(in_size, fdim, bias=True).to(device)
        self.lin_2 = Linear(n_p, in_size, bias=True).to(device)

    def forward(self, x: Tensor, index: Optional[Tensor] = None,
                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
                dim: int = -2) -> Tensor:
        
        # GET X TO THE RIGHT DIMENSIONALITY
        x = self.lin_1(x)
        
        # CONSTRUCT SCORE AND FISHER
        # the input x will be n_p + (n_p*(n_p + 1) // 2) long
        score = x[..., :self.n_p]
        fisher = x[..., self.n_p:]
        
        # reduce the score
        score = self.reduce(score, index, ptr, dim_size, dim, reduce='sum')
        
        # construct the fisher
        fisher = construct_fisher_matrix_multiple_torch(fisher)

        # sum the fishers
        fisher = self.reduce(fisher.reshape(-1, self.n_p**2), 
                             index, ptr, dim_size, dim, reduce='sum').reshape(-1, self.n_p, self.n_p)
        
        # add in the prior 
        fisher += torch.eye(self.n_p).to(device)
        
        # calculate inverse-dot product
        mle = torch.einsum('...jk,...k->...j', torch.linalg.inv(fisher), score)
        
        # if we decide to bottleneck, send through linear back to node dimensionality
        if self.in_size != self.n_p:
            mle = self.lin_2(mle)
          
        return mle


In [2]:
range(64)[1]

1

In [3]:
n_p = 4
embedding_size = n_p + ((n_p * (n_p + 1)) // 2)
print("embedding size", embedding_size)

#print(embedding_size)
mle = FishnetsAggregation(n_p=n_p)(torch.ones((300, n_p)).to(device))

embedding size 14


In [None]:
import torch
import torch.nn.functional as F
from ogb.nodeproppred import Evaluator, PygNodePropPredDataset
from torch.nn import LayerNorm, Linear, ReLU
from tqdm import tqdm

from torch_geometric.loader import RandomNodeLoader
from torch_geometric.nn import DeepGCNLayer, GENConv
from torch_geometric.utils import scatter

dataset = PygNodePropPredDataset('ogbn-proteins', root='/data101/makinen/ogbn/')
splitted_idx = dataset.get_idx_split()
data = dataset[0]
data.node_species = None
data.y = data.y.to(torch.float)

# Initialize features of nodes by aggregating edge features.
row, col = data.edge_index
data.x = scatter(data.edge_attr, col, dim_size=data.num_nodes, reduce='sum')

# Set split indices to masks.
for split in ['train', 'valid', 'test']:
    mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    mask[splitted_idx[split]] = True
    data[f'{split}_mask'] = mask

train_loader = RandomNodeLoader(data, num_parts=40, shuffle=True,
                                num_workers=5)


test_loader = RandomNodeLoader(data, num_parts=5, num_workers=5)


class DeeperGCN(torch.nn.Module):
    def __init__(self, n_p, num_layers, hidden_channels=None):
        super().__init__()
        
        # need some extra channels for the fisher matrix
        fishnets_channels = n_p + ((n_p * (n_p + 1)) // 2)
        
        if hidden_channels is None:
            hidden_channels = n_p

        self.node_encoder = Linear(data.x.size(-1), hidden_channels)
        self.edge_encoder = Linear(data.edge_attr.size(-1), hidden_channels)

        self.layers = torch.nn.ModuleList()
        for i in range(1, num_layers + 1):
            conv = GENConv(hidden_channels, hidden_channels, 
                           aggr=FishnetsAggregation(in_size=hidden_channels, n_p=n_p),
                           t=1.0, learn_t=False, 
                           num_layers=2, norm='layer')
            # output of conv is n_p size
            norm = LayerNorm(hidden_channels, elementwise_affine=True)
            act = ReLU(inplace=True)

            layer = DeepGCNLayer(conv, norm, act, block='res+', dropout=0.1,
                                 ckpt_grad=i % 3)
            self.layers.append(layer)

        self.lin = Linear(hidden_channels, data.y.size(-1))

    def forward(self, x, edge_index, edge_attr):
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        #print("x", x.shape)

        x = self.layers[0].conv(x, edge_index, edge_attr)
        
        #print("x", x.shape)

        for layer in self.layers[1:]:
            x = layer(x, edge_index, edge_attr)
            #print("x", x.shape)

        x = self.layers[0].act(self.layers[0].norm(x))
        x = F.dropout(x, p=0.1, training=self.training)

        return self.lin(x)

FISHNETS_N_P = 5
HIDDEN_CHANNELS = 24
NUM_LAYERS = 3 # was 28

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DeeperGCN(n_p=FISHNETS_N_P, num_layers=NUM_LAYERS, hidden_channels=HIDDEN_CHANNELS).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()
evaluator = Evaluator('ogbn-proteins')


def train(epoch):
    model.train()

    pbar = tqdm(total=len(train_loader), position=0)
    pbar.set_description(f'Training epoch: {epoch:04d}')

    total_loss = total_examples = 0
    for data in train_loader:
        optimizer.zero_grad()
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        total_loss += float(loss) * int(data.train_mask.sum())
        total_examples += int(data.train_mask.sum())

        pbar.update(1)

    pbar.close()

    return total_loss / total_examples


@torch.no_grad()
def test():
    model.eval()

    y_true = {'train': [], 'valid': [], 'test': []}
    y_pred = {'train': [], 'valid': [], 'test': []}

    pbar = tqdm(total=len(test_loader), position=0)
    pbar.set_description(f'Evaluating epoch: {epoch:04d}')

    for data in test_loader:
        
        #data.edge_attr *= 0.5 #torch.rand(data.edge_attr.shape) #* (1.0 - 0.5)
        # sum edges again to restart the data
        row, col = data.edge_index
        data.x = scatter(data.edge_attr, col, dim_size=data.num_nodes, reduce='sum')

        
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr)

        for split in y_true.keys():
            mask = data[f'{split}_mask']
            y_true[split].append(data.y[mask].cpu())
            y_pred[split].append(out[mask].cpu())

        pbar.update(1)

    pbar.close()

    train_rocauc = evaluator.eval({
        'y_true': torch.cat(y_true['train'], dim=0),
        'y_pred': torch.cat(y_pred['train'], dim=0),
    })['rocauc']

    valid_rocauc = evaluator.eval({
        'y_true': torch.cat(y_true['valid'], dim=0),
        'y_pred': torch.cat(y_pred['valid'], dim=0),
    })['rocauc']

    test_rocauc = evaluator.eval({
        'y_true': torch.cat(y_true['test'], dim=0),
        'y_pred': torch.cat(y_pred['test'], dim=0),
    })['rocauc']

    return train_rocauc, valid_rocauc, test_rocauc

losses = []
for epoch in range(1, 29):
    loss = train(epoch)
    losses.append(loss)
    
    if epoch % 2 == 0:
        train_rocauc, valid_rocauc, test_rocauc = test()
        print(f'Loss: {loss:.4f}, Train: {train_rocauc:.4f}, '
          f'Val: {valid_rocauc:.4f}, Test: {test_rocauc:.4f}')
    else:
        print(f'Loss: {loss:.4f}')

Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)
Training epoch: 0001: 100%|██████████| 40/40 [00:52<00:00,  1.32s/it]


Loss: 0.4014


Training epoch: 0002: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0002: 100%|██████████| 5/5 [00:15<00:00,  3.14s/it]


Loss: 0.3175, Train: 0.7129, Val: 0.6674, Test: 0.6465


Training epoch: 0003: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]


Loss: 0.3063


Training epoch: 0004: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0004: 100%|██████████| 5/5 [00:15<00:00,  3.01s/it]


Loss: 0.2996, Train: 0.7534, Val: 0.7072, Test: 0.6812


Training epoch: 0005: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]


Loss: 0.2931


Training epoch: 0006: 100%|██████████| 40/40 [00:53<00:00,  1.33s/it]
Evaluating epoch: 0006: 100%|██████████| 5/5 [00:16<00:00,  3.38s/it]


Loss: 0.2885, Train: 0.7627, Val: 0.6956, Test: 0.6728


Training epoch: 0007: 100%|██████████| 40/40 [00:53<00:00,  1.33s/it]


Loss: 0.2851


Training epoch: 0008: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0008: 100%|██████████| 5/5 [00:16<00:00,  3.35s/it]


Loss: 0.2865, Train: 0.7741, Val: 0.7466, Test: 0.7258


Training epoch: 0009: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]


Loss: 0.2830


Training epoch: 0010: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0010: 100%|██████████| 5/5 [00:16<00:00,  3.24s/it]


Loss: 0.2870, Train: 0.7867, Val: 0.7598, Test: 0.7228


Training epoch: 0011: 100%|██████████| 40/40 [00:53<00:00,  1.33s/it]


Loss: 0.2836


Training epoch: 0012: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0012: 100%|██████████| 5/5 [00:15<00:00,  3.16s/it]


Loss: 0.2798, Train: 0.7855, Val: 0.7261, Test: 0.6974


Training epoch: 0013: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]


Loss: 0.2781


Training epoch: 0014: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0014: 100%|██████████| 5/5 [00:16<00:00,  3.25s/it]


Loss: 0.2764, Train: 0.7801, Val: 0.7304, Test: 0.6953


Training epoch: 0015: 100%|██████████| 40/40 [00:53<00:00,  1.33s/it]


Loss: 0.2749


Training epoch: 0016: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0016: 100%|██████████| 5/5 [00:16<00:00,  3.29s/it]


Loss: 0.2748, Train: 0.7901, Val: 0.7318, Test: 0.7172


Training epoch: 0017: 100%|██████████| 40/40 [00:53<00:00,  1.33s/it]


Loss: 0.2729


Training epoch: 0018: 100%|██████████| 40/40 [00:54<00:00,  1.37s/it]
Evaluating epoch: 0018: 100%|██████████| 5/5 [00:16<00:00,  3.21s/it]


Loss: 0.2741, Train: 0.7968, Val: 0.7564, Test: 0.7170


Training epoch: 0019: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]


Loss: 0.2746


Training epoch: 0020: 100%|██████████| 40/40 [00:53<00:00,  1.33s/it]
Evaluating epoch: 0020: 100%|██████████| 5/5 [00:15<00:00,  3.00s/it]


Loss: 0.2712, Train: 0.8010, Val: 0.7584, Test: 0.7170


Training epoch: 0021: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]


Loss: 0.2721


Training epoch: 0022: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0022: 100%|██████████| 5/5 [00:15<00:00,  3.00s/it]


Loss: 0.2713, Train: 0.8005, Val: 0.7637, Test: 0.7309


Training epoch: 0023: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]


Loss: 0.2700


Training epoch: 0024:  18%|█▊        | 7/40 [00:13<00:55,  1.69s/it]

In [8]:
epoch

28

In [6]:
print(f'Loss: {loss:.4f}, Train: {train_rocauc:.4f}, '
          f'Val: {valid_rocauc:.4f}, Test: {test_rocauc:.4f}')

Loss: 0.2754, Train: 0.7463, Val: 0.7339, Test: 0.6979


In [7]:
print(f'Loss: {loss:.4f}, Train: {train_rocauc:.4f}, '
          f'Val: {valid_rocauc:.4f}, Test: {test_rocauc:.4f}')

Loss: 0.2754, Train: 0.7463, Val: 0.7339, Test: 0.6979


In [12]:
outdir = '/data101/makinen/graph_fishnets/models/fishnet_nc_24_nlyr_14'

torch.save(model.state_dict(), outdir)

In [14]:
model = DeeperGCN(n_p=HIDDEN_CHANNELS, num_layers=NUM_LAYERS).to(device)
model.load_state_dict(torch.load(outdir))
model.eval()

DeeperGCN(
  (node_encoder): Linear(in_features=8, out_features=24, bias=True)
  (edge_encoder): Linear(in_features=8, out_features=24, bias=True)
  (layers): ModuleList(
    (0-13): 14 x DeepGCNLayer(block=res+)
  )
  (lin): Linear(in_features=24, out_features=112, bias=True)
)