<a href="https://colab.research.google.com/github/yaoCTSX/yaoCTSX/blob/main/NELL_SAT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv torch_geometric -f https://data.pyg.org/whl/torch-1.13.0+cpu.html

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.13.0+cpu.html
Collecting pyg_lib
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcpu/pyg_lib-0.2.0%2Bpt113cpu-cp39-cp39-linux_x86_64.whl (626 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m626.9/626.9 KB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_scatter
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcpu/torch_scatter-2.1.1%2Bpt113cpu-cp39-cp39-linux_x86_64.whl (485 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.7/485.7 KB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_sparse
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcpu/torch_sparse-0.6.17%2Bpt113cpu-cp39-cp39-linux_x86_64.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_clus

In [None]:
import torch_geometric

In [None]:
#utils
import torch


def sparse_to_tensor(matrix):
    "Converts a sparse matrix to a 3 x N matrix"
    indices = matrix.coalesce().indices()
    values = matrix.coalesce().values().unsqueeze(0)
    return torch.cat([indices, values], dim=0)


def tensor_to_sparse(matrix):
    "Converts a 3 x N matrix to a sparse matrix"
    indices = matrix[0:2]
    values = matrix[2:3].squeeze()
    return torch.sparse_coo_tensor(indices, values)


def ensure_input_is_tensor(input):
    if input.is_sparse:
        input = sparse_to_tensor(input)
    return input


def edge_to_node_matrix(edges, nodes, one_indexed=True):
    sigma1 = torch.zeros((len(nodes), len(edges)), dtype=torch.float)
    offset = int(one_indexed)
    j = 0
    for edge in edges:
        x, y = edge
        sigma1[x - offset][j] -= 1
        sigma1[y - offset][j] += 1
        j += 1
    return sigma1


def triangle_to_edge_matrix(triangles, edges):
    sigma2 = torch.zeros((len(edges), len(triangles)), dtype=torch.float)
    edges = [e for e in edges]
    edges = {edges[i]: i for i in range(len(edges))}
    for l in range(len(triangles)):
        i, j, k = triangles[l]
        if (i, j) in edges:
            sigma2[edges[(i, j)]][l] += 1
        else:
            sigma2[edges[(j, i)]][l] -= 1

        if (j, k) in edges:
            sigma2[edges[(j, k)]][l] += 1
        else:
            sigma2[edges[(k, j)]][l] -= 1

        if (i, k) in edges:
            sigma2[edges[(i, k)]][l] -= 1
        else:
            sigma2[edges[(k, i)]][l] += 1

    return sigma2

In [None]:
#cochain
import torch
#from utils import sparse_to_tensor #+

def stl(t):
    "Shape to list"
    return list(t.shape)

class CoChain:

    def __init__(self, X0, X1, X2, b1, b2, label):
        self.X0 = X0
        self.X1 = X1
        self.X2 = X2
        # b1 and b2 can either be sparse or dense but since python doesn't really have overloading, doing this instead
        if b1.is_sparse:
            b1 = sparse_to_tensor(b1)
        self.b1 = b1

        if b2.is_sparse:
            b2 = sparse_to_tensor(b2)
        self.b2 = b2

        self.label = label


    def __str__(self):
        name = f"CoChain(X0={stl(self.X0)}, X1={stl(self.X1)}, X2={stl(self.X2)}," \
               f" b1={stl(self.b1)}, b2={stl(self.b2)}, label={stl(self.label)})"
        return name

    def __repr__(self):
        name = f"CoChain(X0={stl(self.X0)}, X1={stl(self.X1)}, X2={stl(self.X2)}," \
               f" b1={stl(self.b1)}, b2={stl(self.b2)}, label={stl(self.label)})"
        return name

    def __eq__(self, other):
        x0 = torch.all(torch.eq(self.X0, other.X0)).item()
        x1 = torch.all(torch.eq(self.X1, other.X1)).item()
        x2 = torch.all(torch.eq(self.X2, other.X2)).item()
        s1 = torch.all(torch.eq(self.b1, other.b1)).item()
        s2 = torch.all(torch.eq(self.b2, other.b2)).item()
        l0 = torch.all(torch.eq(self.label, other.label)).item()
        return all([x0, x1, x2, s1, s2, l0])

In [None]:
#device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
#simplicial_complex
import torch
#from utils import ensure_input_is_tensor #+
#from constants import DEVICE #+


class SimplicialComplex:

    def __init__(self, X0, X1, X2, L0, L1, L2, label, batch=None):
        self.X0 = X0
        self.X1 = X1
        self.X2 = X2

        # L0, L1 and L2 can either be sparse or dense but since python doesn't really have overloading, doing this instead
        self.L0 = ensure_input_is_tensor(L0)
        if L1 is not None:
            self.L1 = ensure_input_is_tensor(L1)
        else:
            self.L1 = L1
        if L2 is not None:
            self.L2 = ensure_input_is_tensor(L2)
        else:
            self.L2 = L2

        self.label = label
        self.batch = batch

    def __eq__(self, other):
        x0 = torch.allclose(self.X0, other.X0, atol=1e-5)
        x1 = torch.allclose(self.X1, other.X1, atol=1e-5)
        x2 = torch.allclose(self.X2, other.X2, atol=1e-5)
        l0 = torch.allclose(self.L0, other.L0, atol=1e-5)
        l1 = torch.allclose(self.L1, other.L1, atol=1e-5)
        l2 = torch.allclose(self.L2, other.L2, atol=1e-5)
        label = torch.allclose(self.label, other.label, atol=1e-5)
        return all([x0, x1, x2, l0, l1, l2, label])

    def unpack_features(self):
        return self.X0, self.X1, self.X2

    def unpack_laplacians(self):
        return self.L0, self.L1, self.L2

    def unpack_batch(self):
        return self.batch

    def to_device(self):
        self.X0 = self.X0.to(DEVICE)
        self.X1 = self.X1.to(DEVICE)
        self.X2 = self.X2.to(DEVICE)

        self.L0 = self.L0.to(DEVICE)
        if self.L1 is not None:
            self.L1 = self.L1.to(DEVICE)
        if self.L2 is not None:
            self.L2 = self.L2.to(DEVICE)

        self.batch = [batch.to(DEVICE) for batch in self.batch]

In [None]:
#nn.utils
import torch
import scipy
import scipy.sparse.linalg as spl
import numpy as np
from scipy.sparse import coo_matrix
import networkx as nx
#from utils import edge_to_node_matrix, triangle_to_edge_matrix #+
#from models.CoChain import CoChain #+
import functools
import scipy.sparse as sp


def normalise_boundary(b1, b2):
    B1, B2 = to_sparse_coo(b1), to_sparse_coo(b2)
    x0, x1 = B1.shape
    _, x2 = B2.shape

    B1_v_abs, B1_i = torch.abs(B1.coalesce().values()), B1.coalesce().indices()
    B1_sum = torch.sparse.sum(torch.sparse_coo_tensor(B1_i, B1_v_abs, (x0, x1)), dim=1)
    B1_sum_values = B1_sum.to_dense()
    B1_sum_indices = torch.tensor([i for i in range(x0)])
    d0_diag_indices = torch.stack([B1_sum_indices, B1_sum_indices], dim=0)
    B1_sum_inv_values = torch.nan_to_num(1. / B1_sum_values, nan=0., posinf=0., neginf=0.)

    D1_inv = torch.sparse_coo_tensor(d0_diag_indices, 0.5 * B1_sum_inv_values, (x0, x0))
    D3_values = (1 / 3.) * torch.ones(B2.shape[1])
    D3_indices = [i for i in range(B2.shape[1])]
    D3_indices = torch.tensor([D3_indices, D3_indices])
    D3 = torch.sparse_coo_tensor(D3_indices, D3_values, (x2, x2))

    B2D3 = torch.sparse.mm(B2, D3)
    D1invB1 = (1 / np.sqrt(2.)) * torch.sparse.mm(D1_inv, B1)

    return D1invB1, B2D3


def preprocess_features(features):
    """Row-normalize feature matrix"""
    rowsum = np.array(features.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    features = r_mat_inv.dot(features)
    return torch.tensor(features, dtype=torch.float)


def remove_diag_sparse(sparse_adj):
    scipy_adj = torch_sparse_to_scipy_sparse(sparse_adj)
    scipy_adj = scipy.sparse.triu(scipy_adj, k=1)
    return scipy_sparse_to_torch_sparse(scipy_adj)


def get_features(features, sc_list):
    def _get_features(features, sc):
        f = [features[i] for i in sc]
        return functools.reduce(lambda a, b: a + b, f).float()
        # return functools.reduce(lambda a, b: torch.logical_and(a, b), f).float()
        # return a/torch.sum(s)

    features = [_get_features(features, sc) for sc in sc_list]
    if bool(features):
        return torch.stack(features, dim=0)
    else:
        return torch.tensor([])


def filter_simplices(node_features, simplice):
    s = [node_features[i] for i in simplice]
    common_features = functools.reduce(lambda a, b: torch.logical_and(a, b), s).float()
    return torch.sum(common_features).item() > 0


def correct_orientation(L, up_or_down):
    """
    L : n * n sparse Laplacian matrix
    up_or_down : int in {-1, 1}
    """
    # Add 2 to identity
    identity = 2 * torch.ones(L.shape[0])
    identity_indices = torch.arange(L.shape[0])
    identity_indices = torch.stack([identity_indices, identity_indices], dim=0)
    sparse_identity = torch.sparse_coo_tensor(identity_indices, identity)
    adj = L + sparse_identity

    indices = adj.coalesce().indices()
    values = adj.coalesce().values() * up_or_down
    values[values < -1] = 1
    values = torch.sign(values)

    return torch.sparse_coo_tensor(indices, values)


def convert_to_CoChain(adj, features, labels, X1=None, X2=None):
    X0 = features

    nodes = [i for i in range(X0.shape[0])]
    edges = adj.coalesce().indices().tolist()
    edges = [(i, j) for i, j in zip(edges[0], edges[1])]
    # edges = [*filter(lambda x: filter_simplices(features, x), edges)]

    g = nx.Graph()
    g.add_nodes_from(nodes)
    g.add_edges_from(edges)
    triangles = [list(sorted(x)) for x in nx.enumerate_all_cliques(g) if len(x) == 3]
    # triangles = [*filter(lambda x: filter_simplices(features, x), triangles)]
    b1 = edge_to_node_matrix(edges, nodes, one_indexed=False).to_sparse()
    b2 = triangle_to_edge_matrix(triangles, edges).to_sparse()

    if X1 is None:
        X1 = torch.tensor(edges)

    if X2 is None:
        X2 = torch.tensor(triangles)

    return CoChain(X0, X1, X2, b1, b2, labels)


def repair_sparse(matrix, ideal_shape):
    # Only use this if last few cols/rows are empty and were removed in sparse operation
    i_x, i_y = ideal_shape
    m_x, m_y = matrix.shape[0], matrix.shape[1]
    indices = matrix.coalesce().indices()
    values = matrix.coalesce().values()
    if i_x > m_x or i_y > m_y:
        additional_i = torch.tensor([[i_x - 1], [i_y - 1]], dtype=torch.float)
        additional_v = torch.tensor([0], dtype=torch.float)
        indices = torch.cat([indices, additional_i], dim=1)
        values = torch.cat([values, additional_v], dim=0)
    return torch.sparse_coo_tensor(indices, values)


def to_sparse_coo(matrix):
    indices = matrix[0:2]
    values = matrix[2:3].squeeze()
    return torch.sparse_coo_tensor(indices, values)


def sparse_diag_identity(n):
    i = [i for i in range(n)]
    return torch.sparse_coo_tensor(torch.tensor([i, i]), torch.ones(n))


def sparse_diag(tensor):
    i = [i for i in range(tensor.shape[0])]
    return torch.sparse_coo_tensor(torch.tensor([i, i]), tensor)


def chebyshev(L, X, k=3):
    if k == 1:
        return torch.sparse.mm(L, X)
    dp = [X, torch.sparse.mm(L, X)]
    for i in range(2, k):
        nxt = 2 * (torch.sparse.mm(L, dp[i - 1]))
        dp.append(torch.sparse.FloatTensor.add(nxt, -(dp[i - 2])))
    return torch.cat(dp, dim=1)


def torch_sparse_to_scipy_sparse(matrix):
    i = matrix.coalesce().indices().cpu()
    v = matrix.coalesce().values().cpu()

    (m, n) = matrix.shape[0], matrix.shape[1]
    return coo_matrix((v, i), shape=(m, n))


def scipy_sparse_to_torch_sparse(matrix):
    values = matrix.data
    indices = np.vstack((matrix.row, matrix.col))

    i = torch.LongTensor(indices)
    v = torch.FloatTensor(values)
    return torch.sparse.FloatTensor(i, v)


def normalise(L):
    M = L.shape[0]
    L = torch_sparse_to_scipy_sparse(L)
    topeig = spl.eigsh(L, k=1, which="LM", return_eigenvectors=False)[0]
    ret = L.copy()
    ret *= 2.0 / topeig
    ret.setdiag(np.ones(M) - ret.diagonal(0), 0)
    return scipy_sparse_to_torch_sparse(ret)


def batch_all_feature_and_lapacian_pair(X, L_i, L_v):
    X_batch, I_batch, V_batch, batch_index = [], [], [], []
    for i in range(len(X)):
        x, i, v, batch = batch_feature_and_lapacian_pair(X[i], L_i[i], L_v[i])
        X_batch.append(x)
        I_batch.append(i)
        V_batch.append(v)
        batch_index.append(batch)

    features_dct = {'features': X_batch,
                    'lapacian_indices': I_batch,
                    'lapacian_values': V_batch,
                    'batch_index': batch_index}

    # I_batch and V_batch form the indices and values of coo_sparse tensor but sparse tensors
    # cant be stored so storing them as two separate tensors
    return features_dct


def batch_feature_and_lapacian_pair(x_list, L_i_list, L_v_list):
    feature_batch = torch.cat(x_list, dim=0)
    sizes = [*map(lambda x: x.size()[0], x_list)]

    I_cat, V_cat = batch_sparse_matrix(L_i_list, L_v_list, sizes, sizes)
    batch = [[i for _ in range(sizes[i])] for i in range(len(sizes))]
    batch = torch.tensor([i for sublist in batch for i in sublist])
    return feature_batch, I_cat, V_cat, batch


def batch_sparse_matrix(L_i_list, L_v_list, size_x, size_y):
    L_i_list = list(L_i_list)
    mx_x, mx_y = 0, 0
    for i in range(1, len(L_i_list)):
        mx_x += size_x[i - 1]
        mx_y += size_y[i - 1]
        L_i_list[i][0] += mx_x
        L_i_list[i][1] += mx_y
    I_cat = torch.cat(L_i_list, dim=1)
    V_cat = torch.cat(L_v_list, dim=0)
    return I_cat, V_cat

In [None]:
#processor_template
from abc import abstractmethod, ABC
#from models.SimplicialComplex import SimplicialComplex #+

class NNProcessor(ABC):

	@abstractmethod
	def process(self, CoChain):
		# Given a CoChain object, continue to process it until the structure can be stored in inmemorydataset
		pass

	@abstractmethod
	def collate(self, objectList: list):
		# Given a list of objects which we have chosen to represent out dataset, combine into one big object to write to memory
		pass

	@abstractmethod
	def get(self, data: SimplicialComplex, slice : dict, idx : int):
		# Given an index and a collated object, take out the individual object
		pass

	@abstractmethod
	def batch(self, objectList: list):
		# Given a list of objects which are representations how we want to store data for each model, batch it and
		# store the fields in a feature_dct.
		# returns feature_dct, label which is a dictionary, torch.tensor
		pass

	@abstractmethod
	def clean_features(self, simplicialComplex: SimplicialComplex):
		# Torch sparse matrix cannot be used during multiprocessing. One way of getting past that is storing the
		# indices and values as separate tensors and combining them again when single threaded. This is done in this function
		pass


In [None]:
#sat processor
import torch
#from models.ProcessorTemplate import NNProcessor #+
#from utils import ensure_input_is_tensor #+
#from models.nn_utils import to_sparse_coo #+
#from models.nn_utils import batch_all_feature_and_lapacian_pair, correct_orientation #+
#from models.SimplicialComplex import SimplicialComplex #+
#from constants import DEVICE


class SATComplex(SimplicialComplex):

    def __init__(self, X0, X1, X2, L0, L1_up, L1_down, L2, label, batch=None):
        super().__init__(X0, X1, X2, L0, None, L2, label, batch=batch)

        self.L1_up = ensure_input_is_tensor(L1_up)
        self.L1_down = ensure_input_is_tensor(L1_down)

    def __eq__(self, other):
        x0 = torch.allclose(self.X0, other.X0, atol=1e-5)
        x1 = torch.allclose(self.X1, other.X1, atol=1e-5)
        x2 = torch.allclose(self.X2, other.X2, atol=1e-5)
        l0 = torch.allclose(self.L0, other.L0, atol=1e-5)
        l1_u = torch.allclose(self.L1_up, other.L1_up, atol=1e-5)
        l1_d = torch.allclose(self.L1_down, other.L1_down, atol=1e-5)
        l2 = torch.allclose(self.L2, other.L2, atol=1e-5)
        label = torch.allclose(self.label, other.label, atol=1e-5)
        return all([x0, x1, x2, l0, l1_u, l1_d, l2, label])

    def unpack_up_down(self):
        return [self.L1_up, self.L1_down]

    def to_device(self):
        super().to_device()
        self.L1_up = self.L1_up.to(DEVICE)
        self.L1_down = self.L1_down.to(DEVICE)


class SATProcessor(NNProcessor):

    def process(self, CoChain):
        b1, b2 = to_sparse_coo(CoChain.b1), to_sparse_coo(CoChain.b2)

        X0, X1, X2 = CoChain.X0, CoChain.X1, CoChain.X2

        L0 = torch.sparse.mm(b1, b1.t())
        L1_up = torch.sparse.mm(b2, b2.t())
        L1_down = torch.sparse.mm(b1.t(), b1)
        L2 = torch.sparse.mm(b2.t(), b2)

        L0 = correct_orientation(L0, 1)
        L1_up = correct_orientation(L1_up, 1)
        L1_down = correct_orientation(L1_down, 1)
        L2 = correct_orientation(L2, 1)

        assert (X0.shape[0] == L0.shape[0])
        assert (X1.shape[0] == L1_up.shape[0])
        assert (X1.shape[0] == L1_down.shape[0])
        assert (X2.shape[0] == L2.shape[0])

        label = CoChain.label

        return SATComplex(X0, X1, X2, L0, L1_up, L1_down, L2, label)

    def collate(self, data_list):
        X0, X1, X2 = [], [], []
        L0, L1_up, L1_dn, L2 = [], [], [], []
        label = []

        x0_total, x1_total, x2_total = 0, 0, 0
        l0_total, l1_u_total, l1_d_total, l2_total = 0, 0, 0, 0
        label_total = 0

        slices = {"X0": [0],
                  "X1": [0],
                  "X2": [0],
                  "L0": [0],
                  "L1_up": [0],
                  "L1_down": [0],
                  "L2": [0],
                  "label": [0]}

        for data in data_list:
            x0, x1, x2 = data.X0, data.X1, data.X2
            l0, l1_up, l1_dn, l2 = data.L0, data.L1_up, data.L1_down, data.L2
            l = data.label

            x0_s, x1_s, x2_s = x0.shape[0], x1.shape[0], x2.shape[0]
            l0_s, l1_u_s, l1_d_s, l2_s = l0.shape[1], l1_up.shape[1], l1_dn.shape[1], l2.shape[1]
            l_s = l.shape[0]

            X0.append(x0)
            X1.append(x1)
            X2.append(x2)
            L0.append(l0)
            L1_up.append(l1_up)
            L1_dn.append(l1_dn)
            L2.append(l2)
            label.append(l)

            x0_total += x0_s
            x1_total += x1_s
            x2_total += x2_s
            l0_total += l0_s
            l1_u_total += l1_u_s
            l1_d_total += l1_d_s
            l2_total += l2_s
            label_total += l_s

            slices["X0"].append(x0_total)
            slices["X1"].append(x1_total)
            slices["X2"].append(x2_total)
            slices["L0"].append(l0_total)
            slices["L1_up"].append(l1_u_total)
            slices["L1_down"].append(l1_d_total)
            slices["L2"].append(l2_total)
            slices["label"].append(label_total)

            del data

        del data_list

        X0 = torch.cat(X0, dim=0).cpu()
        X1 = torch.cat(X1, dim=0).cpu()
        X2 = torch.cat(X2, dim=0).cpu()
        L0 = torch.cat(L0, dim=-1).cpu()
        L1_up = torch.cat(L1_up, dim=-1).cpu()
        L1_down = torch.cat(L1_dn, dim=-1).cpu()
        L2 = torch.cat(L2, dim=-1).cpu()
        label = torch.cat(label, dim=-1).cpu()

        data = SATComplex(X0, X1, X2, L0, L1_up, L1_down, L2, label)

        return data, slices

    def get(self, data, slices, idx):
        x0_slice = slices["X0"][idx:idx + 2]
        x1_slice = slices["X1"][idx:idx + 2]
        x2_slice = slices["X2"][idx:idx + 2]
        l0_slice = slices["L0"][idx:idx + 2]
        l1_u_slice = slices["L1_up"][idx:idx + 2]
        l1_d_slice = slices["L1_down"][idx:idx + 2]
        l2_slice = slices["L2"][idx:idx + 2]
        label_slice = slices["label"][idx: idx + 2]

        X0 = data.X0[x0_slice[0]: x0_slice[1]]
        X1 = data.X1[x1_slice[0]: x1_slice[1]]
        X2 = data.X2[x2_slice[0]: x2_slice[1]]

        L0 = data.L0[:, l0_slice[0]: l0_slice[1]]
        L1_up = data.L1_up[:, l1_u_slice[0]: l1_u_slice[1]]
        L1_dn = data.L1_down[:, l1_d_slice[0]: l1_d_slice[1]]
        L2 = data.L2[:, l2_slice[0]: l2_slice[1]]

        label = data.label[label_slice[0]: label_slice[1]]

        return SATComplex(X0, X1, X2, L0, L1_up, L1_dn, L2, label)

    def batch(self, objectList):
        def unpack_SATComplex(SATComplex):
            X0, X1, X2 = SATComplex.X0, SATComplex.X1, SATComplex.X2
            L0, L1_u, L1_d, L2 = SATComplex.L0, SATComplex.L1_up, SATComplex.L1_down, SATComplex.L2

            L0_i, L0_v = L0[0:2], L0[2:3].squeeze()
            L1_u_i, L1_u_v = L1_u[0:2], L1_u[2:3].squeeze()
            L1_d_i, L1_d_v = L1_d[0:2], L1_d[2:3].squeeze()
            L2_i, L2_v = L2[0:2], L2[2:3].squeeze()

            label = SATComplex.label
            return [X0, X1, X1, X2], [L0_i, L1_u_i, L1_d_i, L2_i], [L0_v, L1_u_v, L1_d_v, L2_v], label

        unpacked_grapObject = [unpack_SATComplex(g) for g in objectList]
        X, L_i, L_v, labels = [*zip(*unpacked_grapObject)]
        X, L_i, L_v = [*zip(*X)], [*zip(*L_i)], [*zip(*L_v)]

        features_dct = batch_all_feature_and_lapacian_pair(X, L_i, L_v)

        labels = torch.cat(labels, dim=0)

        X0, X1, _, X2 = features_dct['features']
        L0_i, L1_u_i, L1_d_i, L2_i = features_dct['lapacian_indices']
        L0_v, L1_u_v, L1_d_v, L2_v = features_dct['lapacian_values']

        L0 = torch.cat([L0_i, L0_v.unsqueeze(0)], dim=0)
        L1_u = torch.cat([L1_u_i, L1_u_v.unsqueeze(0)], dim=0)
        L1_d = torch.cat([L1_d_i, L1_d_v.unsqueeze(0)], dim=0)
        L2 = torch.cat([L2_i, L2_v.unsqueeze(0)], dim=0)

        batch = features_dct['batch_index']
        del batch[1]

        complex = SATComplex(X0, X1, X2, L0, L1_u, L1_d, L2, torch.tensor([0]), batch)
        return complex, labels

    def clean_features(self, satComplex):
        satComplex.L0 = to_sparse_coo(satComplex.L0)
        satComplex.L1_up = to_sparse_coo(satComplex.L1_up)
        satComplex.L1_down = to_sparse_coo(satComplex.L1_down)
        satComplex.L2 = to_sparse_coo(satComplex.L2)
        return satComplex

    def repair(self, satComplex):
        return satComplex


In [None]:
#sat_model
import torch
import torch.nn as nn
from torch_geometric.nn import global_mean_pool
import torch.nn.functional as F
import functools
import warnings
from typing import Optional, Tuple

import torch
from torch import Tensor
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.parameter import Parameter


class SATLayer_orientated(nn.Module):

    def __init__(self, input_size, output_size, bias=True):
        super().__init__()
        self.a_1 = nn.Linear(output_size, 1, bias=bias)
        self.a_2 = nn.Linear(output_size, 1, bias=bias)
        self.a_3 = nn.Linear(output_size, 1, bias=bias)
        self.layer = nn.Linear(input_size, output_size, bias=bias)

    def forward(self, features, adj):
        """
        features : n * m dense matrix of feature vectors
        adj : n * n  sparse signed orientation matrix
        output : n * k dense matrix of new feature vectors
        """
        features = self.layer(features)
        indices = adj.coalesce().indices()
        values = adj.coalesce().values()

        a_1 = self.a_1(features.abs())
        a_2 = self.a_2(features.abs())
        a_3 = self.a_3(features.abs())
        v = (a_1 + a_2.T + a_3)[indices[0, :], indices[1, :]]
        e = torch.sparse_coo_tensor(indices, v)
        attention = torch.sparse.softmax(e, dim=1)
        a_v = torch.mul(attention.coalesce().values(), values)
        attention = torch.sparse_coo_tensor(indices, a_v)

        output = torch.sparse.mm(attention, features)

        return output


class SATLayer_regular(nn.Module):

    def __init__(self, input_size, output_size, bias=True):
        super().__init__()
        self.a_1 = nn.Linear(output_size, 1, bias=bias)
        self.a_2 = nn.Linear(output_size, 1, bias=bias)
        self.a_3 = nn.Linear(output_size, 1, bias=bias)        
        self.layer = nn.Linear(input_size, output_size, bias=bias)

    def forward(self, features, adj):
        """
        features : n * m dense matrix of feature vectors
        adj : n * n  sparse signed orientation matrix
        output : n * k dense matrix of new feature vectors
        """
        features = self.layer(features)
        indices = adj.coalesce().indices()

        a_1 = self.a_1(features)
        a_2 = self.a_2(features)
        a_3 = self.a_3(features)
        v = (a_1 + a_2.T + a_3)[indices[0, :], indices[1, :]]
        e = torch.sparse_coo_tensor(indices, v)
        attention = torch.sparse.softmax(e, dim=1)

        output = torch.sparse.mm(attention, features)

        return output

class PReLU(torch.nn.Module):
    r"""Applies the element-wise function:

    .. math::
        \text{PReLU}(x) = \max(0,x) + a * \min(0,x)

    or

    .. math::
        \text{PReLU}(x) =
        \begin{cases}
        x, & \text{ if } x \geq 0 \\
        ax, & \text{ otherwise }
        \end{cases}

    Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
    parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
    a separate :math:`a` is used for each input channel.


    .. note::
        weight decay should not be used when learning :math:`a` for good performance.

    .. note::
        Channel dim is the 2nd dim of input. When input has dims < 2, then there is
        no channel dim and the number of channels = 1.

    Args:
        num_parameters (int): number of :math:`a` to learn.
            Although it takes an int as input, there is only two values are legitimate:
            1, or the number of channels at input. Default: 1
        init (float): the initial value of :math:`a`. Default: 0.25

    Shape:
        - Input: :math:`( *)` where `*` means, any number of additional
          dimensions.
        - Output: :math:`(*)`, same shape as the input.

    Attributes:
        weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).

    .. image:: ../scripts/activation_images/PReLU.png

    Examples::

        >>> m = nn.PReLU()
        >>> input = torch.randn(2)
        >>> output = m(input)
    """
    __constants__ = ['num_parameters']
    num_parameters: int

    def __init__(self, num_parameters: int = 1, const_val: float = 0.25,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        self.num_parameters = num_parameters
        super(PReLU, self).__init__()
        self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs).fill_(const_val))
    def forward(self, input: Tensor) -> Tensor:
        return F.prelu(input, self.weight)

    def extra_repr(self) -> str:
        return 'num_parameters={}'.format(self.num_parameters)

#class PRELU(nn.PReLU):

    # def forward(self, input):
    #     return nn.PReLU(input, self.weight)


class PlanetoidSAT(nn.Module):

    def __init__(self, num_node_feats, output_size, bias=True):
        super().__init__()
        k_heads = 2
        self.layer_n = torch.nn.ModuleList([SATLayer_regular(num_node_feats, output_size, bias) for _ in range(k_heads)])
        self.layer_e = torch.nn.ModuleList([SATLayer_regular(num_node_feats, output_size, bias) for _ in range(k_heads)])
        self.layer_t = torch.nn.ModuleList([SATLayer_regular(num_node_feats, output_size, bias) for _ in range(k_heads)])
        self.layer_s = torch.nn.ModuleList([SATLayer_regular(num_node_feats, output_size, bias) for _ in range(k_heads)])   
        self.f = PReLU()

        self.tri_layer = nn.Linear(output_size, output_size)

    def forward(self, simplicialComplex, B1, B2):
        X0, X1, X2 = simplicialComplex.unpack_features()
        L0, _, L2 = simplicialComplex.unpack_laplacians()
        L1 = simplicialComplex.unpack_up_down()

        X0[X0 != 0] = 1

        X1_in, X1_out = X0[X1[:, 0]], X0[X1[:, 1]]
        X1 = torch.logical_and(X1_in, X1_out).float()

        X2_i, X2_j, X2_k = X0[X2[:, 0]], X0[X2[:, 1]], X0[X2[:, 2]]
        X2 = torch.logical_and(X2_i, torch.logical_and(X2_j, X2_k)).float()

        X0 = self.f(functools.reduce(lambda a, b: a + b, [sat(X0, L0) for sat in self.layer_n]))
        X1 = self.f(functools.reduce(lambda a, b: a + b, [sat(X1, L) for L, sat in zip(L1, self.layer_e)]))
        X2 = self.f(functools.reduce(lambda a, b: a + b, [sat(X2, L2) for sat in self.layer_t]))

        X0 = (X0 + torch.sparse.mm(B1, X1) + torch.sparse.mm(B1, self.tri_layer(torch.sparse.mm(B2, X2)))) / 3
        return X0

In [None]:
#planetoid.logreg
import torch
import torch.nn as nn
import torch.nn.functional as F

class LogReg(nn.Module):
    def __init__(self, ft_in, nb_classes):
        super(LogReg, self).__init__()
        self.fc = nn.Linear(ft_in, nb_classes)

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, seq):
        ret = self.fc(seq)
        return ret

In [None]:
#planetoid.DGI
#from models.nn_utils import convert_to_CoChain, torch_sparse_to_scipy_sparse, scipy_sparse_to_torch_sparse, \
#    normalise_boundary #+
import scipy
import numpy as np
import torch
import torch.nn as nn
#from constants import DEVICE #+


def convert_to_device(lst):
    return [i.to(DEVICE) for i in lst]


def corruption_function(simplicialComplex, processor_type, p=0.000):
    L0 = simplicialComplex.L0
    X0 = simplicialComplex.X0
    nb_nodes = X0.shape[0]
    idx = np.random.permutation(nb_nodes)
    # idx = [i for i in range(nb_nodes)]
    C_X0 = X0[idx]

    L0_i = L0.coalesce().indices().cpu()
    L0_v = -torch.ones(L0_i.shape[1])
    L0 = torch.sparse_coo_tensor(L0_i, L0_v)
    cor_adj_i = torch.triu_indices(nb_nodes, nb_nodes, 0)
    cor_adj_v = torch.tensor(np.random.binomial(1, p, size=(cor_adj_i.shape[1])), dtype=torch.float)

    # logical xor for edge insertion/deletion
    cor_adj = torch.sparse_coo_tensor(cor_adj_i, cor_adj_v)
    cor_adj = L0 + cor_adj
    cor_adj_i, cor_adj_v = cor_adj.coalesce().indices(), cor_adj.coalesce().values()
    cor_adj_v = torch.abs(cor_adj_v)
    cor_adj = torch.sparse_coo_tensor(cor_adj_i, cor_adj_v)
    cor_adj = torch_sparse_to_scipy_sparse(cor_adj)
    cor_adj = scipy.sparse.triu(cor_adj, k=1)
    cor_adj.eliminate_zeros()
    cor_adj = scipy_sparse_to_torch_sparse(cor_adj)

    fake_labels = torch.zeros(nb_nodes)
    cochain = convert_to_CoChain(cor_adj, C_X0, fake_labels)
    corrupted_train = processor_type.process(cochain)
    corrupted_train = processor_type.batch([corrupted_train])[0]
    corrupted_train = processor_type.clean_features(corrupted_train)
    corrupted_train = processor_type.repair(corrupted_train)

    b1, b2 = normalise_boundary(cochain.b1, cochain.b2)

    return corrupted_train, b1, b2


######################################################################################################
# This section is adopted from https://github.com/PetarV-/DGI/tree/61baf67d7052905c77bdeb28c22926f04e182362
######################################################################################################

class AvgReadout(nn.Module):
    def __init__(self):
        super(AvgReadout, self).__init__()

    def forward(self, seq, msk):
        if msk is None:
            return torch.mean(seq, 1)
        else:
            msk = torch.unsqueeze(msk, -1)
            return torch.sum(seq * msk, 1) / torch.sum(msk)


class Discriminator(nn.Module):
    def __init__(self, n_h):
        super(Discriminator, self).__init__()
        self.f_k = nn.Bilinear(n_h, n_h, 1)

        for m in self.modules():
            self.weights_init(m.to(DEVICE))

    def weights_init(self, m):
        if isinstance(m, nn.Bilinear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, c, h_pl, h_mi):
        c_x = torch.unsqueeze(c, 1)
        c_x = c_x.expand_as(h_pl)

        sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 2)
        sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 2)

        logits = torch.cat((sc_1, sc_2), 1)

        return logits


