In [None]:
import numpy as np
import networkx as nx
import random
import torch
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn import metrics

import load_data as ld
import func
import evaluation
import algorithms.multiorg_SDNE as sdne

In [None]:
org_num = 2
org0 = 'cel'
org1 = 'sce'

#load network
g0 = ld.load_ppi(org0, k_core=0, lcc=False)
g1 = ld.load_ppi(org1, k_core=0, lcc=False)

In [None]:
#node2index and index2node mapping
g0_node2index = func.node2index(g0)
g0_node2index = defaultdict(lambda:-1, g0_node2index)
g1_node2index = func.node2index(g1)
g1_node2index = defaultdict(lambda:-1, g1_node2index)
g0_index2node = func.index2node(g0)
g1_index2node = func.index2node(g1)

In [None]:
#load ontology file
ontology_file = org0 + '_' + org1 + '_ontology_pairs.txt'
ontology = ld.load_go_pairs(org0, org1, ontology_file)
ontology = ld.filter_anchor(ontology, g0_node2index, g1_node2index)
print('ontology', len(ontology))

#load ortholog file
ortholog = ld.load_anchor(org0, org1)
ortholog = ld.filter_anchor(ortholog, g0_node2index, g1_node2index)
print('ortholog', len(ortholog))

ortholog_set = set()
ortholog_matrix = np.zeros((len(g0.nodes()), len(g1.nodes())), dtype=int)

for i, j, k in ortholog:
    i_idx = g0_node2index[i]
    j_idx = g1_node2index[j]
    ortholog_set.add((i_idx, j_idx))
    ortholog_matrix[i_idx][j_idx] = 1



ontology_set = set()
ontology_matrix = np.zeros((len(g0.nodes()), len(g1.nodes())), dtype=int)
for i,j in ontology:
    i_idx = g0_node2index[i]
    j_idx = g1_node2index[j]
    ontology_set.add((i_idx, j_idx))
    ontology_matrix[i_idx][j_idx] = 1

In [None]:
#select gene with at least one annotation
org0_annotations = np.sum(ontology_matrix, axis=1)
org1_annotations = np.sum(ontology_matrix, axis=0)
org0_ontology_indexes = [x for x in range(len(org0_annotations)) if org0_annotations[x]>0]
org1_ontology_indexes = [x for x in range(len(org1_annotations)) if org1_annotations[x]>0]

test_matrix = ontology_matrix[org0_indexes][:,org1_indexes]

In [None]:
device = 'cpu'
#initialize model
joint = sdne.SDNEJoint([len(g0.nodes), len(g1.nodes)], hidden_layers=[1024, 128], device=device)
joint.to(device)
for x in joint.encoders:
    x.to(device)
for x in joint.decoders:
    x.to((device))
model0 = sdne.SDNE(g0, 0, joint, device=device)
model1 = sdne.SDNE(g1, 1, joint, device=device)
optimizer_align=torch.optim.Adam(joint.encoders_parameters(0)+
                                 joint.encoders_parameters(1))
models = [model0, model1]
indexes = [(x,y) for x in range(org_num) for y in range(x+1, org_num)]

In [None]:
#model training
for itr in range(5):
    for m in models:
        m.fit(batch_size=128, epochs=10, verbose=0)
    sdne.cross_training(models, [ortholog_set], indexes, 
                   [ortholog_matrix], optimizer_align, device=device)

In [None]:
#get embedding and score matrix
emb0 = m0.get_embeddings()
emb1 = m1.get_embeddings()
S = emb0.dot(emb1.T)

In [None]:
evaluation.evaluate_all(evaluate_all(S[org0_indexes][:,org1_indexes], test_matrix))