In [2]:
import torch


In [9]:
# Load the graphs
train_data = torch.load('data/directed_train_graph.pt')
val_data   = torch.load('data/directed_val_graph.pt')
test_data  = torch.load('data/directed_test_graph.pt')


In [10]:
print('Directed Training Data')
print('----------------------')
print(train_data)

print('Directed Validation Data')
print('------------------------')
print(val_data)

print('Directed Test Data')
print('------------------')
print(test_data)


Directed Training Data
----------------------
HeteroData(
  lncRNA={ x=[1269, 2] },
  protein={ x=[11585, 2] },
  (lncRNA, interacts, protein)={
    edge_index=[2, 5119],
    edge_attr=[5119, 4],
  },
  (protein, interacts, protein)={
    edge_index=[2, 59544],
    edge_attr=[59544, 1],
  }
)
Directed Validation Data
------------------------
HeteroData(
  lncRNA={ x=[1269, 2] },
  protein={ x=[11585, 2] },
  (lncRNA, interacts, protein)={
    edge_index=[2, 640],
    edge_attr=[640, 4],
  },
  (protein, interacts, protein)={
    edge_index=[2, 7443],
    edge_attr=[7443, 1],
  }
)
Directed Test Data
------------------
HeteroData(
  lncRNA={ x=[1269, 2] },
  protein={ x=[11585, 2] },
  (lncRNA, interacts, protein)={
    edge_index=[2, 640],
    edge_attr=[640, 4],
  },
  (protein, interacts, protein)={
    edge_index=[2, 7443],
    edge_attr=[7443, 1],
  }
)


In [11]:
from torch_geometric.data import HeteroData

def make_undirected(data: HeteroData):
    new_data = HeteroData()

    # Copy original node features
    for node_type in data.node_types:
        new_data[node_type].x = data[node_type].x

    # For every edge type, add reverse edges
    for edge_type in data.edge_types:
        src, rel, dst = edge_type
        edge_index = data[edge_type].edge_index
        edge_attr = data[edge_type].edge_attr if 'edge_attr' in data[edge_type] else None

        # Add original
        new_data[edge_type].edge_index = edge_index
        if edge_attr is not None:
            new_data[edge_type].edge_attr = edge_attr

        # Add reverse
        reversed_edge_index = edge_index[[1, 0], :]
        reversed_edge_type = (dst, rel + '_rev', src)
        new_data[reversed_edge_type].edge_index = reversed_edge_index
        if edge_attr is not None:
            new_data[reversed_edge_type].edge_attr = edge_attr  # same attr

    return new_data


In [12]:
train_data = make_undirected(train_data)
val_data = make_undirected(val_data)
test_data = make_undirected(test_data)

In [8]:
print('Undirected Training Data')
print('----------------------')
print(train_data)

print('Undirected Validation Data')
print('------------------------')
print(val_data)

print('Undirected Test Data')
print('------------------')
print(test_data)


Undirected Training Data
----------------------
HeteroData(
  lncRNA={ x=[1269, 2] },
  protein={ x=[11585, 2] },
  (lncRNA, interacts, protein)={
    edge_index=[2, 5119],
    edge_attr=[5119, 4],
  },
  (protein, interacts_rev, lncRNA)={
    edge_index=[2, 5119],
    edge_attr=[5119, 4],
  },
  (protein, interacts, protein)={
    edge_index=[2, 59544],
    edge_attr=[59544, 1],
  },
  (protein, interacts_rev, protein)={
    edge_index=[2, 59544],
    edge_attr=[59544, 1],
  }
)
Undirected Validation Data
------------------------
HeteroData(
  lncRNA={ x=[1269, 2] },
  protein={ x=[11585, 2] },
  (lncRNA, interacts, protein)={
    edge_index=[2, 640],
    edge_attr=[640, 4],
  },
  (protein, interacts_rev, lncRNA)={
    edge_index=[2, 640],
    edge_attr=[640, 4],
  },
  (protein, interacts, protein)={
    edge_index=[2, 7443],
    edge_attr=[7443, 1],
  },
  (protein, interacts_rev, protein)={
    edge_index=[2, 7443],
    edge_attr=[7443, 1],
  }
)
Undirected Test Data
-------------

In [13]:
## save graph.. 
torch.save(train_data, 'data/undirected_train_graph.pt')
torch.save(val_data,   'data/undirected_val_graph.pt')
torch.save(test_data,  'data/undirected_test_graph.pt')