class DGI(nn.Module):
    def __init__(self, input_size, output_size, model):
        super(DGI, self).__init__()
        self.model = model(input_size, output_size).to(DEVICE)
        self.read = AvgReadout()

        self.sigm = nn.SELU()

        self.disc = Discriminator(output_size)

    def forward(self, simplicialComplex, b1, b2, processor_type):
        corrupted_complex, cb1, cb2 = corruption_function(simplicialComplex, processor_type)
        simplicialComplex.to_device()
        corrupted_complex.to_device()
        cb1 = cb1.to(DEVICE)
        cb2 = cb2.to(DEVICE)

        h_1 = self.model(simplicialComplex, b1, b2).unsqueeze(0)
        c = self.read(h_1, None)
        c = self.sigm(c)

        h_2 = self.model(corrupted_complex, cb1, cb2).unsqueeze(0)

        ret = self.disc(c, h_1, h_2)

        return ret

    # Detach the return variables
    def embed(self, simplicialComplex, b1, b2):
        simplicialComplex.to_device()
        h_1 = self.model(simplicialComplex, b1, b2)
        c = self.read(h_1, None)

        return h_1.detach(), c.detach()

In [None]:
#fake_dataset
from torch_geometric.datasets import Planetoid
import torch
import networkx as nx
#from models.nn_utils import convert_to_CoChain, to_sparse_coo #+
#from utils import edge_to_node_matrix, triangle_to_edge_matrix #+


