In [2]:
import os
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from torch_geometric.data import HeteroData
from torch_geometric.nn import GENConv, HeteroConv

In [3]:
hetero_data = torch.load('data/combined_dbs_heteroGraph.pt')
print(hetero_data)

HeteroData(
  lncRNA={ x=[1269, 2] },
  protein={ x=[11585, 2] },
  (lncRNA, interacts, protein)={
    edge_index=[2, 7635],
    edge_attr=[7635, 4],
  },
  (protein, interacts, protein)={
    edge_index=[2, 148992],
    edge_attr=[148992, 1],
  }
)


In [4]:
class HeteroGNNEncoder(nn.Module):
    def __init__(self, in_channels, out_channels, metadata, hidden_channels=64):
        super().__init__()

        # Define edge_dim manually for each edge type
        edge_dims = {
            ('lncRNA', 'interacts', 'protein'): 4,
            ('protein', 'interacts', 'protein'): 1,
        }

        # Store convs in ModuleDict using string keys (requirement of nn.ModuleDict)
        self.convs = nn.ModuleDict({
            '__'.join(edge_type): GENConv(
                in_channels=in_channels,
                out_channels=out_channels,
                edge_dim=edge_dims[edge_type],
                aggr='softmax',
                t=1.0,
                learn_t=True,
                num_layers=2,
                norm='layer'
            )
            for edge_type in metadata[1]
        })

        # Build HeteroConv with original tuple keys (required by PyG)
        self.hetero_conv = HeteroConv({
            edge_type: self.convs['__'.join(edge_type)]
            for edge_type in metadata[1]
        }, aggr='sum')

        # LayerNorm for each node type
        self.norms = nn.ModuleDict({
            node_type: nn.LayerNorm(out_channels)
            for node_type in metadata[0]
        })

    def forward(self, x_dict, edge_index_dict, edge_attr_dict):
        # Apply HeteroConv using edge_attr directly
        out_dict = self.hetero_conv(x_dict, edge_index_dict, edge_attr_dict)

        # Normalize node embeddings
        out_dict = {
            node_type: self.norms[node_type](x)
            for node_type, x in out_dict.items()
        }
        return out_dict


In [5]:
encoder = HeteroGNNEncoder(
    in_channels=2,
    out_channels=128,
    metadata=hetero_data.metadata(),
    hidden_channels=64
)

with torch.no_grad():
    embeddings = encoder(
        hetero_data.x_dict,
        hetero_data.edge_index_dict,
        hetero_data.edge_attr_dict
    )

for node_type, emb in embeddings.items():
    print(f"{node_type} embedding shape: {emb.shape}")




protein embedding shape: torch.Size([11585, 128])


In [None]:
## No embedding for lncRNA because it's a source only. lncRNA didn't receieve any msg.
## we need to add reverse edges.