In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import SAGEConv, to_hetero
from model import ToyNet

In [3]:
data = OGB_MAG(root='hetero/data', preprocess='metapath2vec', transform=T.ToUndirected())[0]

In [4]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


In [5]:
model = GNN(hidden_channels=64, out_channels=2)

In [6]:
model = to_hetero(model, data.metadata(), aggr='sum')

In [7]:
#graph_meta = (['txt_src', 'img_src', 'ques'], [('ques','contains','txt_src'), ('ques','contains','img_src')])
graph_meta = (['txt_src', 'img_src'], [('txt_src','link1','txt_src'), ('img_src','contains','img_src'), 
('txt_src','link1','img_src'), ('img_src','link1','txt_src')])

In [8]:
model = to_hetero(model, graph_meta, aggr='sum')

In [9]:
model2 = ToyNet()

In [10]:
model2 = to_hetero(model2, graph_meta, aggr='sum')

In [11]:
data

HeteroData(
  [1mpaper[0m={
    x=[736389, 128],
    year=[736389],
    y=[736389],
    train_mask=[736389],
    val_mask=[736389],
    test_mask=[736389]
  },
  [1mauthor[0m={ x=[1134649, 128] },
  [1minstitution[0m={ x=[8740, 128] },
  [1mfield_of_study[0m={ x=[59965, 128] },
  [1m(author, affiliated_with, institution)[0m={ edge_index=[2, 1043998] },
  [1m(author, writes, paper)[0m={ edge_index=[2, 7145660] },
  [1m(paper, cites, paper)[0m={ edge_index=[2, 10792672] },
  [1m(paper, has_topic, field_of_study)[0m={ edge_index=[2, 7505078] },
  [1m(institution, rev_affiliated_with, author)[0m={ edge_index=[2, 1043998] },
  [1m(paper, rev_writes, author)[0m={ edge_index=[2, 7145660] },
  [1m(field_of_study, rev_has_topic, paper)[0m={ edge_index=[2, 7505078] }
)

In [26]:
data.x_dict

{'paper': tensor([[-0.0954,  0.0408, -0.2109,  ...,  0.0616, -0.0277, -0.1338],
         [-0.1510, -0.1073, -0.2220,  ...,  0.3458, -0.0277, -0.2185],
         [-0.1148, -0.1760, -0.2606,  ...,  0.1731, -0.1564, -0.2780],
         ...,
         [ 0.0228, -0.0865,  0.0981,  ..., -0.0547, -0.2077, -0.2305],
         [-0.2891, -0.2029, -0.1525,  ...,  0.1042,  0.2041, -0.3528],
         [-0.0890, -0.0348, -0.2642,  ...,  0.2601, -0.0875, -0.5171]]),
 'author': tensor([[-0.4683,  0.1084, -0.0180,  ..., -0.2873,  0.3973,  0.0373],
         [ 0.1035, -0.3703, -0.3722,  ...,  0.5777,  0.0044, -0.3645],
         [ 0.3745,  0.0797,  0.3995,  ...,  0.0166, -0.5806, -0.1265],
         ...,
         [-0.0076,  0.6291,  0.0684,  ...,  0.0279,  0.1603, -0.0225],
         [ 0.1657, -0.1814,  0.2352,  ..., -0.4000, -0.4608, -0.7904],
         [-0.4098,  0.0470, -0.2027,  ...,  0.1393, -0.1985, -0.6175]]),
 'institution': tensor([[ 0.9148, -0.4798, -0.5734,  ...,  0.5746,  0.0610,  0.4985],
         [-

In [19]:
import numpy as np
y = np.zeros(10)
node_idx = [i for i in range(len(y))]
source_nodes = []
for i in range(len(y)):
    source_nodes += [i]*(len(y)-1)
target_nodes = []
for i in range(len(y)):
    target_nodes += node_idx[:i] + node_idx[i+1:]

In [12]:
model2

GraphModule(
  (linq): ModuleDict(
    (txt_src): Linear(in_features=768, out_features=512, bias=True)
    (img_src): Linear(in_features=768, out_features=512, bias=True)
  )
  (blinq): ModuleDict(
    (txt_src): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (img_src): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (linc): ModuleDict(
    (txt_src): Linear(in_features=768, out_features=512, bias=True)
    (img_src): Linear(in_features=768, out_features=512, bias=True)
  )
  (blinc): ModuleDict(
    (txt_src): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (img_src): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (lini): ModuleDict(
    (txt_src): Linear(in_features=2048, out_features=512, bias=True)
    (img_src): Linear(in_features=2048, out_features=512, bias=True)
  )
  (blini): ModuleDict(
    (txt_src): BatchNorm1d(512, eps=1e

In [14]:
srcs = [0,1,2,3]
tar = [0,1,2]
num_srcs = len(srcs)
num_targets = len(tar)

In [18]:
source_nodes = []
for i in range(num_srcs):
    source_nodes += [i]*(num_targets)

In [20]:
source_nodes

[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]

In [23]:
target_nodes = tar*num_srcs

In [24]:
edge_index = torch.tensor([source_nodes, target_nodes], 
                    dtype=torch.long)

In [25]:
edge_index

tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],
        [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]])

In [33]:
import pickle
import json

In [34]:
with open("../model/WebQA_train_val.json", 'r') as f:
    dump_j = json.load(f)

In [28]:
with open("../model/ImgQueries_embeddings.pkl", 'rb') as f:
    dump = pickle.load(f)

In [36]:
dump_j['d5bbc6d80dba11ecb1e81171463288e9']['txt_negFacts']

[{'title': 'Xanadu Houses',
  'fact': 'Construction of the Xanadu house in Kissimmee, Florida, began with the pouring of a concrete slab base and the erection of a tension ring 40 feet (12 m) in diameter to anchor the domed roof of what would become the "Great Room" of the house.',
  'url': 'https://en.wikipedia.org/wiki/Xanadu_Houses',
  'snippet_id': 'd5bbc6d80dba11ecb1e81171463288e9_6'},
 {'title': 'Xanadu Houses',
  'fact': 'The Xanadu house in Kissimmee, Florida used an automated system controlled by Commodore microcomputers. The house had fifteen rooms; of these the kitchen, party room, health spa, and bedrooms all used computers and other electronic equipment heavily in their design.',
  'url': 'https://en.wikipedia.org/wiki/Xanadu_Houses',
  'snippet_id': 'd5bbc6d80dba11ecb1e81171463288e9_7'},
 {'title': 'Booker T. Washington',
  'fact': "In 1946, he was honored on the first coin to feature an African American, the Booker T. Washington Memorial Half Dollar, which was minted by 

In [38]:
dump['d5bbc6d80dba11ecb1e81171463288e9']['txt_negFacts']['d5bbc6d80dba11ecb1e81171463288e9_6'].shape

(768,)