class GraphObject:

    def __init__(self, x, edge_index, y, train_mask, val_mask, test_mask):
        self.x = x
        self.edge_index = edge_index
        self.y = y
        self.train_mask = train_mask
        self.val_mask = val_mask
        self.test_mask = test_mask

def gen_dataset():
    data = Planetoid('./data', "Cora")[0]

    edges = data.edge_index
    n = data.x.shape[0]
    adj = torch.zeros((n, n))
    adj = adj.index_put_(tuple(edges), torch.ones(1))
    adj = torch.triu(adj)
    edges = torch.nonzero(adj).tolist()
    nodes = [i for i in range(n)]

    g = nx.Graph()
    g.add_nodes_from(nodes)
    g.add_edges_from(edges)

    triangles = [list(sorted(x)) for x in nx.enumerate_all_cliques(g) if len(x) == 3]
    quads = [list(sorted(x)) for x in nx.enumerate_all_cliques(g) if len(x) == 4]
    quads_indices_set = set()

    labels = [0 for _ in range(n)]
    for quad in quads:
        for index in quad:
            labels[index] = 2
            quads_indices_set.add(index)

    tri_indices_set = set()

    for triangle in triangles:
        for index in triangle:
            if index not in quads_indices_set:
                labels[index] = 1
                tri_indices_set.add(index)

    y = torch.tensor(labels)

    train_mask = []
    val_mask = []
    test_mask = []

    train_no = 60
    val_no = 300
    test_no = 1000

    class_1 = 0
    class_2 = 0
    class_3 = 0

    val = 0
    test = 0

    for i in nodes:
        if class_1 + class_2 + class_3 < train_no:
            if i in tri_indices_set and class_1 < 20:
                train_mask.append(i)
                class_1 += 1
            elif i in quads_indices_set and class_2 < 20:
                train_mask.append(i)
                class_2 += 1
            elif class_3 < 20:
                train_mask.append(i)
                class_3 += 1
        elif val < val_no:
            val_mask.append(i)
            val += 1
        elif test < test_no:
            test_mask.append(i)
            test += 1

    train_index = torch.tensor(train_mask)
    train_mask = torch.zeros(n)
    train_mask.index_fill_(0, train_index, 1)
    train_mask = train_mask > 0

    test_index = torch.tensor(test_mask)
    test_mask = torch.zeros(n)
    test_mask.index_fill_(0, test_index, 1)
    test_mask = test_mask > 0

    val_index = torch.tensor(val_mask)
    val_mask = torch.zeros(n)
    val_mask.index_fill_(0, val_index, 1)
    val_mask = val_mask > 0

    # X0 = torch.sum(adj, dim = 1)
    # X0 = torch.nn.functional.one_hot(X0.long()).float()
    # X0 = adj + adj.T
    X0 = torch.ones((adj.shape[0], adj.shape[0]))
    edge_index = torch.nonzero(adj).T

    assert (X0.shape[0] == y.shape[0])
    assert (y.shape[0] == train_mask.shape[0])
    assert (y.shape[0] == val_mask.shape[0])
    assert (y.shape[0] == test_mask.shape[0])

    g = GraphObject(X0, edge_index, y, train_mask, val_mask, test_mask)
    return g



