In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from torch.nn.parameter import Parameter


In [3]:
class MolecularGraphNeuralNetwork(nn.Module):
    def __init__(self, N_fingerprint, dim, layer_hidden, device):
        super(MolecularGraphNeuralNetwork, self).__init__()
        self.embed_fingerprint = nn.Embedding(N_fingerprint, dim)
        self.W_fingerprint = nn.ModuleList([nn.Linear(dim, dim)
                                            for _ in range(layer_hidden)])
        self.device = device
        self.layer_hidden = layer_hidden

    def pad(self, matrices, pad_value):
        """Pad the list of matrices
        with a pad_value (e.g., 0) for batch processing.
        For example, given a list of matrices [A, B, C],
        we obtain a new matrix [A00, 0B0, 00C],
        where 0 is the zero (i.e., pad value) matrix.
        """
        shapes = [m.shape for m in matrices]
        M, N = sum([s[0] for s in shapes]), sum([s[1] for s in shapes])
        zeros = torch.FloatTensor(np.zeros((M, N))).to(self.device)
        pad_matrices = pad_value + zeros
        i, j = 0, 0
        for k, matrix in enumerate(matrices):
            m, n = shapes[k]
            pad_matrices[i:i+m, j:j+n] = matrix
            i += m
            j += n
        return pad_matrices

    def update(self, matrix, vectors, layer):
        hidden_vectors = torch.relu(self.W_fingerprint[layer](vectors))
        return hidden_vectors + torch.mm(matrix, hidden_vectors)

    def sum(self, vectors, axis):
        sum_vectors = [torch.sum(v, 0) for v in torch.split(vectors, axis)]
        return torch.stack(sum_vectors)

    def mean(self, vectors, axis):
        mean_vectors = [torch.mean(v, 0) for v in torch.split(vectors, axis)]
        return torch.stack(mean_vectors)

    def forward(self, inputs):

        """Cat or pad each input data for batch processing."""
        fingerprints, adjacencies, molecular_sizes = inputs
        fingerprints = torch.cat(fingerprints)
        adjacencies = self.pad(adjacencies, 0)

        """GNN layer (update the fingerprint vectors)."""
        fingerprint_vectors = self.embed_fingerprint(fingerprints)
        for l in range(self.layer_hidden):
            hs = self.update(adjacencies, fingerprint_vectors, l)
            # fingerprint_vectors = F.normalize(hs, 2, 1)  # normalize.
            fingerprint_vectors = hs

        """Molecular vector by sum or mean of the fingerprint vectors."""
        molecular_vectors = self.sum(fingerprint_vectors, molecular_sizes)
        # molecular_vectors = self.mean(fingerprint_vectors, molecular_sizes)

        return molecular_vectors

    
class TestModelStruc(nn.Module):
    def __init__(self, vocab_size, ddi_adj, ddi_matrix, GNNSet, N_fingerprints, emb_dim=256, device=torch.device('cpu:0')):
        super(TestModelStruc, self).__init__()

        self.device = device

        # pre-embedding
        self.embeddings = nn.ModuleList(
            [nn.Embedding(vocab_size[i], emb_dim) for i in range(2)])
        self.dropout = nn.Dropout(p=0.7)
        self.encoders = nn.ModuleList([nn.GRU(emb_dim, emb_dim*2, batch_first=True) for _ in range(2)])
        self.query = nn.Sequential(
                nn.ReLU(),
                nn.Linear(4 * emb_dim, emb_dim)
        )

        # implicit
        self.fraction_linear = nn.Sequential(
            nn.Linear(emb_dim, ddi_matrix.shape[1])
        )
        # self.output = MaskLinear(ddi_matrix.shape[1], vocab_size[2])
        self.output = nn.Linear(ddi_matrix.shape[1], vocab_size[2])

        # explicit
        self.GNN_linear = nn.Linear(emb_dim, vocab_size[2])
        self.GNNSet = list(zip(*GNNSet))
        self.GNN_emb = MolecularGraphNeuralNetwork(N_fingerprints, emb_dim, layer_hidden=2, device=device).forward(self.GNNSet)
        self.GNN_emb_fix = torch.tensor(self.GNN_emb, requires_grad=True)

        # graphs
        self.tensor_ddi_adj = torch.FloatTensor(ddi_adj).to(device)
        self.tensor_ddi_matrix = torch.FloatTensor(ddi_matrix).to(device)
        self.init_weights()

    def forward(self, input):

        """
        def mean_embedding(embedding):
            return embedding.mean(dim=1)  # (1,1,dim)

        diag_emb = mean_embedding(self.embeddings[0](torch.LongTensor(input[0]).unsqueeze(dim=0).to(self.device))) # (1,dim)
        prod_emb = mean_embedding(self.embeddings[1](torch.LongTensor(input[1]).unsqueeze(dim=0).to(self.device))) # (1,dim)
        query = self.query(torch.cat([diag_emb, prod_emb], dim=1))
        """

	# generate medical embeddings and queries
        i1_seq = []
        i2_seq = []
        def mean_embedding(embedding):
            return embedding.mean(dim=1).unsqueeze(dim=0)  # (1,1,dim)
        for adm in input:
            i1 = mean_embedding(self.dropout(self.embeddings[0](torch.LongTensor(adm[0]).unsqueeze(dim=0).to(self.device)))) # (1,1,dim)
            i2 = mean_embedding(self.dropout(self.embeddings[1](torch.LongTensor(adm[1]).unsqueeze(dim=0).to(self.device))))
            i1_seq.append(i1)
            i2_seq.append(i2)
        i1_seq = torch.cat(i1_seq, dim=1) #(1,seq,dim)
        i2_seq = torch.cat(i2_seq, dim=1) #(1,seq,dim)

        o1, h1 = self.encoders[0](
            i1_seq
        ) # o1:(1, seq, dim*2) hi:(1,1,dim*2)
        o2, h2 = self.encoders[1](
            i2_seq
        )
        patient_representations = torch.cat([o1, o2], dim=-1).squeeze(dim=0) # (seq, dim*4)
        query = self.query(patient_representations)[-1:, :] # (seq, dim)

	# implicit GNN molecule structure
        att = F.softmax(torch.mm(query, self.GNN_emb_fix.t()), dim=-1)  # (1, size)
        GNN_out = self.GNN_linear(torch.mm(att, self.GNN_emb_fix))  # (1, dim)
        
	# explicit mask molecule structure
        # fraction_out = self.output(self.fraction_linear(query), self.tensor_ddi_matrix.t())
        fraction_out = self.output(self.fraction_linear(query))

        # result = torch.mul(fraction_out, GNN_out)
        result = fraction_out
        
        neg_pred_prob = F.sigmoid(result)
        neg_pred_prob = neg_pred_prob.t() * neg_pred_prob  # (voc_size, voc_size)
        batch_neg = neg_pred_prob.mul(self.tensor_ddi_adj).sum()

        return result, batch_neg

    def init_weights(self):
        """Initialize weights."""
        initrange = 0.1
        for item in self.embeddings:
            item.weight.data.uniform_(-initrange, initrange)

In [None]:
resume_name = '/home/chaoqiy2/TEMP/GAMENet/code2/saved/emb_LR_drug_molecule/Epoch_39_JA_0.5123_DDI_0.0597.model'.format(model_name)

model = TestModelStruc(voc_size, ddi_adj, ddi_matrix, GNNSet, N_fingerprint, emb_dim=256, device=device)
model.load_state_dict(torch.load(open(resume_name, 'rb')))