In [95]:
import torch_geometric
import torch
from torch_geometric.nn.pool import knn_graph
import swyft

import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

In [82]:
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='mean')  # "Add" aggregation (Step 5).
        self.lin = torch.nn.Sequential(
            torch.nn.Linear(in_channels, 256),
            torch.nn.ReLU(),
            torch.nn.LazyLinear(out_channels)
        )
        self.bias = Parameter(torch.Tensor(out_channels))

#        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        #x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        out = self.propagate(edge_index, x=x, norm=norm)

        # Step 6: Apply a final bias vector.
        out += self.bias

        return out

    def message(self, norm, x_j, x_i):
        # x_j has shape [E, out_channels]
        m_j = x_j[:,:1]
        r_ji = x_j[:,1:] - x_i[:,1:]
        d = torch.norm(r_ji, dim = 1)
        return m_j

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

In [91]:
g = GCNConv(4, 1).cuda()

In [92]:
# Example usage:
points = torch.rand((10, 4)).cuda()  # 100 points in 3D
k = 5  # Number of neighbors
edge_index = knn_graph(points, k)
gg = g(points, edge_index)
gg

tensor([[0.5617],
        [0.3783],
        [0.6214],
        [0.5144],
        [0.6214],
        [0.2190],
        [0.1937],
        [0.2190],
        [0.3588],
        [0.2714]], device='cuda:0', grad_fn=<AddBackward0>)

In [93]:
class Network(swyft.SwyftModule):
    def __init__(self):
        super().__init__()
        self.logratios = swyft.LogRatioEstimator_1dim(num_features = 1, num_params = 1, varnames = 'd')

    def forward(self, A, B):
        logratios = self.logratios(A['x'], B['d'])
        return logratios

trainer = swyft.SwyftTrainer(accelerator = 'cpu', devices=1, max_epochs = 10, precision = 64)
dm = swyft.SwyftDataModule(samples, fractions = [0.7, 0.3, 0.1], num_workers = 0, batch_size = 64)
network = Network()
trainer.fit(network, dm)

NameError: name 'swyft' is not defined