if __name__ == "__main__":
    gen_dataset()

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [None]:
#planetoid_dataset
from torch_geometric.datasets import Planetoid
from torch_geometric.data import InMemoryDataset
import numpy as np
import torch
#from models.nn_utils import convert_to_CoChain, remove_diag_sparse, to_sparse_coo, normalise_boundary, preprocess_features #+
#from Planetoid.FakeDataset import gen_dataset #+


class PlanetoidSCDataset(InMemoryDataset):

    def __init__(self, root, dataset_name, processor_type):
        self.root = root
        self.dataset_name = dataset_name
        self.processor_type = processor_type

        folder = f"{root}/{self.dataset_name}/{processor_type.__class__.__name__}"

        super().__init__(folder)
        self.data, self.slices = torch.load(self.processed_paths[0])

    def __len__(self):
        return len(self.slices["X0"]) - 1

    def load_dataset(self):
        """Load the dataset_processor from here and process it if it doesn't exist"""
        print("Loading dataset_processor from disk...")
        data, slices = torch.load(self.processed_paths[0])
        return data, slices

    @property
    def raw_file_names(self):
        return []

    def download(self):
        # Instantiating this will download and process the graph dataset_processor.
        if self.dataset_name == 'fake':
            self.data_download = gen_dataset()
        else:
            self.data_download = Planetoid(self.root, self.dataset_name)[0]
        nodes = self.data_download.x.shape[0]
        self.nodes = np.array([i for i in range(nodes)])

        self.test_split = self.data_download.test_mask
        self.train_split = self.data_download.train_mask
        self.val_split = self.data_download.val_mask

    @property
    def processed_file_names(self):
        return ["features.pt"]

    def process(self):
        data = self.data_download
        features, edges, labels = data.x, data.edge_index, data.y
        adj_ones = torch.ones(edges.shape[1])
        adj = torch.sparse_coo_tensor(edges, adj_ones)
        # features = preprocess_features(features)
        adj = remove_diag_sparse(adj)
        dataset = convert_to_CoChain(adj, features, labels)
        dataset = [self.processor_type.process(dataset)]
        data, slices = self.processor_type.collate(dataset)
        torch.save((data, slices), self.processed_paths[0])

    def get_boundary(self, edge_list, features):
        adj_ones = torch.ones(edge_list.shape[1])
        adj = torch.sparse_coo_tensor(edge_list, adj_ones)

        # features = preprocess_features(features)
        adj = remove_diag_sparse(adj)
        cochain = convert_to_CoChain(adj, features, None)
        b1, b2 = normalise_boundary(cochain.b1, cochain.b2)
        return b1, b2

    def _get_node_subsection(self, idx_list):
        dataset = self.__getitem__(0)
        idx_list = torch.tensor(idx_list)
        adj = to_sparse_coo(dataset.L0).to_dense()
        adj = torch.index_select(adj, 0, idx_list)
        adj = torch.index_select(adj, 1, idx_list)
        adj = torch.triu(adj, diagonal=1).to_sparse()
        features = dataset.X0[idx_list]
        labels = dataset.label[idx_list]
        simplicialComplex = convert_to_CoChain(adj, features, labels)
        simplicialComplex = self.processor_type.process(simplicialComplex)
        simplicialComplex = self.processor_type.batch([simplicialComplex])[0]
        simplicialComplex = self.processor_type.clean_features(simplicialComplex)
        return simplicialComplex

    def get_full(self):
        simplicialComplex = self.get(0)
        simplicialComplex = self.processor_type.batch([simplicialComplex])[0]
        simplicialComplex = self.processor_type.clean_features(simplicialComplex)
        simplicialComplex = self.processor_type.repair(simplicialComplex)
        b1, b2 = self.get_boundary(simplicialComplex.L0.coalesce().indices(), simplicialComplex.X0)
        return simplicialComplex, b1, b2

    def get_train_labels(self):
        simplicialComplex = self.get(0)
        return simplicialComplex.label[self.train_split]

    def get_val_labels(self):
        simplicialComplex = self.get(0)
        return simplicialComplex.label[self.val_split]

    def get_test_labels(self):
        simplicialComplex = self.get(0)
        return simplicialComplex.label[self.test_split]

    def get_labels(self):
        simplicialComplex = self.get(0)
        return simplicialComplex.label

    def get_train_embeds(self, embeds):
        return embeds[self.train_split]

    def get_val_embeds(self, embeds):
        return embeds[self.val_split]

    def get_test_embeds(self, embeds):
        return embeds[self.test_split]

    def __getitem__(self, idx):
        return self.processor_type.get(self.data, self.slices, idx)

    def get_name(self):
        return self.name


