In [80]:
import torch
import pickle
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import math


In [103]:
def positional_encoding(d_model, depth_vec):
    size,_ = depth_vec.shape

    div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))

    pe = torch.zeros(size, d_model)

    pe[:, 0::2] = torch.sin(depth_vec * div_term)
    pe[:, 1::2] = torch.cos(depth_vec * div_term)

    return pe

def to_combined_batch(list_data, data_dict, embedding_dim):
    batch_list = []
    for (x1, x2, y) in list_data:
        # x1/x_t is conj, x2/x_s is stmt
        conj = x1
        stmt = x2

        conj_graph = data_dict[conj]
        stmt_graph = data_dict[stmt]

        # batch_list.append(
        #     LinkData(edge_index_s=stmt_graph.edge_index, x_s=stmt_graph.x, edge_attr_s=stmt_graph.edge_attr, edge_index_t=conj_graph.edge_index, x_t=conj_graph.x, edge_attr_t=conj_graph.edge_attr, y=torch.tensor(y)))

        # Concatenate node feature matrices
        combined_features = torch.cat([conj_graph.x, stmt_graph.x], dim=0)

        # Combine edge indices
        num_nodes_g1 = conj_graph.num_nodes
        edge_index1 = conj_graph.edge_index
        edge_index2 = stmt_graph.edge_index +num_nodes_g1
        combined_edge_index = torch.cat([edge_index1, edge_index2], dim=1)

        # combine edge attributes
        edge_attr1 = conj_graph.edge_attr
        edge_attr2 = stmt_graph.edge_attr +num_nodes_g1
        combined_edge_attr = torch.cat([edge_attr1, edge_attr2], dim=0)


        # Compute disjoint pairwise complete edge indices
        complete_edge_index1 = torch.cartesian_prod(torch.arange(num_nodes_g1),
                                                    torch.arange(num_nodes_g1))  # All pairs of nodes in conj_graph

        complete_edge_index2 = torch.cartesian_prod(torch.arange(num_nodes_g1, num_nodes_g1 + stmt_graph.num_nodes),
                                                    torch.arange(num_nodes_g1,
                                                                num_nodes_g1 + stmt_graph.num_nodes))  # All pairs of nodes in stmt_graph

        complete_edge_index = torch.cat([complete_edge_index1, complete_edge_index2], dim=0).t().contiguous()


        # positional encodings

        graph_ind = torch.cat([torch.ones(num_nodes_g1), torch.ones(stmt_graph.num_nodes) * 2], dim=0)
        pos_enc = positional_encoding(embedding_dim, graph_ind.unsqueeze(1))


        #append combined graph to batch
        batch_list.append(CombinedGraphData(combined_x=combined_features,combined_edge_index=combined_edge_index, combined_edge_attr=combined_edge_attr, complete_edge_index=complete_edge_index, num_nodes_g1=num_nodes_g1, pos_enc=pos_enc, y=y))

    loader = iter(DataLoader(batch_list, batch_size=len(batch_list)))

    batch = next(iter(loader))

    return batch


class CombinedGraphData(Data):
    def __init__(self, combined_x, combined_edge_index, num_nodes_g1, y, combined_edge_attr=None, complete_edge_index=None, pos_enc = None):
        super().__init__()
        self.y = y
        # node features concatenated along first dimension
        self.x = combined_x
        # adjacency matrix representing nodes from both graphs. Nodes from second graph have num_nodes_g1 added so they represent disjoint sets, but can be computed in parallel
        self.edge_index = combined_edge_index
        self.num_nodes_g1 = num_nodes_g1

        # combined edge features in format as above
        self.combined_edge_attr = combined_edge_attr

        self.complete_edge_index = complete_edge_index
        self.pos_enc = pos_enc


In [104]:
batch = to_combined_batch(train[:2], torch_graph_dict, 128)


In [105]:
batch


CombinedGraphDataBatch(y=[2], x=[117, 1000], edge_index=[2, 159], num_nodes_g1=[2], combined_edge_attr=[159], complete_edge_index=[2, 4183], pos_enc=[117, 128], batch=[117], ptr=[3])

In [106]:
batch.num_nodes_g1

tensor([53, 22])

In [123]:
batch.combined_edge_attr

tensor([ 0.,  1.,  2.,  3.,  4.,  0.,  1.,  0.,  1.,  0.,  1.,  0.,  1.,  2.,
         3.,  0.,  1.,  0.,  1.,  0.,  0.,  0.,  1.,  2.,  3.,  0.,  0.,  1.,
         0.,  0.,  0.,  1.,  0.,  1.,  0.,  1.,  0.,  1.,  2.,  3.,  4.,  0.,
         0.,  1.,  2.,  3.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,
         0.,  0.,  0.,  0.,  1.,  2.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,  0.,
         0.,  1.,  2.,  3.,  4.,  0.,  1.,  0.,  1.,  0.,  1.,  2.,  0.,  1.,
        53., 53., 54., 53., 53., 54., 53., 54., 53., 54., 53., 54., 53., 54.,
        53., 53., 54., 53., 54., 53., 53., 54.,  0.,  0.,  1.,  0.,  1.,  0.,
         1.,  0.,  1.,  0.,  1.,  0.,  1.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,
         0.,  1.,  0.,  1.,  0.,  1.,  0.,  1.,  0.,  1.,  0.,  0.,  1.,  0.,
         1., 22., 23., 22., 22., 23., 22., 22., 23., 22., 23., 22., 22., 23.,
        22., 22., 22., 22., 23.])

In [121]:
batch.pos_enc[52] - batch.pos_enc[53]

