In [2]:
%load_ext autoreload
%autoreload 2

import sys, os
sys.path.append(os.path.abspath(os.path.join('..')))

import matplotlib.pyplot as plt 
import torch
import torch.nn as nn
import seaborn as sns
import numpy as np
from scipy.stats import pearsonr
import pandas as pd
import networkx as nx
from tqdm import tqdm_notebook as tqdm
import goatools
from goatools.base import download_go_basic_obo, download_ncbi_associations
from goatools.obo_parser import GODag
from goatools.associations import read_ncbi_gene2go
from goatools.go_enrichment import GOEnrichmentStudy


from dpp.util import Params, prepare_sns, load_mapping
from dpp.data.associations import load_diseases

prepare_sns(sns, {})

In [3]:
num_nodes = 10
embedding_size = 8

In [24]:
adj = torch.randint(low=0, high=2, size=(num_nodes, num_nodes))
# enforce symmetry 
for row in range(adj.shape[0]):
    for col in range(adj.shape[1]):
        if row > col:
            adj[row, col] = adj[col, row]

In [25]:
embeddings = torch.rand(size=(num_nodes, embedding_size))
print(embeddings)
embeddings = nn.Embedding.from_pretrained(embeddings)

tensor([[0.6506, 0.7550, 0.8824, 0.7077, 0.0434, 0.7944, 0.4711, 0.3924],
        [0.7747, 0.9373, 0.3369, 0.7672, 0.4132, 0.9224, 0.5649, 0.8292],
        [0.7782, 0.4631, 0.8377, 0.0066, 0.3016, 0.8193, 0.6468, 0.5768],
        [0.4460, 0.1819, 0.7086, 0.5556, 0.9397, 0.8525, 0.1167, 0.2807],
        [0.4184, 0.2806, 0.7493, 0.4164, 0.1621, 0.7129, 0.4619, 0.1291],
        [0.3676, 0.0625, 0.3834, 0.9405, 0.7357, 0.8351, 0.2951, 0.0453],
        [0.0531, 0.7144, 0.6921, 0.1820, 0.4543, 0.0888, 0.5622, 0.2063],
        [0.7748, 0.7176, 0.3569, 0.3097, 0.5829, 0.7407, 0.2536, 0.0636],
        [0.4108, 0.2819, 0.2285, 0.9151, 0.2669, 0.4498, 0.9203, 0.8702],
        [0.7773, 0.1691, 0.2494, 0.1585, 0.0511, 0.3216, 0.5251, 0.6897]])


In [26]:
disease_node = torch.tensor(0) 
query_node = torch.tensor(4)

In [27]:
mutual_interactors = torch.nonzero(adj[disease_node] * adj[query_node]).squeeze()

In [28]:
bilinear = nn.Bilinear(embedding_size, embedding_size, embedding_size)

In [31]:
embeddings(mutual_interactors).shape

torch.Size([3, 8])

In [33]:
disease_projs = bilinear(embeddings(disease_node).unsqueeze(0).expand(mutual_interactors.shape[0], -1), embeddings(mutual_interactors))

In [35]:
query_projs = bilinear(embeddings(disease_node).unsqueeze(0).expand(mutual_interactors.shape[0], -1), embeddings(mutual_interactors))

In [39]:
print(query_projs.shape)
print(disease_projs.shape)

torch.Size([3, 8])
torch.Size([3, 8])


In [44]:
torch.bmm(disease_projs.unsqueeze(1), query_projs.unsqueeze(2)).squeeze()

tensor([5.5901, 2.4591, 3.0949], grad_fn=<SqueezeBackward0>)

In [57]:
np.random.rand()

0.2480499777325702