In [None]:
planetoid_SAT = [SATProcessor(), PlanetoidSAT]


In [None]:
!pip install tb-nightly
!pip install future

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tb-nightly
  Downloading tb_nightly-2.13.0a20230326-py3-none-any.whl (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m41.6 MB/s[0m eta [36m0:00:00[0m
Collecting tensorboard-data-server<0.8.0,>=0.7.0
  Downloading tensorboard_data_server-0.7.0-py3-none-manylinux2014_x86_64.whl (6.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.6/6.6 MB[0m [31m56.0 MB/s[0m eta [36m0:00:00[0m
Collecting google-auth-oauthlib<1.1,>=0.5
  Downloading google_auth_oauthlib-1.0.0-py2.py3-none-any.whl (18 kB)
Installing collected packages: tensorboard-data-server, google-auth-oauthlib, tb-nightly
  Attempting uninstall: tensorboard-data-server
    Found existing installation: tensorboard-data-server 0.6.1
    Uninstalling tensorboard-data-server-0.6.1:
      Successfully uninstalled tensorboard-data-server-0.6.1
  Attempting unin

In [None]:

2708, 79
dataset = 'Cora'
dataset_features_dct = {'Cora' : 1433, 'CiteSeer' : 3703, 'PubMed' : 500, 'fake' : 2708}
dataset_classes_dct = {'Cora' : 7, 'CiteSeer' : 6, 'PubMed' : 3 , 'fake' : 3}
input_size = dataset_features_dct[dataset]
output_size = 512
nb_epochs = 50
test_epochs = 5
lr = 0.001
l2_coef = 0.0
patience = 20

# nn_mod = planetoid_GCN
# nn_mod = planetoid_GAT
# nn_mod = planetoid_SCN
# nn_mod = planetoid_SCConv
nn_mod = planetoid_SAT
# nn_mod = planetoid_SAN

processor_type = nn_mod[0]
model = nn_mod[1]

dgi = DGI(input_size, output_size, model)
print(dgi)

DGI(
  (model): PlanetoidSAT(
    (layer_n): ModuleList(
      (0): SATLayer_regular(
        (a_1): Linear(in_features=512, out_features=1, bias=True)
        (a_2): Linear(in_features=512, out_features=1, bias=True)
        (a_3): Linear(in_features=512, out_features=1, bias=True)
        (layer): Linear(in_features=1433, out_features=512, bias=True)
      )
      (1): SATLayer_regular(
        (a_1): Linear(in_features=512, out_features=1, bias=True)
        (a_2): Linear(in_features=512, out_features=1, bias=True)
        (a_3): Linear(in_features=512, out_features=1, bias=True)
        (layer): Linear(in_features=1433, out_features=512, bias=True)
      )
    )
    (layer_e): ModuleList(
      (0): SATLayer_regular(
        (a_1): Linear(in_features=512, out_features=1, bias=True)
        (a_2): Linear(in_features=512, out_features=1, bias=True)
        (a_3): Linear(in_features=512, out_features=1, bias=True)
        (layer): Linear(in_features=1433, out_features=512, bias=True)


In [None]:
#from Planetoid.PlanetoidDataset import PlanetoidSCDataset #+
#from models import planetoid_GCN, planetoid_GAT, planetoid_SCN, planetoid_SCConv, planetoid_SAN, planetoid_SAT
import torch.nn as nn
import torch
#from Planetoid.DGI import DGI #+
#from Planetoid.logreg import LogReg #+
#from constants import DEVICE #+

2708, 79
dataset = 'Cora'
dataset_features_dct = {'Cora' : 1433, 'CiteSeer' : 3703, 'PubMed' : 500, 'fake' : 2708}
dataset_classes_dct = {'Cora' : 7, 'CiteSeer' : 6, 'PubMed' : 3 , 'fake' : 3}
input_size = dataset_features_dct[dataset]
output_size = 512
nb_epochs = 50
test_epochs = 5
lr = 0.001
l2_coef = 0.0
patience = 20

# nn_mod = planetoid_GCN
# nn_mod = planetoid_GAT
# nn_mod = planetoid_SCN
# nn_mod = planetoid_SCConv
nn_mod = planetoid_SAT
# nn_mod = planetoid_SAN

processor_type = nn_mod[0]
model = nn_mod[1]

dgi = DGI(input_size, output_size, model)
optimiser = torch.optim.Adam(dgi.parameters(), lr=lr, weight_decay=l2_coef)#, momentum = 0.9, nesterov = True)
b_xent = nn.BCEWithLogitsLoss()
xent = nn.CrossEntropyLoss()

if __name__ == "__main__":

    data = PlanetoidSCDataset('./data', dataset, processor_type)
     
    for v in data.slices.items():
      data.slices = [v]
    
    data_full, b1, b2 = data.get_full()

    cnt_wait = 0
    best = 1e9
    best_t = 0
    bl = False
    b1 = b1.to(DEVICE)
    b2 = b2.to(DEVICE)
    for epoch in range(nb_epochs):
        dgi.train()
        optimiser.zero_grad()

        nb_nodes = data_full.X0.shape[0]
        lbl_1 = torch.ones(1, nb_nodes)
        lbl_2 = torch.zeros(1, nb_nodes)

        lbl = torch.cat((lbl_1, lbl_2), 1).to(DEVICE)

        logits = dgi(data_full, b1, b2, processor_type)

        loss = b_xent(logits, lbl)

        print('Loss:', loss)

        if loss < best:
            best = loss
            best_t = epoch
            cnt_wait = 0
            torch.save(dgi.state_dict(), f'./data/{model.__name__}_dgi.pkl')
            if epoch != 0:
                bl = True
        else:
            if bl:
                cnt_wait += 1

        if cnt_wait == patience:
            print('Early stopping!')
            break

        loss.backward()
        optimiser.step()

    print('Loading {}th epoch'.format(best_t))
    dgi.load_state_dict(torch.load(f'./data/{model.__name__}_dgi.pkl'))

    embeds, _ = dgi.embed(data_full, b1, b2)
    # embeds = data_full.X0.to(DEVICE)
    # output_size = 79
    # with open("./embeddings.py", 'w') as f:
    #     f.write(f'embeddings = {embeds.tolist()}')
    # with open("./labels.py", 'w') as f:
    #     f.write(f'labels = {data.get_labels().tolist()}')
    train_embs = data.get_train_embeds(embeds)
    val_embs = data.get_val_embeds(embeds)
    test_embs = data.get_test_embeds(embeds)

    train_lbls = data.get_train_labels().to(DEVICE)
    x_unique = train_lbls.unique(sorted=True)
    x_unique_count = torch.stack([(train_lbls == x_u).sum() for x_u in x_unique])
    val_lbls = data.get_val_labels().to(DEVICE)
    test_lbls = data.get_test_labels().to(DEVICE)

    tot = torch.zeros(1).to(DEVICE)

    accs = []

    for _ in range(test_epochs):
        log = LogReg(output_size, dataset_classes_dct[dataset])
        opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0)
        log.to(DEVICE)

        pat_steps = 0
        best_acc = torch.zeros(1)
        best_acc = best_acc.to(DEVICE)

        for _ in range(100):
            log.train()
            opt.zero_grad()

            logits = log(train_embs)
            loss = xent(logits, train_lbls)

            loss.backward()
            opt.step()

        logits = log(test_embs)
        preds = torch.argmax(logits, dim=1)
        acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0]
        accs.append(acc * 100)
        print(model.__name__)
        print(acc)
        tot += acc

    print('Average accuracy:', tot / test_epochs)

    accs = torch.stack(accs)
    print(accs.mean())
    print(accs.std())

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!
Processing...
Done!


Loss: tensor(0.6931, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
Loss: tensor(0.6929, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
Loss: tensor(0.6920, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
Loss: tensor(0.6901, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
Loss: tensor(0.6868, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
Loss: tensor(0.6805, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
Loss: tensor(0.6724, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
Loss: tensor(0.6580, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
Loss: tensor(0.6437, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
Loss: tensor(0.6246, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
Loss: tensor(0.6171, device='cuda:0',
       grad_fn=<Binary

In [None]:
data = PlanetoidSCDataset('./data', dataset, processor_type)
print(data)

PlanetoidSCDataset()


In [None]:
print(data)

PlanetoidSCDataset()