tensor([-6.7826e-02,  9.5645e-01, -2.2533e-01,  8.0834e-01, -3.1592e-01,
         6.6081e-01, -3.5853e-01,  5.2777e-01, -3.6896e-01,  4.1455e-01,
        -3.5916e-01,  3.2171e-01, -3.3759e-01,  2.4746e-01, -3.1002e-01,
         1.8912e-01, -2.8014e-01,  1.4384e-01, -2.5028e-01,  1.0901e-01,
        -2.2177e-01,  8.2390e-02, -1.9534e-01,  6.2150e-02, -1.7131e-01,
         4.6812e-02, -1.4975e-01,  3.5220e-02, -1.3060e-01,  2.6477e-02,
        -1.1369e-01,  1.9892e-02, -9.8836e-02,  1.4938e-02, -8.5840e-02,
         1.1213e-02, -7.4498e-02,  8.4153e-03, -6.4619e-02,  6.3143e-03,
        -5.6027e-02,  4.7372e-03, -4.8562e-02,  3.5535e-03, -4.2082e-02,
         2.6655e-03, -3.6461e-02,  1.9992e-03, -3.1586e-02,  1.4994e-03,
        -2.7360e-02,  1.1245e-03, -2.3698e-02,  8.4335e-04, -2.0525e-02,
         6.3246e-04, -1.7776e-02,  4.7427e-04, -1.5395e-02,  3.5566e-04,
        -1.3332e-02,  2.6673e-04, -1.1546e-02,  1.9997e-04, -9.9988e-03,
         1.4997e-04, -8.6589e-03,  1.1247e-04, -7.4

In [120]:
batch.pos_enc[53]

tensor([ 9.0930e-01, -4.1615e-01,  9.8705e-01, -1.6044e-01,  9.9748e-01,
         7.0948e-02,  9.6323e-01,  2.6869e-01,  9.0213e-01,  4.3146e-01,
         8.2710e-01,  5.6205e-01,  7.4690e-01,  6.6493e-01,  6.6713e-01,
         7.4494e-01,  5.9113e-01,  8.0658e-01,  5.2071e-01,  8.5373e-01,
         4.5669e-01,  8.8962e-01,  3.9926e-01,  9.1684e-01,  3.4821e-01,
         9.3742e-01,  3.0314e-01,  9.5295e-01,  2.6355e-01,  9.6464e-01,
         2.2891e-01,  9.7345e-01,  1.9867e-01,  9.8007e-01,  1.7233e-01,
         9.8504e-01,  1.4942e-01,  9.8877e-01,  1.2951e-01,  9.9158e-01,
         1.1223e-01,  9.9368e-01,  9.7240e-02,  9.9526e-01,  8.4239e-02,
         9.9645e-01,  7.2970e-02,  9.9733e-01,  6.3203e-02,  9.9800e-01,
         5.4741e-02,  9.9850e-01,  4.7410e-02,  9.9888e-01,  4.1059e-02,
         9.9916e-01,  3.5558e-02,  9.9937e-01,  3.0794e-02,  9.9953e-01,
         2.6667e-02,  9.9964e-01,  2.3094e-02,  9.9973e-01,  1.9999e-02,
         9.9980e-01,  1.7318e-02,  9.9985e-01,  1.4

In [8]:
with open("/home/sean/Documents/phd/aitp/data/hol4/train_test_data.pk", "rb") as f:
    train, val, test, enc_nodes = pickle.load(f)

with open("/home/sean/Documents/phd/aitp/data/hol4/torch_graph_dict.pk", "rb") as f:
    torch_graph_dict = pickle.load(f)


In [13]:
torch_graph_dict[train[0][0]]

Data(x=[53, 1000], edge_index=[2, 84], edge_attr=[84], labels=[53])

In [53]:

print(torch_graph_dict[train[0][0]].edge_attr)
print(torch_graph_dict[train[0][1]].edge_attr)

print(torch_graph_dict[train[1][0]].edge_attr)
print(torch_graph_dict[train[1][1]].edge_attr)


print (train[1])

tensor([0., 1., 2., 3., 4., 0., 1., 0., 1., 0., 1., 0., 1., 2., 3., 0., 1., 0.,
        1., 0., 0., 0., 1., 2., 3., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 1.,
        0., 1., 2., 3., 4., 0., 0., 1., 2., 3., 0., 0., 0., 0., 1., 0., 0., 1.,
        0., 1., 0., 0., 0., 0., 1., 2., 0., 1., 0., 0., 1., 0., 1., 0., 0., 1.,
        2., 3., 4., 0., 1., 0., 1., 0., 1., 2., 0., 1.])
tensor([0., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 0., 1., 0.,
        1., 0., 0., 1.])
tensor([0., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 0., 1.,
        0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 0., 1., 0., 1.])
tensor([0., 1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1.])
('@ C$bool$ ! | Vn @ C$bool$ ! | Vm @ C$bool$ ! | Vp @ @ C$bool$ /\\ @ @ C$min$ = @ @ C$arithmetic$ <= @ @ C$arithmetic$ MIN Vm Vn Vp @ @ C$bool$ \\/ @ @ C$arithmetic$ <= Vm Vp @ @ C$arithmetic$ <= Vn Vp @ @ C$min$ = @ @ C$arithmetic$ <= Vp @ @ C$arithmetic$ MIN Vm Vn @ @ C$bool$

4183

In [59]:
batch


CombinedGraphDataBatch(y=[2], x=[117, 1000], edge_index=[2, 159], num_nodes_g1=[2], combined_edge_attr=[159], complete_edge_index=[2, 4183], batch=[117], ptr=[3])

In [50]:
batch.num_nodes_g1

tensor([53, 22])

In [28]:
batch.y

tensor([1, 0])